Tarjan缩点拓扑最长路P3387,如何为疑问?
摘要:P3387 【模版】缩点 - 洛谷 给定一个有向图,每个点都有一个权值,允许多次经过一条边或者一个点但权值只计算一次。求一条权值之和最大的路径。 很明显的拓扑排序最长路问题,考虑缩点后进行 DP 。一般的方法为 缩点后得到 DAG ,在新的
P3387 【模版】缩点 - 洛谷
给定一个有向图,每个点都有一个权值,允许多次经过一条边或者一个点但权值只计算一次。求一条权值之和最大的路径。
很明显的拓扑排序最长路问题,考虑缩点后进行 DP 。一般的方法为
缩点后得到 DAG ,在新的 DAG 里进行拓扑排序。
在拓扑排序的过程中更新状态转移方程 dp[v] = Max(dp[v], dp[u] + w[v]) 。
在进行收网时更新权值,此时新权值与 \(scc\) 的编号对应。
if dfn[x] == low[x]: # 其余模版省略
t = 0
while True:
t += w[cur - 1]
scc_size.append(t) # 得到一个SCC后其权值就是t索引与scc对应
scc += 1
具体代码为
import sys
from math import inf
from collections import deque
sys.setrecursionlimit(100010)
if 1:
inp = lambda: sys.stdin.readline().strip()
II = lambda: int(inp())
MII = lambda: map(int, inp().split())
LII = lambda: list(MII())
Max = lambda x, y: x if x > y else y
Min = lambda x, y: x if x < y else y
def main():
n, m = MII()
w = LII()
g = [[] for _ in range(n + 1)]
for _ in range(m):
u, v = MII()
g[u].append(v)
dfn = [-1] * (n + 1)
low = [-1] * (n + 1)
st = []
inst = [False] * (n + 1)
scc = 0
timer = 1
scc_id = [-1] * (n + 1)
scc_size = []
def tarjan(x):
nonlocal timer, scc
dfn[x] = low[x] = timer
timer += 1
st.append(x)
inst[x] = True
for y in g[x]:
if dfn[y] == -1:
tarjan(y)
low[x] = Min(low[x], low[y])
elif inst[y]:
low[x] = Min(low[x], dfn[y])
if dfn[x] == low[x]:
t = 0
while True:
cur = st.pop()
inst[cur] = False
scc_id[cur] = scc
t += w[cur - 1]
if cur == x:
break
scc_size.append(t)
scc += 1
for i in range(1, n + 1):
if dfn[i] == -1:
tarjan(i)
dag = [[] for _ in range(scc)]
din = [0] * scc
for u in range(1, n + 1):
for v in g[u]:
su = scc_id[u]
sv = scc_id[v]
if su != sv:
dag[su].append(sv)
din[sv] += 1
dq = deque([i for i in range(scc) if din[i] == 0])
dist = [-inf] * scc
for i in dq:
dist[i] = scc_size[i]
while dq:
u = dq.popleft()
for v in dag[u]:
dist[v] = Max(dist[v], dist[u] + scc_size[v])
din[v] -= 1
if din[v] == 0:
dq.append(v)
print(max(dist))
if __name__ == "__main__":
main()
如果只是单纯的在拓扑序上进行操作的话,其实可以不必要新建 DAG 。因为 Tarjan 算法得到的 scc_id 的顺序就是一个天然的逆拓扑序。
上面代码的 while 循环可以替换为以下代码
dist = scc_size[:]
# 可以去掉 din 数组和队列
for u in range(scc - 1, -1, -1):
for v in dag[u]:
dist[v] = Max(dist[v], dist[u] + scc_size[v])
