1 条题解

  • 2
    @ 2022-04-09 19:22:06
    #include <bits/stdc++.h>
    using namespace std;
    #define M 100005
    #define ri register int
    
    int n,m,R;
    long long P;
    long long  val[M];
    vector <int> p[M];
    
    void read()     // ok
    {
        scanf("%d%d%d%lld",&n,&m,&R,&P);
        for(ri i=1;i<=n;i++)
        {
            scanf("%lld",&val[i]);
        }
        for(ri i=1;i<n;i++)
        {
            int a,b;
            scanf("%d%d",&a,&b);
            p[a].push_back(b);
            p[b].push_back(a);
        }
    }
    
    int dep[M],sz[M],fa[M],son[M];
    int  vis[M];
    
    void dfs1(int r,int f)
    {
         fa[r]=f;
         dep[r]=dep[f]+1;
         vis[r]=1;
         int maxz=0;
         sz[r]=1;
         for(ri i=0;i<p[r].size();i++)
         {
            int b=p[r][i];
            if(vis[b]) continue ;
            dfs1(b,r);
            sz[r]+=sz[b];
            if(maxz<sz[b])
            {
                son[r]=b;
                maxz=sz[b];
            }
         }
    
    }
    
    int id[M],hao[M];
    int cur,top[M],addr[M];
    void dfs2(int r,int f)
    {
        vis[r]=1;
        id[r]=++cur;
        addr[cur]=r;
        top[r]=f;
    
        if(son[r]) dfs2(son[r],f);
    
        for(ri i=0;i<p[r].size();i++)
        {
            int b=p[r][i];
            if(vis[b]) continue;
            dfs2(b,b);
        }
        hao[r]=cur;
    
    }
    
    struct xdhsu{
        int l,r;
        long long val;
    }tr[M*4];
    void build(int l,int r,int i)
    {   
         tr[i].l=l;tr[i].r=r;
        if(l==r)
        {
            tr[i].val=val[addr[l]]; // 
            return ;
        }
        int mid=(l+r)>>1;
        build(l,mid,i<<1);
        build(mid+1,r,i<<1|1);
        tr[i].val+=tr[i<<1].val;
        tr[i].val+=tr[i<<1|1].val;
        tr[i].val%=P;
    }
    
    long long  lz[M*4];
    
    void down(int i)
    {
        if(!lz[i]) return ;
        lz[i<<1]+=lz[i];lz[i<<1]%=P;
        lz[i<<1|1]+=lz[i];lz[i<<1|1]%=P;
        tr[i<<1].val+=(tr[i<<1].r-tr[i<<1].l+1)*lz[i]%P;tr[i<<1].val%=P;
    
        tr[i<<1|1].val+=(tr[i<<1|1].r-tr[i<<1|1].l+1)*lz[i]%P;tr[i<<1|1].val%=P;
        lz[i]=0;
    } 
    void xiu(int l,int r,int i,int k)
    {
        if(l>tr[i].r||r<tr[i].l) return ;
        if(l<=tr[i].l&&r>=tr[i].r)
        {
            tr[i].val+=(tr[i].r-tr[i].l+1)*k%P;
            tr[i].val%=P;
            lz[i]+=k%P; //// ky ba
            lz[i]%=P;
            return ;    
        }
        down(i);
        xiu(l,r,i<<1,k);
        xiu(l,r,i<<1|1,k);
        tr[i].val=(tr[i<<1].val+tr[i<<1|1].val)%P;
    }
    
    long long  qu(int l,int r,int i)
    {
        if(l>tr[i].r||r<tr[i].l) return 0;
        if(l<=tr[i].l&&r>=tr[i].r)
        {
    
            return tr[i].val%P; 
        }
        down(i);
        return (qu(l,r,i<<1)+qu(l,r,i<<1|1))%P;
    }
    
    void xlca(int a,int b,int k)
    {
        while(top[a]!=top[b])
        {
            if(dep[top[a]]>dep[top[b]])
            swap(a,b);
            xiu(id[top[b]],id[b],1,k);
            b=fa[top[b]];
        }
        if(dep[a]>dep[b]) swap(a,b);
        xiu(id[a],id[b],1,k);
    }
    long long clca(int a,int b)
    {
           long long ans=0;
            while(top[a]!=top[b])
            {
             if(dep[top[a]]>dep[top[b]])
             swap(a,b);
             ans=(ans+qu(id[top[b]],id[b],1))%P;
              b=fa[top[b]];
            }
        if(dep[a]>dep[b]) swap(a,b);
           ans=(ans+qu(id[a],id[b],1))%P;
           return ans;
    }
    
    void solve()
    {
        dfs1(R,0);
        memset(vis,0,sizeof(vis));
        dfs2(R,R);
    
        build(1,cur,1);
    
        for(ri i=1;i<=m;i++)
        {
            int a;
            scanf("%d",&a);
            if(a==1)
            {
                int x,y,z;
                scanf("%d%d%d",&x,&y,&z);
                xlca(x,y,z);
            }
            if(a==2)
            {
                int x,y;
                scanf("%d%d",&x,&y);
    
                long long anss;
                anss=clca(x,y)%P;
                printf("%lld\n",anss);
            }
            if(a==3)
            {
                int x,z;
                scanf("%d%d",&x,&z);
                xiu(id[x],hao[x],1,z);
            }
            if(a==4)
            {
                int x;
                scanf("%d",&x);
                long long anss=qu(id[x],hao[x],1)%P;
                printf("%lld\n",anss);
    
            }
        }
    }
    
    int main(){
    
        read();
        solve();
    
        return 0;
    
    }   
    
  • 1

【模板】轻重链剖分 / 树链剖分

信息

ID
1123
难度
10
分类
树链剖分 点击显示
标签
递交数
7
已通过
3
通过率
43%
上传者