(应该算普及吧)
换根dp(二次扫描)
换根dp,树形dp的一种,通常不指定根结点,要求输出以不同结点为根的最优解。
eg.
3 个点
1 - 3
1 - 2
以1为根
以2为根
以3为根
例题(id:20221)
给出一棵 n 个点的树,以及每个点的权值 c_i
设 d(u,v) 表示树上 u,v 两点之间的距离, f(x) =\sum_{i = 1}^{n}c_i\cdot d(x, i)
n \le 10^5 c_i \le 10^9
使用预处理+暴力
预处理阶段:
对每个节点做一次 BFS,每个 BFS 会遍历所有 n 个结点和 n−1 条边,时间复杂度为 O(n)
总共有 n 个结点,因此预处理的时间复杂度是 O(n^2)
暴力计算阶段:
遍历 n 个节点,每个节点需要累加 n 项,时间复杂度为 O(n^2)
无法通过此题
考虑使用 换根dp
先处理 1 号结点为根时 f(1)
void dfs(int u, int fa, bool j)//当前结点,父亲结点,是否处理(后序遍历,防止子结点未处理情况)
{
if (!j)
{
for (auto k : g[u])
{
if (k == fa)
{
continue;
}
dfs(k, u, 0);
}
dfs(u, fa, 1);//按照后序遍历顺序
}
else//加上子节点贡献
{
len[u] = a[u];
for (auto k : g[u])
{
if (k == fa)
continue;
len[u] += len[k];//u的子树所有节点的权值和
f[u] += (f[k] + len[k]);//结点u的f值
}
}
}
dfs(1, -1, 0)//1号结点无父结点 未处理
那么结点 1 为根的情况处理好了,这时候要处理其它结点(换根)
void dfs2(int u, int fa)
{
minf = min(minf, f[u]); // 更新最小值
for (int v : g[u])
{
if (v == fa)
continue;
// 核心换根公式:从父节点u推导子节点v的f值
f[v] = f[u] + (tot - len[v]) - len[v];//tot为所有结点权值和
dfs2(v, u); // 递归处理子节点
}
}
这里解释一下换根公式
f[v] = f[u] + (tot - len[v]) - len[v];
- 由于 v 为 u 的子结点,v 比 u 子树内结点距离减 1,子树外结点距离加 1
- 子树内结点距离减 1,即 \sum_{i = g[v].rbegin()}^{g[v].rend()} c[*i], 即 len[v]
f[u] - len[v]//len[v]为v子树权值和
- 子树外结点距离加 1, 即 g[v] 以外的结点的权值和, 即 tot - len[v]
参考代码
#include <bits/stdc++.h>
#define int long long
using namespace std;
int n, tot;
vector<int> g[100010];
int a[100010];
int len[100010];
int f[100010];
void dfs1(int u, int fa, bool j)
{
if (!j)
{
for (auto k : g[u])
{
if (k == fa)
{
continue;
}
dfs1(k, u, 0);
}
dfs1(u, fa, 1);
}
else
{
len[u] = a[u];
for (auto k : g[u])
{
if (k == fa)
continue;
len[u] += len[k];
f[u] += (f[k] + len[k]);
}
}
}
signed main()
{
ios::sync_with_stdio(0);
cin.tie(nullptr);
cin >> n;
for (int i = 1; i < n; i ++)
{
int x, y;
cin >> x >> y;
g[x].emplace_back(y);
g[y].emplace_back(x);
}
for (int i = 1; i <= n; i ++)
cin >> a[i], tot += a[i];
dfs1(1, -1, 0);
stack<pair<int,int> > s;//用栈模拟,逻辑一致
s.push({1, -1});
while (!s.empty())
{
pair<int,int> t = s.top();
int u = t.first, fa = t.second;
s.pop();
for (int k : g[u])
{
if (k != fa)
{
f[k] = f[u] - len[k] + tot - len[k];
s.push({k, u});
}
}
}
int ans = LONG_LONG_MAX - 10;//找答案
for (int i = 1; i <= n; i ++)
ans = min(ans, f[i]);
cout << ans;
return 0;
}
The end

