树形dp -> 换根dp

(应该算普及吧)


换根dp(二次扫描)

换根dp,树形dp的一种,通常不指定根结点,要求输出以不同结点为根的最优解。

eg.
3 个点

1 - 3

1 - 2

以1为根
graph


以2为根
graph


以3为根
graph


例题(id:20221)
给出一棵 n 个点的树,以及每个点的权值 c_i
d(u,v) 表示树上 u,v 两点之间的距离, f(x) =\sum_{i = 1}^{n}c_i\cdot d(x, i)
n \le 10^5 c_i \le 10^9

使用预处理+暴力
预处理阶段:
对每个节点做一次 BFS,每个 BFS 会遍历所有 n 个结点和 n−1 条边,时间复杂度为 O(n)
总共有 n 个结点,因此预处理的时间复杂度是 O(n^2)
暴力计算阶段:
遍历 n 个节点,每个节点需要累加 n 项,时间复杂度为 O(n^2)
无法通过此题

考虑使用 换根dp
先处理 1 号结点为根时 f(1)

void dfs(int u, int fa, bool j)//当前结点,父亲结点,是否处理(后序遍历,防止子结点未处理情况)
{
	if (!j)
	{
		for (auto k : g[u])
		{
			if (k == fa)
			{
				continue;
			}
			dfs(k, u, 0);
		}
		dfs(u, fa, 1);//按照后序遍历顺序
	}
	else//加上子节点贡献
	{
		len[u] = a[u];
		for (auto k : g[u])
		{
			if (k == fa)
				continue;
			len[u] += len[k];//u的子树所有节点的权值和
			f[u] += (f[k] + len[k]);//结点u的f值
		}
	}
}
dfs(1, -1, 0)//1号结点无父结点 未处理

那么结点 1 为根的情况处理好了,这时候要处理其它结点(换根)

void dfs2(int u, int fa) 
{
    minf = min(minf, f[u]); // 更新最小值
    for (int v : g[u]) 
    {
        if (v == fa) 
            continue;
        // 核心换根公式:从父节点u推导子节点v的f值
        f[v] = f[u] + (tot -  len[v]) - len[v];//tot为所有结点权值和
        dfs2(v, u); // 递归处理子节点
    }
}

这里解释一下换根公式

f[v] = f[u] + (tot -  len[v]) - len[v];
  • 由于 vu 的子结点,vu 子树内结点距离减 1,子树外结点距离加 1
  • 子树内结点距离减 1,即 \sum_{i = g[v].rbegin()}^{g[v].rend()} c[*i], 即 len[v]
    f[u] - len[v]//len[v]为v子树权值和
  • 子树外结点距离加 1, 即 g[v] 以外的结点的权值和, 即 tot - len[v]

参考代码

#include <bits/stdc++.h>
#define int long long
using namespace std;
int n, tot;
vector<int> g[100010];
int a[100010];
int len[100010];
int f[100010];
void dfs1(int u, int fa, bool j)
{
	if (!j)
	{
		for (auto k : g[u])
		{
			if (k == fa)
			{
				continue;
			}
			dfs1(k, u, 0);
		}
		dfs1(u, fa, 1);
	}
	else
	{
		len[u] = a[u];
		for (auto k : g[u])
		{
			if (k == fa)
				continue;
			len[u] += len[k];
			f[u] += (f[k] + len[k]);
		}
	}
}
signed main()
{
	ios::sync_with_stdio(0);
	cin.tie(nullptr);
	cin >> n;
	for (int i = 1; i < n; i ++)
	{
		int x, y;
		cin >> x >> y;
		g[x].emplace_back(y);
		g[y].emplace_back(x);
	}
		
	
	for (int i = 1; i <= n; i ++)
		cin >> a[i], tot += a[i];
	dfs1(1, -1, 0);
	stack<pair<int,int> > s;//用栈模拟,逻辑一致
	s.push({1, -1});
	while (!s.empty())
	{
		pair<int,int> t = s.top();
		int u = t.first, fa = t.second;
		s.pop();
		for (int k : g[u])
		{
			if (k != fa)
			{
				f[k] = f[u] - len[k] + tot - len[k];
				s.push({k, u});
			}
		}
	}
	int ans = LONG_LONG_MAX - 10;//找答案
	for (int i = 1; i <= n; i ++)
		ans = min(ans, f[i]);
	cout << ans;
	return 0;
} 


The end