3 条题解

  • 3
    @ 2019-06-20 19:09:08

    首先将问题进行转化,每一次询问的花费为路径上所有点的点权和,在转成有根树后,我们可以把点权转移到这个点和它儿子所连的边上,问题就变成求边权和
    记f[i]为从i点到根节点的花费,u为a,b的lca,则从a到b的花费为f[a]+f[b]-2 * f[u]+v(u) (可以自己手写一下看看为什么要把lca的权值加上)
    lca可以用倍增求出,接下来我们只需要动态维护点到根节点的花费就可以了,这里我们看出,当一个点的值变化时,它的所有子节点的f[i]会收到影响,而除此以外的其他点不受影响,每次修改一个点,相当于区间修改它所有的子节点的值
    怎么做呢?
    DFS序!
    DFS序本质上是树的先序遍历,同时把遍历的顺序存在一个数组中,它有一个非常优美的性质:一个点的所有子节点一定是其后面连续的一段!
    如此一来,我们就可以用线性的数据结构来维护了,利用树状数组的差分,将区间修改单点查询转变为单点修改区间查询就比较容易了

    #include<iostream>
    #include<vector>
    using namespace std;
    vector<int> q[100010];
    int v[100010],f[100010][19],c[100010],now=0,l[100010],r[100010],tree[100010],n;
    void dfs(int u,int dd)
    {
        int i,d=q[u].size(),ha; 
        now++;
        l[u]=now;
        for(i=0;i<d;i++)
        {
            ha=q[u][i];
            if(ha!=dd)
            {
                f[ha][0]=u;
                c[ha]=c[u]+1;
                dfs(ha,u);
            }
        }
        r[u]=now;
        return ;
    }
    int lca(int x,int y)
    {
        
        int i;
        if(c[x]<c[y])
        {
            x=x^y;
            y=x^y;
            x=x^y;
        }
        while(c[x]>c[y])
        {
            for(i=17;i>=0;i--)
             if(c[f[x][i]]>=c[y])
              {
                 x=f[x][i];
                 break;
              }
        }
        if(x==y)
         return x;
        while(1)
        {
            for(i=17;i>=0;i--)
             if(f[x][i]!=f[y][i])
              break;
            if(i<0)
              break;
            x=f[x][i];
            y=f[y][i];
        }
        return f[x][0];
    }
    int lowbit(int x)
    {
        return x&(-x);
    }
    int sum(int x)
    {
        int i,ans=0;
        for(i=x;i>0;i-=lowbit(i))
         ans+=tree[i];
        return ans;
    }
    void add(int x,int y)
    {
        int i;
        for(i=x;i<=n;i+=lowbit(i))
         tree[i]+=y;
    }
    void xiu(int x,int y)
    {
        add(l[x],y-v[x]);
        add(r[x]+1,v[x]-y);
        v[x]=y;
    }
    int ans(int x,int y)
    {
        return sum(l[x])+sum(l[y])-2*sum(l[lca(x,y)])+v[lca(x,y)];
    }
    int main()
    {
        char t;
        int i,j,m,a,b;
        cin>>n;
        for(i=1;i<=n;i++)
         cin>>v[i];
        for(i=1;i<n;i++)
        {
            cin>>a>>b;
            q[a].push_back(b);
            q[b].push_back(a);
        }
        dfs(1,1);
        f[1][0]=1;
        for(i=0;i<17;i++)
         for(j=1;j<=n;j++)
          f[j][i+1]=f[f[j][i]][i];
        for(i=1;i<=n;i++)
        {
            add(l[i],v[i]);
            add(r[i]+1,-v[i]);
        }
        cin>>m;
        for(i=1;i<=m;i++)
        {
            cin>>t>>a>>b;
            if(t=='Q')
             cout<<ans(a,b)<<endl;
            else
             xiu(a,b);
        }
        return 0;
    }
    
  • 1
    @ 2020-10-24 23:18:55
    #include <cmath>
    #include <cstdio>
    #include <cstdlib>
    #include <cstring>
    #include <algorithm>
    #include <vector>
    #include <deque>
    #include <limits>
    using namespace std;
    
    namespace dts
    {
        int ft=1,cnt;
        class tree_node
        {
            public:
                int fa,dep,size,hs,top,id;
                vector<int> s;
        };
        int rk[(1<<17)+1];
        tree_node tr[(1<<17)+1];
        void tr_dfs1(int now,int fa)
        {
            tr[now].fa=fa;
            tr[now].dep=tr[tr[now].fa].dep+1;
            tr[now].size=1;
            tr[now].hs=-1;
            for (int i=0;i<tr[now].s.size();i++)
                if (tr[now].s[i]!=fa)
                {
                    int next=tr[now].s[i];
                    tr_dfs1(next,now);
                    tr[now].size+=tr[next].size;
                    if (tr[now].hs==-1)
                        tr[now].hs=next;
                    else if (tr[tr[now].hs].size<tr[next].size)
                        tr[now].hs=next;
                }
        }
        void tr_dfs2(int now,int top)
        {
            tr[now].top=top;
            tr[now].id=++cnt;
            rk[cnt]=now;
            if (tr[now].hs!=-1)
            {
                tr_dfs2(tr[now].hs,top);
                for (int i=0;i<tr[now].s.size();i++)
                    if (tr[now].s[i]!=tr[now].fa&&tr[now].s[i]!=tr[now].hs)
                        tr_dfs2(tr[now].s[i],tr[now].s[i]);
            }
        }
        void tr_build()
        {
            cnt=0;
            tr_dfs1(ft,ft);
            tr_dfs2(ft,ft);
        }
        int lca(int x,int y)
        {
            while (tr[x].top!=tr[y].top)
            {
                if (tr[tr[x].top].dep<tr[tr[y].top].dep)
                    swap(x,y);
                x=tr[tr[x].top].fa;
            }
            if (tr[x].dep<tr[y].dep)
                return x;
            else
                return y;
        }
        
        class st_node
        {
            public:
                int l,r,mid,empt=1;
                int lans=0,rans=0,ans=0,sum=0;
                int iflz=0,numlz;
                int len()
                {
                    return tr[rk[r]].dep-tr[rk[l]].dep+1;
                }
        };
        int data[(1<<17)+1];
        st_node st[(1<<19)+2];
        #define lc(now) ((now)<<1)
        #define rc(now) ((now)<<1|1)
        st_node merge(st_node li,st_node ri)//li:左子區間,ri:右子區間
        {
            if (li.empt)
                return ri;
            else if (ri.empt)
                return li;
            st_node a;
            a.empt=a.iflz=0;
            a.l=li.l,a.r=ri.r,a.mid=li.r;
            a.lans=max(li.lans,li.sum+ri.lans);
            a.rans=max(ri.rans,ri.sum+li.rans);
            a.ans=max(max(li.ans,ri.ans),li.rans+ri.lans);
            a.sum=li.sum+ri.sum;
            return a;
        }
        void st_pushup(int now)
        {
            st[now]=merge(st[lc(now)],st[rc(now)]);//別在意時間複雜度常數
        }
        void st_update(int now,int l,int r,int val);
        void st_pushdown(int now)
        {
            if (st[now].iflz)
            {
                st_update(lc(now),st[now].l,st[now].mid,st[now].numlz);
                st_update(rc(now),st[now].mid+1,st[now].r,st[now].numlz);
                st[now].iflz=0;
            }
        }
        void st_update(int now,int l,int r,int val)
        {
            if (st[now].l==l&&r==st[now].r)
            {
                st[now].lans=st[now].rans=st[now].ans=max(st[now].len()*val,0);
                st[now].sum=st[now].len()*val;
                st[now].iflz=1,st[now].numlz=val;
            }
            else
            {
                st_pushdown(now);
                if (r<=st[now].mid)
                    st_update(lc(now),l,r,val);
                else if (st[now].mid+1<=l)
                    st_update(rc(now),l,r,val);
                else
                    st_update(lc(now),l,st[now].mid,val),st_update(rc(now),st[now].mid+1,r,val);
                st_pushup(now);
            }
        }
        st_node st_ask(int now,int l,int r)
        {
            if (st[now].l==l&&r==st[now].r)
                return st[now];
            else
            {
                st_pushdown(now);
                if (r<=st[now].mid)
                    return st_ask(lc(now),l,r);
                else if (st[now].mid+1<=l)
                    return st_ask(rc(now),l,r);
                else
                    return merge(st_ask(lc(now),l,st[now].mid),st_ask(rc(now),st[now].mid+1,r));
            }
        }
        void st_build(int now,int l,int r)
        {
            st[now].empt=st[now].iflz=0;
            st[now].l=l,st[now].r=r;
            if (l<r)
            {
                st[now].mid=(l+r)>>1;
                st_build(lc(now),l,st[now].mid);
                st_build(rc(now),st[now].mid+1,r);
                st_pushup(now);
            }
            else
            {
                st[now].sum=data[rk[l]];
                st[now].lans=st[now].rans=st[now].ans=max(data[rk[l]],0);
            }
        }
        
        void update(int x,int y,int val)
        {
            int i,j,lcan=lca(x,y);
            for (i=x;tr[i].top!=tr[lcan].top;i=tr[tr[i].top].fa)
                st_update(1,tr[tr[i].top].id,tr[i].id,val);
            for (j=y;tr[j].top!=tr[lcan].top;j=tr[tr[j].top].fa)
                st_update(1,tr[tr[j].top].id,tr[j].id,val);
            if (tr[i].dep>tr[j].dep)
                swap(i,j);
            st_update(1,tr[i].id,tr[j].id,val);
        }
        st_node cty(st_node stn)
        {
            swap(stn.l,stn.r);
            swap(stn.lans,stn.rans);
            return stn;
        }
        int ask(int x,int y)
        {
            int i,j,lcan=lca(x,y);
            st_node ians,jans,ans;
            ians.iflz=jans.iflz=0;
            ians.empt=jans.empt=1;
            for (i=x;tr[i].top!=tr[lcan].top;i=tr[tr[i].top].fa)
                ians=merge(st_ask(1,tr[tr[i].top].id,tr[i].id),ians);
            for (j=y;tr[j].top!=tr[lcan].top;j=tr[tr[j].top].fa)
                jans=merge(st_ask(1,tr[tr[j].top].id,tr[j].id),jans);
            if (tr[i].dep>tr[j].dep)
                swap(i,j),swap(ians,jans);
            jans=merge(st_ask(1,tr[i].id,tr[j].id),jans);
            ans=merge(cty(jans),ians);
            return ans.sum;
        }
        
        int n,m;
        
        void main()
        {
            scanf("%d",&n);
            for (int i=1;i<=n;i++)
                scanf("%d",&data[i]);
            for (int i=1;i<=n;i++)
                tr[i].s.clear();
            for (int i=1;i<n;i++)
            {
                int x,y;
                scanf("%d%d",&x,&y);
                tr[x].s.push_back(y);
                tr[y].s.push_back(x);
            }
            if (n>0)
                tr_build();
            st_build(1,1,cnt);
            scanf("%d\n",&m);
            for (int i=1;i<=m;i++)
            {
                int x,y;
                char K;
                scanf("%c%d%d\n",&K,&x,&y);
                if (K=='Q')
                    printf("%d\n",ask(x,y));
                else if (K=='C')
                    update(x,x,y);
            }
        }
    }
    
    int main()
    {
        dts::main();
    }
    
  • 1
    @ 2016-03-16 13:55:57

    DFS序+树状数组+倍增LCA即可。这里有一个DFS序的教程。
    注意那个教程叙述的是边权值的情况而不是点权值的情况,所以细节上可能需要改动一下。

    C++ Code

    #include <cctype>
    #include <cstdio>
    #include <cstring>
    #include <algorithm>
    using namespace std;
    const int MAXSIZE=30000020;
    int bufpos;
    char buf[MAXSIZE];
    void init(){
        #ifdef LOCAL
            freopen("1986.txt","r",stdin);
        #endif
        buf[fread(buf,1,MAXSIZE,stdin)]='\0';
        bufpos=0;
    }
    int readint(){
        bool isneg;
        int val=0;
        for(;!isdigit(buf[bufpos]) && buf[bufpos]!='-';bufpos++);
        bufpos+=(isneg=(buf[bufpos]=='-'))?1:0;
        for(;isdigit(buf[bufpos]);bufpos++)
            val=val*10+buf[bufpos]-'0';
        return isneg?-val:val;
    }
    char readchar(){
        for(;isspace(buf[bufpos]);bufpos++);
        return buf[bufpos++];
    }
    const int maxn=100001;
    const int maxm=200001;
    const int maxe=200001;
    struct bit{
        int n;
        int t[maxe];
        inline int lowbit(int x){
            return x&(-x);
        }
        void update(int p,int v){
            for(int i=p;i<=n;i+=lowbit(i))
                t[i]+=v;
        }
        int query(int p){
            int ans=0;
            for(int i=p;i>0;i-=lowbit(i))
                ans+=t[i];
            return ans;
        }
    };
    struct edge{
        int from,to,next;
    };
    struct graph{
        int n,m,maxi;
        bit b;
        int eu[maxe];
        int val[maxe];
        int dep[maxn];
        int st[maxn];
        int en[maxn];
        int first[maxn];
        bool vis[maxn];
        edge e[maxm];
        int par[maxn][21];
        int cnt;
        void init(int n){
            this->n=n;
            memset(first,-1,sizeof(first));
            maxi=32-__builtin_clz(n);
            m=0;
        }
        void addedge(int from,int to){
            e[++m]=(edge){from,to,first[from]};
            first[from]=m;
        }
        void dfs(int u,int d,int fa){
            eu[++cnt]=u;
            st[u]=cnt;
            dep[u]=d;
            par[u][0]=fa;
            for(int i=1;i<=maxi;i++)
                par[u][i]=par[par[u][i-1]][i-1];
            for(int i=first[u];i!=-1;i=e[i].next){
                int v=e[i].to;
                if (!st[v])
                    dfs(v,d+1,u);
            }
            eu[++cnt]=u;
            en[u]=cnt;
        }
        void prepare(){
            cnt=0;
            memset(st,0,sizeof(st));
            memset(en,0,sizeof(en));
            dfs(1,1,0);
            //for(int i=1;i<=cnt;i++)
                //printf("%d ",eu[i]);
            //putchar('\n');
            b.n=cnt+1;
        }
        int lca(int u,int v){
            if (dep[u]<dep[v])
                swap(u,v);
            int t=dep[u]-dep[v];
            maxi=32-__builtin_clz(dep[u]);
            //printf("before:u=%d v=%d dep[u]=%d dep[v]=%d\n",u,v,dep[u],dep[v]);
            for(int i=0;i<=maxi;i++)
                if (t&(1<<i))
                    u=par[u][i];
            //printf("after:u=%d v=%d dep[u]=%d dep[v]=%d\n",u,v,dep[u],dep[v]);
            //assert(dep[u]==dep[v]);
            for(int i=maxi;i>=0;i--){
                if (par[u][i] && par[u][i]!=par[v][i]){
                    u=par[u][i];
                    v=par[v][i];
                }
            }
            return u==v?u:par[u][0];
        }
        int query(int u,int v){
            //puts("WTF");
            if (st[v]<st[u])
                swap(u,v);
            int lc=lca(u,v);
            //puts("WTF");
            //printf("u=%d v=%d st[u]=%d st[v]=%d lca=%d val[lca]=%d\n",u,v,st[u],st[v],lc,val[lc]);
            //printf("b.query(st[u])=%d b.query(st[v])=%d\n",b.query(st[u]),b.query(st[v]));
            return b.query(st[v])-2*b.query(st[lc]-1)+b.query(st[u])-val[lc];
        }
        void update(int p,int v){
            b.update(st[p],v-val[p]);
            b.update(en[p],val[p]-v);
            val[p]=v;
            //b.update(st[p],v);
            //b.update(en[p],v);
        }
    } g;
    int t[maxn];
    int main(){
        init();
        int n=readint();
        g.init(n);
        for(int i=1;i<=n;i++)
            t[i]=readint();
        for(int i=1;i<=n-1;i++){
            int u=readint(),v=readint();
            g.addedge(u,v);
            g.addedge(v,u);
        }
        g.prepare();
        //puts("WTF");
        for(int i=1;i<=n;i++)
            g.update(i,t[i]);
        int q=readint();
        for(int i=1;i<=q;i++){
            char buf=readchar();
            int x=readint(),y=readint();
            if (buf=='Q')
               printf("%d\n",g.query(x,y));
            else
                g.update(x,y);
        }
    }
    
  • 1

信息

ID
1986
难度
6
分类
小h的妹子树 点击显示
标签
(无)
递交数
172
已通过
50
通过率
29%
被复制
2
上传者