3 条题解
-
3猫粮寸断 LV 10 @ 2019-06-20 19:09:08
首先将问题进行转化,每一次询问的花费为路径上所有点的点权和,在转成有根树后,我们可以把点权转移到这个点和它儿子所连的边上,问题就变成求边权和
记f[i]为从i点到根节点的花费,u为a,b的lca,则从a到b的花费为f[a]+f[b]-2 * f[u]+v(u) (可以自己手写一下看看为什么要把lca的权值加上)
lca可以用倍增求出,接下来我们只需要动态维护点到根节点的花费就可以了,这里我们看出,当一个点的值变化时,它的所有子节点的f[i]会收到影响,而除此以外的其他点不受影响,每次修改一个点,相当于区间修改它所有的子节点的值
怎么做呢?
DFS序!
DFS序本质上是树的先序遍历,同时把遍历的顺序存在一个数组中,它有一个非常优美的性质:一个点的所有子节点一定是其后面连续的一段!
如此一来,我们就可以用线性的数据结构来维护了,利用树状数组的差分,将区间修改单点查询转变为单点修改区间查询就比较容易了#include<iostream> #include<vector> using namespace std; vector<int> q[100010]; int v[100010],f[100010][19],c[100010],now=0,l[100010],r[100010],tree[100010],n; void dfs(int u,int dd) { int i,d=q[u].size(),ha; now++; l[u]=now; for(i=0;i<d;i++) { ha=q[u][i]; if(ha!=dd) { f[ha][0]=u; c[ha]=c[u]+1; dfs(ha,u); } } r[u]=now; return ; } int lca(int x,int y) { int i; if(c[x]<c[y]) { x=x^y; y=x^y; x=x^y; } while(c[x]>c[y]) { for(i=17;i>=0;i--) if(c[f[x][i]]>=c[y]) { x=f[x][i]; break; } } if(x==y) return x; while(1) { for(i=17;i>=0;i--) if(f[x][i]!=f[y][i]) break; if(i<0) break; x=f[x][i]; y=f[y][i]; } return f[x][0]; } int lowbit(int x) { return x&(-x); } int sum(int x) { int i,ans=0; for(i=x;i>0;i-=lowbit(i)) ans+=tree[i]; return ans; } void add(int x,int y) { int i; for(i=x;i<=n;i+=lowbit(i)) tree[i]+=y; } void xiu(int x,int y) { add(l[x],y-v[x]); add(r[x]+1,v[x]-y); v[x]=y; } int ans(int x,int y) { return sum(l[x])+sum(l[y])-2*sum(l[lca(x,y)])+v[lca(x,y)]; } int main() { char t; int i,j,m,a,b; cin>>n; for(i=1;i<=n;i++) cin>>v[i]; for(i=1;i<n;i++) { cin>>a>>b; q[a].push_back(b); q[b].push_back(a); } dfs(1,1); f[1][0]=1; for(i=0;i<17;i++) for(j=1;j<=n;j++) f[j][i+1]=f[f[j][i]][i]; for(i=1;i<=n;i++) { add(l[i],v[i]); add(r[i]+1,-v[i]); } cin>>m; for(i=1;i<=m;i++) { cin>>t>>a>>b; if(t=='Q') cout<<ans(a,b)<<endl; else xiu(a,b); } return 0; }
-
12020-10-24 23:18:55@
#include <cmath> #include <cstdio> #include <cstdlib> #include <cstring> #include <algorithm> #include <vector> #include <deque> #include <limits> using namespace std; namespace dts { int ft=1,cnt; class tree_node { public: int fa,dep,size,hs,top,id; vector<int> s; }; int rk[(1<<17)+1]; tree_node tr[(1<<17)+1]; void tr_dfs1(int now,int fa) { tr[now].fa=fa; tr[now].dep=tr[tr[now].fa].dep+1; tr[now].size=1; tr[now].hs=-1; for (int i=0;i<tr[now].s.size();i++) if (tr[now].s[i]!=fa) { int next=tr[now].s[i]; tr_dfs1(next,now); tr[now].size+=tr[next].size; if (tr[now].hs==-1) tr[now].hs=next; else if (tr[tr[now].hs].size<tr[next].size) tr[now].hs=next; } } void tr_dfs2(int now,int top) { tr[now].top=top; tr[now].id=++cnt; rk[cnt]=now; if (tr[now].hs!=-1) { tr_dfs2(tr[now].hs,top); for (int i=0;i<tr[now].s.size();i++) if (tr[now].s[i]!=tr[now].fa&&tr[now].s[i]!=tr[now].hs) tr_dfs2(tr[now].s[i],tr[now].s[i]); } } void tr_build() { cnt=0; tr_dfs1(ft,ft); tr_dfs2(ft,ft); } int lca(int x,int y) { while (tr[x].top!=tr[y].top) { if (tr[tr[x].top].dep<tr[tr[y].top].dep) swap(x,y); x=tr[tr[x].top].fa; } if (tr[x].dep<tr[y].dep) return x; else return y; } class st_node { public: int l,r,mid,empt=1; int lans=0,rans=0,ans=0,sum=0; int iflz=0,numlz; int len() { return tr[rk[r]].dep-tr[rk[l]].dep+1; } }; int data[(1<<17)+1]; st_node st[(1<<19)+2]; #define lc(now) ((now)<<1) #define rc(now) ((now)<<1|1) st_node merge(st_node li,st_node ri)//li:左子區間,ri:右子區間 { if (li.empt) return ri; else if (ri.empt) return li; st_node a; a.empt=a.iflz=0; a.l=li.l,a.r=ri.r,a.mid=li.r; a.lans=max(li.lans,li.sum+ri.lans); a.rans=max(ri.rans,ri.sum+li.rans); a.ans=max(max(li.ans,ri.ans),li.rans+ri.lans); a.sum=li.sum+ri.sum; return a; } void st_pushup(int now) { st[now]=merge(st[lc(now)],st[rc(now)]);//別在意時間複雜度常數 } void st_update(int now,int l,int r,int val); void st_pushdown(int now) { if (st[now].iflz) { st_update(lc(now),st[now].l,st[now].mid,st[now].numlz); st_update(rc(now),st[now].mid+1,st[now].r,st[now].numlz); st[now].iflz=0; } } void st_update(int now,int l,int r,int val) { if (st[now].l==l&&r==st[now].r) { st[now].lans=st[now].rans=st[now].ans=max(st[now].len()*val,0); st[now].sum=st[now].len()*val; st[now].iflz=1,st[now].numlz=val; } else { st_pushdown(now); if (r<=st[now].mid) st_update(lc(now),l,r,val); else if (st[now].mid+1<=l) st_update(rc(now),l,r,val); else st_update(lc(now),l,st[now].mid,val),st_update(rc(now),st[now].mid+1,r,val); st_pushup(now); } } st_node st_ask(int now,int l,int r) { if (st[now].l==l&&r==st[now].r) return st[now]; else { st_pushdown(now); if (r<=st[now].mid) return st_ask(lc(now),l,r); else if (st[now].mid+1<=l) return st_ask(rc(now),l,r); else return merge(st_ask(lc(now),l,st[now].mid),st_ask(rc(now),st[now].mid+1,r)); } } void st_build(int now,int l,int r) { st[now].empt=st[now].iflz=0; st[now].l=l,st[now].r=r; if (l<r) { st[now].mid=(l+r)>>1; st_build(lc(now),l,st[now].mid); st_build(rc(now),st[now].mid+1,r); st_pushup(now); } else { st[now].sum=data[rk[l]]; st[now].lans=st[now].rans=st[now].ans=max(data[rk[l]],0); } } void update(int x,int y,int val) { int i,j,lcan=lca(x,y); for (i=x;tr[i].top!=tr[lcan].top;i=tr[tr[i].top].fa) st_update(1,tr[tr[i].top].id,tr[i].id,val); for (j=y;tr[j].top!=tr[lcan].top;j=tr[tr[j].top].fa) st_update(1,tr[tr[j].top].id,tr[j].id,val); if (tr[i].dep>tr[j].dep) swap(i,j); st_update(1,tr[i].id,tr[j].id,val); } st_node cty(st_node stn) { swap(stn.l,stn.r); swap(stn.lans,stn.rans); return stn; } int ask(int x,int y) { int i,j,lcan=lca(x,y); st_node ians,jans,ans; ians.iflz=jans.iflz=0; ians.empt=jans.empt=1; for (i=x;tr[i].top!=tr[lcan].top;i=tr[tr[i].top].fa) ians=merge(st_ask(1,tr[tr[i].top].id,tr[i].id),ians); for (j=y;tr[j].top!=tr[lcan].top;j=tr[tr[j].top].fa) jans=merge(st_ask(1,tr[tr[j].top].id,tr[j].id),jans); if (tr[i].dep>tr[j].dep) swap(i,j),swap(ians,jans); jans=merge(st_ask(1,tr[i].id,tr[j].id),jans); ans=merge(cty(jans),ians); return ans.sum; } int n,m; void main() { scanf("%d",&n); for (int i=1;i<=n;i++) scanf("%d",&data[i]); for (int i=1;i<=n;i++) tr[i].s.clear(); for (int i=1;i<n;i++) { int x,y; scanf("%d%d",&x,&y); tr[x].s.push_back(y); tr[y].s.push_back(x); } if (n>0) tr_build(); st_build(1,1,cnt); scanf("%d\n",&m); for (int i=1;i<=m;i++) { int x,y; char K; scanf("%c%d%d\n",&K,&x,&y); if (K=='Q') printf("%d\n",ask(x,y)); else if (K=='C') update(x,x,y); } } } int main() { dts::main(); }
-
12016-03-16 13:55:57@
DFS序+树状数组+倍增LCA即可。这里有一个DFS序的教程。
注意那个教程叙述的是边权值的情况而不是点权值的情况,所以细节上可能需要改动一下。C++ Code
#include <cctype> #include <cstdio> #include <cstring> #include <algorithm> using namespace std; const int MAXSIZE=30000020; int bufpos; char buf[MAXSIZE]; void init(){ #ifdef LOCAL freopen("1986.txt","r",stdin); #endif buf[fread(buf,1,MAXSIZE,stdin)]='\0'; bufpos=0; } int readint(){ bool isneg; int val=0; for(;!isdigit(buf[bufpos]) && buf[bufpos]!='-';bufpos++); bufpos+=(isneg=(buf[bufpos]=='-'))?1:0; for(;isdigit(buf[bufpos]);bufpos++) val=val*10+buf[bufpos]-'0'; return isneg?-val:val; } char readchar(){ for(;isspace(buf[bufpos]);bufpos++); return buf[bufpos++]; } const int maxn=100001; const int maxm=200001; const int maxe=200001; struct bit{ int n; int t[maxe]; inline int lowbit(int x){ return x&(-x); } void update(int p,int v){ for(int i=p;i<=n;i+=lowbit(i)) t[i]+=v; } int query(int p){ int ans=0; for(int i=p;i>0;i-=lowbit(i)) ans+=t[i]; return ans; } }; struct edge{ int from,to,next; }; struct graph{ int n,m,maxi; bit b; int eu[maxe]; int val[maxe]; int dep[maxn]; int st[maxn]; int en[maxn]; int first[maxn]; bool vis[maxn]; edge e[maxm]; int par[maxn][21]; int cnt; void init(int n){ this->n=n; memset(first,-1,sizeof(first)); maxi=32-__builtin_clz(n); m=0; } void addedge(int from,int to){ e[++m]=(edge){from,to,first[from]}; first[from]=m; } void dfs(int u,int d,int fa){ eu[++cnt]=u; st[u]=cnt; dep[u]=d; par[u][0]=fa; for(int i=1;i<=maxi;i++) par[u][i]=par[par[u][i-1]][i-1]; for(int i=first[u];i!=-1;i=e[i].next){ int v=e[i].to; if (!st[v]) dfs(v,d+1,u); } eu[++cnt]=u; en[u]=cnt; } void prepare(){ cnt=0; memset(st,0,sizeof(st)); memset(en,0,sizeof(en)); dfs(1,1,0); //for(int i=1;i<=cnt;i++) //printf("%d ",eu[i]); //putchar('\n'); b.n=cnt+1; } int lca(int u,int v){ if (dep[u]<dep[v]) swap(u,v); int t=dep[u]-dep[v]; maxi=32-__builtin_clz(dep[u]); //printf("before:u=%d v=%d dep[u]=%d dep[v]=%d\n",u,v,dep[u],dep[v]); for(int i=0;i<=maxi;i++) if (t&(1<<i)) u=par[u][i]; //printf("after:u=%d v=%d dep[u]=%d dep[v]=%d\n",u,v,dep[u],dep[v]); //assert(dep[u]==dep[v]); for(int i=maxi;i>=0;i--){ if (par[u][i] && par[u][i]!=par[v][i]){ u=par[u][i]; v=par[v][i]; } } return u==v?u:par[u][0]; } int query(int u,int v){ //puts("WTF"); if (st[v]<st[u]) swap(u,v); int lc=lca(u,v); //puts("WTF"); //printf("u=%d v=%d st[u]=%d st[v]=%d lca=%d val[lca]=%d\n",u,v,st[u],st[v],lc,val[lc]); //printf("b.query(st[u])=%d b.query(st[v])=%d\n",b.query(st[u]),b.query(st[v])); return b.query(st[v])-2*b.query(st[lc]-1)+b.query(st[u])-val[lc]; } void update(int p,int v){ b.update(st[p],v-val[p]); b.update(en[p],val[p]-v); val[p]=v; //b.update(st[p],v); //b.update(en[p],v); } } g; int t[maxn]; int main(){ init(); int n=readint(); g.init(n); for(int i=1;i<=n;i++) t[i]=readint(); for(int i=1;i<=n-1;i++){ int u=readint(),v=readint(); g.addedge(u,v); g.addedge(v,u); } g.prepare(); //puts("WTF"); for(int i=1;i<=n;i++) g.update(i,t[i]); int q=readint(); for(int i=1;i<=q;i++){ char buf=readchar(); int x=readint(),y=readint(); if (buf=='Q') printf("%d\n",g.query(x,y)); else g.update(x,y); } }
- 1