题目链接
根据题解的思路写的DSU on tree。由于是第一次写DSU on tree,没发现有什么问题,但是跑样例时样例输出的是 102 。有大佬帮忙看一下吗?
code:
#include<bits/stdc++.h>
using namespace std;
#define int long long
vector<int> vec[2000006];
int n;
int tot;
bool flg;
int cnt[2000006];
int siz[2000006];
int mx[2000006];
int h[2000006];
int freq[2000006];
int ans[2000006];
void dfs1(int now, int fa)
{
siz[now] = 1;
mx[now] = 0;
for(int v : vec[now])
{
if(v == fa) continue;
dfs1(v, now);
siz[now] += siz[v];
mx[now] = max(mx[now], siz[v]);
}
mx[now] = max(mx[now], n - siz[now]);
}
int find()
{
int root = 1;
dfs1(1, 0);
for(int i = 1; i <= n; i++)
{
if(mx[i] <= n / 2)
{
root = i;
break;
}
}
return root;
}
void dfs2(int now, int fa)
{
siz[now] = 1;
for(int v : vec[now])
{
if(v == fa) continue;
dfs2(v, now);
siz[now] += siz[v];
}
}
void dfs3(int now, int fa)
{
for(int v : vec[now])
{
if(v == fa) continue;
tot += siz[v] * (n - siz[v]);
dfs3(v, now);
}
}
void dfs4(int now, int fa)
{
if(siz[now] == n / 2)
{
flg = true;
}
for(int v : vec[now])
{
if(v == fa) continue;
dfs4(v, now);
}
}
void dfs5(int now, int fa)
{
int res = 0;
int resi = 0;
for(int v : vec[now])
{
if(v != fa)
{
dfs5(v, now);
if(res < siz[v])
{
res = siz[v];
resi = v;
}
}
}
h[now] = resi;
}
void addnode(int u)
{
freq[siz[u]]++;
}
void delnode(int u)
{
freq[siz[u]]--;
}
void add(int u, int fa)
{
addnode(u);
for(int v : vec[u])
{
if(v != fa)
{
add(v, u);
}
}
}
void del(int u, int fa)
{
delnode(u);
for(int v : vec[u])
{
if(v != fa)
{
del(v, u);
}
}
}
void dfs6(int now, int fa, int keep)
{
for(int v : vec[now])
{
if(v != fa && v != h[now])
{
dfs6(v, now, 0);
}
}
if(h[now])
{
dfs6(h[now], now, 1);
}
addnode(now);
for(int v : vec[now])
{
if(v != fa && v != h[now])
{
add(v, now);
}
}
if(now != fa)
{
int tmp = n / 2 - siz[now];
ans[now] = cnt[tmp] - freq[tmp];
}
if(!keep)
{
del(now, fa);
}
}
signed main()
{
ios::sync_with_stdio(false);
cin.tie(0);
cin>>n;
for(int i = 1; i < n; i++)
{
int u, v;
cin>>u>>v;
vec[u].push_back(v);
vec[v].push_back(u);
}
int root = find();
dfs2(root, 0);
for(int i = 1; i <= n; i++)
{
cnt[siz[i]]++;
}
dfs3(root, 0);
dfs4(root, 0);
if(flg)
{
tot -= (n / 2) * (n / 2);
}
dfs5(root, 0);
dfs6(root, 0, 1);
for(int i = 1; i <= n; i++)
{
if(i != root)
{
int tmp = n / 2 - siz[i];
tot -= ans[i] * siz[i];
}
}
cout<<tot;
return 0;
}