树分裂DSU on tree WA 18分 求调

题目链接
根据题解的思路写的DSU on tree。由于是第一次写DSU on tree,没发现有什么问题,但是跑样例时样例输出的是 102 。有大佬帮忙看一下吗?

code:

#include<bits/stdc++.h>
using namespace std;
#define int long long
vector<int> vec[2000006];
int n;
int tot;
bool flg;
int cnt[2000006];
int siz[2000006];
int mx[2000006];
int h[2000006];
int freq[2000006];
int ans[2000006];
void dfs1(int now, int fa)
{
	siz[now] = 1;
	mx[now] = 0;
	for(int v : vec[now])
	{
		if(v == fa) continue;
		dfs1(v, now);
		siz[now] += siz[v];
		mx[now] = max(mx[now], siz[v]);
	}
	mx[now] = max(mx[now], n - siz[now]);
}
int find()
{
	int root = 1;
	dfs1(1, 0);
	for(int i = 1; i <= n; i++)
	{
		if(mx[i] <= n / 2)
		{
			root = i;
			break;
		}
	}
	return root;
}
void dfs2(int now, int fa)
{
	siz[now] = 1;
	for(int v : vec[now])
	{
		if(v == fa) continue;
		dfs2(v, now);
		siz[now] += siz[v];
	}
}
void dfs3(int now, int fa)
{
	for(int v : vec[now])
	{
		if(v == fa) continue;
		tot += siz[v] * (n - siz[v]);
		dfs3(v, now);
	}
}
void dfs4(int now, int fa)
{
	if(siz[now] == n / 2)
	{
		flg = true;
	}
	for(int v : vec[now])
	{
		if(v == fa) continue;
		dfs4(v, now);
	}
}
void dfs5(int now, int fa)
{
	int res = 0;
	int resi = 0;
	for(int v : vec[now])
	{
		if(v != fa)
		{
			dfs5(v, now);
			if(res < siz[v])
			{
				res = siz[v];
				resi = v;
			}
		}
	}
	h[now] = resi;
}
void addnode(int u)
{
	freq[siz[u]]++;
}
void delnode(int u)
{
	freq[siz[u]]--;
}
void add(int u, int fa)
{
	addnode(u);
	for(int v : vec[u])
	{
		if(v != fa)
		{
			add(v, u);
		}
	}
}
void del(int u, int fa)
{
	delnode(u);
	for(int v : vec[u])
	{
		if(v != fa)
		{
			del(v, u);
		}
	}
}
void dfs6(int now, int fa, int keep)
{
	for(int v : vec[now])
	{
		if(v != fa && v != h[now])
		{
			dfs6(v, now, 0);
		}
	}
	if(h[now])
	{
		dfs6(h[now], now, 1);
	}
	addnode(now);
	for(int v : vec[now])
	{
		if(v != fa && v != h[now])
		{
			add(v, now);
		}
	}
	if(now != fa)
	{
		int tmp = n / 2 - siz[now];
		ans[now] = cnt[tmp] - freq[tmp];
	}
	if(!keep)
	{
		del(now, fa);
	}
}
signed main()
{
	ios::sync_with_stdio(false);
	cin.tie(0);
	cin>>n;
	for(int i = 1; i < n; i++)
	{
		int u, v;
		cin>>u>>v;
		vec[u].push_back(v);
		vec[v].push_back(u);
	}
	int root = find();
	dfs2(root, 0);
	for(int i = 1; i <= n; i++)
	{
		cnt[siz[i]]++;
	}
	dfs3(root, 0);
	dfs4(root, 0);
	if(flg)
	{
		tot -= (n / 2) * (n / 2);
	}
	dfs5(root, 0);
	dfs6(root, 0, 1);
	for(int i = 1; i <= n; i++)
	{
		if(i != root)
		{
			int tmp = n / 2 - siz[i];
			tot -= ans[i] * siz[i];
		}
	}
	cout<<tot;
	return 0;
}