LC 834. 树中距离之和
题目描述
这是 LeetCode 上的 834. 树中距离之和 ,难度为 困难。
给定一个无向、连通的树。
树中有 n
个标记为 0...n-1
的节点以及 n-1
条边 。
给定整数 n
和数组 edges
, $edges[i] = [a{i}, b{i}]$表示树中的节点 $a{i}$ 和 $b{i}$ 之间有一条边。
返回长度为 n
的数组 answer
,其中 answer[i]
是树中第 i
个节点与所有其他节点之间的距离之和。
示例 1:1
2
3
4
5
6
7输入: n = 6, edges = [[0,1],[0,2],[2,3],[2,4],[2,5]]
输出: [8,12,6,10,10,10]
解释: 树如图所示。
我们可以计算出 dist(0,1) + dist(0,2) + dist(0,3) + dist(0,4) + dist(0,5)
也就是 1 + 1 + 2 + 2 + 2 = 8。 因此,answer[0] = 8,以此类推。
示例 2:1
2
3输入: n = 1, edges = []
输出: [0]
示例 3:
1 |
|
提示:
- $1 <= n <= 3 \times 10^4$
- $edges.length = n - 1$
- $edges[i].length = 2$
- $0 <= a{i}, b{i} < n$
- $a{i} != b{i}$
- 给定的输入保证为有效的树
树形 DP
对于树形 DP,可以随便以某个节点为根,把整棵树“拎起来”进行分析,通常还会以“方向”作为切入点进行思考。
不妨以编号为 0
的节点作为根节点进行分析:假设当前处理到的节点为 u
,当前节点 u
的父节点为 fa
,同时 u
有若干子节点 j
。
对于任意节点 u
而言,其树中距离之和可根据「方向/位置」分为两大类(对应示例图的左右两部分):
- 所有从节点
u
“往下”延伸所达的节点距离之和,即所有经过u -> j
边所能访问到的节点距离之和 - 所有从节点
u
“往上”延伸所达的节点距离之和,即经过u -> fa
边所能访问到的节点距离之和
假设我们能够用 $f[u]$ 和 $g[u]$ 预处理出每个节点“往下”和“往上”的距离之和,那么就有 $ans[u] = f[u] + g[u]$。
不失一般性分别考虑 $f[u]$ 和 $g[u]$ 该如何计算。
为了方便,起始先用「链式前向星」对 edges
进行转存,同时在递归计算 $f[u]$ 时,将父节点 fa
也进行传递,从而避免遍历节点 u
的出边时,重新走回 fa
。
$f[u]$ 的推导
对于叶子节点(没有“往下”出边的节点),我们有 $f[u] = 0$ 的天然条件,计算好的叶子节点值可用于更新其父节点,因此求解 $f[u]$ 是一个「从下往上」的递推过程。
假设当前处理到的节点是 u
,往下节点有 $j{1}$、$j{2}$ 和 $j_{3}$ ,且所有 $f[j]$ 均已计算完成。
由于 $f[u]$ 是由所有存在“往下”出边的节点 j
贡献而来。而单个子节点 j
来说,其对 $f[u]$ 的贡献应当是:在所有原有节点到节点 j
的距离路径中,额外增加一条当前出边(u -> j
),再加上 1
(代表节点 u
到节点 j
的距离)。
原路径距离之和恰好是 $f[j]$,额外需要增加的出边数量为原来参与计算 $f[j]$ 的点的数量(即挂载在节点 j
下的数量),因此我们还需要一个 c
数组,来记录某个节点下的子节点数量。
最终的 $f[u]$ 为所有符合条件的节点的 j
的 $f[j] + c[j] + 1$ 的总和。
$g[u]$ 的推导
对于树形 DP 题目,“往下”的计算往往是容易的,而“往上”的计算则是稍稍麻烦。
假设当前我们处理到节点为 u
,将要遍历的节点为 j
,考虑如何使用已经计算好的 $f[X]$ 来求解 $g[j]$。
这里为什么是求解 $g[j]$,而不是 $g[u]$ 呢?
因为我们求解的方向是“往上”的部分,必然是用父节点的计算结果,来推导子节点的结果,即求解 $g[u]$ 是一个「从上往下」的过程。
对于树形 DP ,通常需要对“往上”进一步拆分:「往上再往上」和「往上再往下」:
往上再往上:是指经过了
j -> u
后,还必然经过u -> fa
这条边时,所能到达的节点距离之和:这部分对 $g[j]$ 的贡献为:在所有原有节点到节点
u
的距离路径中,额外增加一条当前出边(u -> j
),增加当前出边的数量与节点数量相同,点数量为 $n - 1 - c[u]$,含义为 总节点数量 减去u
节点以及子节点数量。即此部分对 $g[j]$ 的贡献为 $g[u] + n - 1 - c[u]$。
往上再往下:是指经过了
j -> u
后,还经过「除u -> j
以外」的其他“往下”边时,所能到达的节点距离之和:这部分的计算需要先在 $f[u]$ 中剔除 $f[j]$ 的贡献,然后再加上额外边(
u -> j
)的累加数量,同样也是节点数量。从 $f[u]$ 中剔除 $f[j]$ 后为 $f[u] - f[j] - c[j]$,而点的数量为 $c[u] - 1 - c[j]$,含义为在以节点
u
为根的子树中剔除调用以节点j
为根节点的部分。即此部分对 $g[j]$ 的贡献为 $f[u] - f[j] - c[j] + c[u] - 1 - c[j]$。
Java 代码:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43class Solution {
int N = 30010, M = 60010, idx = 0, n;
int[] he = new int[N], e = new int[M], ne = new int[M];
int[] f = new int[N], c = new int[N], g = new int[N];
void add(int a, int b) {
e[idx] = b;
ne[idx] = he[a];
he[a] = idx++;
}
public int[] sumOfDistancesInTree(int _n, int[][] es) {
n = _n;
Arrays.fill(he, -1);
for (int[] e : es) {
int a = e[0], b = e[1];
add(a, b); add(b, a);
}
dfs1(0, -1);
dfs2(0, -1);
int[] ans = new int[n];
for (int i = 0; i < n; i++) ans[i] = f[i] + g[i];
return ans;
}
int[] dfs1(int u, int fa) {
int tot = 0, cnt = 0;
for (int i = he[u]; i != -1; i = ne[i]) {
int j = e[i];
if (j == fa) continue;
int[] next = dfs1(j, u);
tot += next[0] + next[1] + 1; cnt += next[1] + 1;
}
f[u] = tot; c[u] = cnt;
return new int[]{tot, cnt};
}
void dfs2(int u, int fa) {
for (int i = he[u]; i != -1; i = ne[i]) {
int j = e[i];
if (j == fa) continue;
g[j] += g[u] + n - 1 - c[u]; // 往上再往上
g[j] += f[u] - f[j] - c[j] + c[u] - 1 - c[j]; // 往上再往下
dfs2(j, u);
}
}
}
C++ 代码:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46class Solution {
public:
int N = 30010, M = 60010, idx = 0, n;
int he[30010], e[60010], ne[60010];
int f[30010], c[30010], g[30010];
void add(int a, int b) {
e[idx] = b;
ne[idx] = he[a];
he[a] = idx++;
}
vector<int> sumOfDistancesInTree(int _n, vector<vector<int>>& es) {
n = _n;
memset(he, -1, sizeof(he));
for (auto& e : es) {
int a = e[0], b = e[1];
add(a, b);
add(b, a);
}
dfs1(0, -1);
dfs2(0, -1);
vector<int> ans(n);
for (int i = 0; i < n; i++) ans[i] = f[i] + g[i];
return ans;
}
vector<int> dfs1(int u, int fa) {
int tot = 0, cnt = 0;
for (int i = he[u]; i != -1; i = ne[i]) {
int j = e[i];
if (j == fa) continue;
vector<int> next = dfs1(j, u);
tot += next[0] + next[1] + 1;
cnt += next[1] + 1;
}
f[u] = tot; c[u] = cnt;
return {tot, cnt};
}
void dfs2(int u, int fa) {
for (int i = he[u]; i != -1; i = ne[i]) {
int j = e[i];
if (j == fa) continue;
g[j] += g[u] + n - 1 - c[u];
g[j] += f[u] - f[j] - c[j] + c[u] - 1 - c[j];
dfs2(j, u);
}
}
};
- 时间复杂度:$O(n)$
- 空间复杂度:$O(n)$
最后
这是我们「刷穿 LeetCode」系列文章的第 No.834
篇,系列开始于 2021/01/01,截止于起始日 LeetCode 上共有 1916 道题目,部分是有锁题,我们将先把所有不带锁的题目刷完。
在这个系列文章里面,除了讲解解题思路以外,还会尽可能给出最为简洁的代码。如果涉及通解还会相应的代码模板。
为了方便各位同学能够电脑上进行调试和提交代码,我建立了相关的仓库:https://github.com/SharingSource/LogicStack-LeetCode 。
在仓库地址里,你可以看到系列文章的题解链接、系列文章的相应代码、LeetCode 原题链接和其他优选题解。
本博客所有文章除特别声明外,均采用 CC BY-SA 4.0 协议 ,转载请注明出处!