#include<bits/stdc++.h>
#define int long long
using namespace std;
int n,m,r,p;
int h[100005],e[200005],ne[200005],idx;
int v[100005];
int fa[100005],dep[100005],son[100005],siz[100005];
int dfn[100005],cnt,gr[100005];
int id[100005];
void dfs(int x,int f){
//cout<<1<<" "<<x<<" "<<f<<endl;
fa[x]=f;
siz[x]=1;
dep[x]=dep[f]+1;
int mx=0,hson=0;
//cout<<dep[x]<<" "<<h[x]<<endl;
for (int i=h[x];i;i=ne[i]){
//cout<<i<<endl;
int j=e[i];
//cout<<j<<endl;
if (j==f)continue;
dfs(j,x);
if (siz[j]>mx)hson=j,mx=siz[j];
siz[x]+=siz[j];
}
son[x]=hson;
}
void dfs2(int x,int f,int grand){
//cout<<2<<" "<<x<<" "<<f<<" "<<grand<<endl;
dfn[x]=++cnt;
id[cnt]=x;
gr[x]=grand;
if (son[x])dfs2(son[x],x,grand);
for (int i=h[x];i;i=ne[i]){
int j=e[i];
if (j==f||j==son[x])continue;
dfs2(j,x,j);
}
}
struct node{
int l,r;
int sum,f;
}t[400005];
inline void pushup(int rt){
t[rt].sum=t[rt*2].sum+t[rt*2+1].sum;
t[rt].sum%=p;
}
inline void pushdown(int rt){
if (t[rt].f){
t[rt*2].f+=t[rt].f;
t[rt*2].f%=p;
t[rt*2+1].f+=t[rt].f;
t[rt*2+1].f%=p;
t[rt*2].sum+=(t[rt*2].r-t[rt*2].l+1)*t[rt].f;
t[rt*2].sum%=p;
t[rt*2+1].sum+=(t[rt*2+1].r-t[rt*2+1].l+1)*t[rt].f;
t[rt*2+1].sum%=p;
t[rt].f=0;
}
}
void build(int rt,int l,int r){
t[rt].l=l;
t[rt].r=r;
if (l==r){
t[rt].sum=v[dfn[l]];
//cout<<l<<" "<<t[rt].sum<<endl;
return;
}
int mid=(l+r)/2;
build(rt*2,l,mid);
build(rt*2+1,mid+1,r);
pushup(rt);
}
void modify(int rt,int l,int r,int x){
//cout<<"m "<<t[rt].l<<" "<<t[rt].r<<" "<<l<<" "<<r<<" "<<x<<endl;
if (t[rt].l>r||t[rt].r<l)return;
if (t[rt].l>=l&&t[rt].r<=r){
t[rt].sum+=(t[rt].r-t[rt].l+1)*x;
t[rt].f+=x;
t[rt].sum%=p;
t[rt].f%=p;
return;
}
pushdown(rt);
modify(rt*2,l,r,x);
modify(rt*2+1,l,r,x);
pushup(rt);
}
int query(int rt,int l,int r){
//cout<<"q "<<t[rt].l<<" "<<t[rt].r<<" "<<l<<" "<<r<<" "<<t[rt].sum<<endl;
if (t[rt].l>r||t[rt].r<l)return 0;
if (t[rt].l>=l&&t[rt].r<=r){
return t[rt].sum;
}
pushdown(rt);
return (query(rt*2,l,r)+query(rt*2+1,l,r))%p;
}
void add(int x,int y,int z){
while (gr[x]!=gr[y]){
if (dep[x]<dep[y])swap(x,y);
if (gr[x]==r)swap(x,y);
modify(1,dfn[gr[x]],dfn[x],z);
x=fa[gr[x]];
}
while (dep[x]!=dep[y]){
if (dep[x]<dep[y])swap(x,y);
modify(1,dfn[x],dfn[x],z);
x=fa[x];
}
modify(1,dfn[y],dfn[y],z);
}
int sum(int x,int y){
int ans=0;
while (gr[x]!=gr[y]){
//cout<<x<<" "<<y<<" "<<gr[x]<<" "<<gr[y]<<endl;
if (dep[x]<dep[y])swap(x,y);
if (gr[x]==r)swap(x,y);
//cout<<x<<" "<<y<<" "<<gr[x]<<" "<<gr[y]<<endl;
ans+=query(1,dfn[gr[x]],dfn[x]);
ans%=p;
x=fa[gr[x]];
}
while (dep[x]!=dep[y]){
//cout<<x<<" "<<y<<" "<<dep[x]<<" "<<dep[y]<<" "<<ans<<endl;
if (dep[x]<dep[y])swap(x,y);
//cout<<x<<" "<<y<<" "<<dep[x]<<" "<<dep[y]<<" "<<ans<<endl;
ans+=query(1,dfn[x],dfn[x]);
x=fa[x];
}
ans+=query(1,dfn[y],dfn[y]);
return ans;
}
signed main(){
cin>>n>>m>>r>>p;
for (int i=1;i<=n;i++)cin>>v[i];
for (int i=1;i<n;i++){
int u,v;
cin>>u>>v;
e[++idx]=v;
ne[idx]=h[u];
h[u]=idx;
e[++idx]=u;
ne[idx]=h[v];
h[v]=idx;
}
dfs(r,0);
dfs2(r,0,r);
build(1,1,n);
for (int i=1;i<=m;i++){
int op,x,y,z;
cin>>op;
if (op==1){
cin>>x>>y>>z;
add(x,y,z);
}
if (op==2){
cin>>x>>y;
cout<<sum(x,y)<<endl;
}
if (op==3){
cin>>x>>z;
modify(1,dfn[x],dfn[x]+siz[x]-1,z);
}
if (op==4){
cin>>x;
cout<<query(1,dfn[x],dfn[x]+siz[x]-1)<<endl;
}
//cout<<"yishang "<<op<<" "<<x<<" "<<y<<" "<<z<<endl;
}
return 0;
}
1 个赞
现在发现了 2 个问题。
build
函数内应为t[rt].sum=v[id[l]]
;query
函数中
if (t[rt].l >= l && t[rt].r <= r) {
return t[rt].sum;
}
返回时需要取模。
1 个赞
感觉 sum
函数有点奇怪,x
和 y
调到同一条重链上可以直接用线段树上查询了,不需要再一点一点往上跳吧
1 个赞
发帖当天已A,谢谢