洛谷P3384 【模板】重链剖分/树链剖分 树剖解法0分 WA#1~7 #11 TLE#8~10 求调

,

原题链接

#include<bits/stdc++.h>
#define int long long
using namespace std;
int n,m,r,p;
int h[100005],e[200005],ne[200005],idx;
int v[100005];
int fa[100005],dep[100005],son[100005],siz[100005];
int dfn[100005],cnt,gr[100005];
int id[100005];
void dfs(int x,int f){
//cout<<1<<" "<<x<<" "<<f<<endl;
    fa[x]=f;
    siz[x]=1;
    dep[x]=dep[f]+1;
    int mx=0,hson=0;
    //cout<<dep[x]<<" "<<h[x]<<endl;
    for (int i=h[x];i;i=ne[i]){
    //cout<<i<<endl;
        int j=e[i];
        //cout<<j<<endl;
        if (j==f)continue;
        dfs(j,x);
        if (siz[j]>mx)hson=j,mx=siz[j];
        siz[x]+=siz[j];
    }
    son[x]=hson;
}
void dfs2(int x,int f,int grand){
//cout<<2<<" "<<x<<" "<<f<<" "<<grand<<endl;
    dfn[x]=++cnt;
    id[cnt]=x;
    gr[x]=grand;
    if (son[x])dfs2(son[x],x,grand);
    for (int i=h[x];i;i=ne[i]){
        int j=e[i];
        if (j==f||j==son[x])continue;
        dfs2(j,x,j);
    }
}
struct node{
    int l,r;
    int sum,f;
}t[400005];
inline void pushup(int rt){
    t[rt].sum=t[rt*2].sum+t[rt*2+1].sum;
    t[rt].sum%=p;
}
inline void pushdown(int rt){
    if (t[rt].f){
        t[rt*2].f+=t[rt].f;
        t[rt*2].f%=p;
        t[rt*2+1].f+=t[rt].f;
        t[rt*2+1].f%=p;
        t[rt*2].sum+=(t[rt*2].r-t[rt*2].l+1)*t[rt].f;
        t[rt*2].sum%=p;
        t[rt*2+1].sum+=(t[rt*2+1].r-t[rt*2+1].l+1)*t[rt].f;
        t[rt*2+1].sum%=p;
        t[rt].f=0;
    }
}
void build(int rt,int l,int r){
    t[rt].l=l;
    t[rt].r=r;
    if (l==r){
        t[rt].sum=v[dfn[l]];
        //cout<<l<<" "<<t[rt].sum<<endl;
        return;
    }
    int mid=(l+r)/2;
    build(rt*2,l,mid);
    build(rt*2+1,mid+1,r);
    pushup(rt);
}
void modify(int rt,int l,int r,int x){
//cout<<"m "<<t[rt].l<<" "<<t[rt].r<<" "<<l<<" "<<r<<" "<<x<<endl;
    if (t[rt].l>r||t[rt].r<l)return;
    if (t[rt].l>=l&&t[rt].r<=r){
        t[rt].sum+=(t[rt].r-t[rt].l+1)*x;
        t[rt].f+=x;
        t[rt].sum%=p;
        t[rt].f%=p;
        return;
    }
    pushdown(rt);
    modify(rt*2,l,r,x);
    modify(rt*2+1,l,r,x);
    pushup(rt);
}
int query(int rt,int l,int r){
//cout<<"q "<<t[rt].l<<" "<<t[rt].r<<" "<<l<<" "<<r<<" "<<t[rt].sum<<endl;
    if (t[rt].l>r||t[rt].r<l)return 0;
    if (t[rt].l>=l&&t[rt].r<=r){
        return t[rt].sum;
    }
    pushdown(rt);
    return (query(rt*2,l,r)+query(rt*2+1,l,r))%p;
}
void add(int x,int y,int z){
    while (gr[x]!=gr[y]){
        if (dep[x]<dep[y])swap(x,y);
        if (gr[x]==r)swap(x,y);
        modify(1,dfn[gr[x]],dfn[x],z);
        x=fa[gr[x]];
    }
    while (dep[x]!=dep[y]){
        if (dep[x]<dep[y])swap(x,y);
        modify(1,dfn[x],dfn[x],z);
        x=fa[x];
    }
    modify(1,dfn[y],dfn[y],z);
}
int sum(int x,int y){
    int ans=0;
    while (gr[x]!=gr[y]){
    //cout<<x<<" "<<y<<" "<<gr[x]<<" "<<gr[y]<<endl;
        if (dep[x]<dep[y])swap(x,y);
        if (gr[x]==r)swap(x,y);
        //cout<<x<<" "<<y<<" "<<gr[x]<<" "<<gr[y]<<endl;
        ans+=query(1,dfn[gr[x]],dfn[x]);
        ans%=p;
        x=fa[gr[x]];
    }
    while (dep[x]!=dep[y]){
    //cout<<x<<" "<<y<<" "<<dep[x]<<" "<<dep[y]<<" "<<ans<<endl;
        if (dep[x]<dep[y])swap(x,y);
        //cout<<x<<" "<<y<<" "<<dep[x]<<" "<<dep[y]<<" "<<ans<<endl;
        ans+=query(1,dfn[x],dfn[x]);
        x=fa[x];
    }
    ans+=query(1,dfn[y],dfn[y]);
    return ans;
}
signed main(){
    cin>>n>>m>>r>>p;
    for (int i=1;i<=n;i++)cin>>v[i];
    for (int i=1;i<n;i++){
        int u,v;
        cin>>u>>v;
        e[++idx]=v;
        ne[idx]=h[u];
        h[u]=idx;
        e[++idx]=u;
        ne[idx]=h[v];
        h[v]=idx;
    }
    dfs(r,0);
    dfs2(r,0,r);
    build(1,1,n);
    for (int i=1;i<=m;i++){
        int op,x,y,z;
        cin>>op;
        if (op==1){
            cin>>x>>y>>z;
            add(x,y,z);
        }
        if (op==2){
            cin>>x>>y;
            cout<<sum(x,y)<<endl;
        }
        if (op==3){
            cin>>x>>z;
            modify(1,dfn[x],dfn[x]+siz[x]-1,z);
        }
        if (op==4){
            cin>>x;
            cout<<query(1,dfn[x],dfn[x]+siz[x]-1)<<endl;
        }
        //cout<<"yishang "<<op<<" "<<x<<" "<<y<<" "<<z<<endl;
    }
    return 0;
}
1 个赞

@周沐瑞1

现在发现了 2 个问题。

  • build 函数内应为 t[rt].sum=v[id[l]]
  • query 函数中
if (t[rt].l >= l && t[rt].r <= r) {
	return t[rt].sum;
}

返回时需要取模。

1 个赞

感觉 sum 函数有点奇怪,xy 调到同一条重链上可以直接用线段树上查询了,不需要再一点一点往上跳吧

1 个赞

发帖当天已A,谢谢