[BZOJ2243][SDOI2011]染色 解题报告|树链剖分
Description
给定一棵有n个节点的无根树和m个操作,操作有2类:
1、将节点a到节点b路径上所有点都染成颜色c;
2、询问节点a到节点b路径上的颜色段数量(连续相同颜色被认为是同一段),如“112221”由3段组成:“11”、“222”和“1”。
请你写一个程序依次完成这m个操作。
与上一题差别不大,主要就是solve过程要根据左端点和右端点的颜色处理合并时候的情况
线段树的每个节点要记录颜色段数|最左边的颜色|最右边的颜色
同时往下传的时候标记要做好(之前那道题是单点修改所以不用考虑下传
因为这个昨晚调了一个晚上无果...果然是早上头脑清醒的多 一下子就改好了...
program bzoj2243; const maxn=1000010;maxm=2000010; var ter,next:array[-1..maxm]of longint; link,pos,deep,size,belong,v:array[-1..maxn]of longint; fa:array[-1..maxn,-1..20]of longint; cnt,n,m,i,j,x,y,z,t:longint; ch:char; tr:array[-1..5*maxn]of record l,r,lc,rc,cover:longint;wait:boolean;end; procedure add(x,y:longint); begin inc(j);ter[j]:=y;next[j]:=link[x];link[x]:=j; inc(j);ter[j]:=x;next[j]:=link[y];link[y]:=j; end; procedure dfs1(p:longint); var i,j:longint; begin size[p]:=1; for i:=1 to 17 do begin if deep[p]<=1 << i then break; fa[p][i]:=fa[fa[p][i-1]][i-1]; end; j:=link[p]; while j<>0 do begin if deep[ter[j]]=0 then begin deep[ter[j]]:=deep[p]+1; fa[ter[j]][0]:=p; dfs1(ter[j]); inc(size[p],size[ter[j]]); end; j:=next[j]; end; end; procedure dfs2(p,chain:longint); var j,k:longint; begin inc(cnt);pos[p]:=cnt;belong[p]:=chain; k:=0; j:=link[p]; while j<>0 do begin if deep[ter[j]]>deep[p] then if size[ter[j]]>size[k] then k:=ter[j]; j:=next[j]; end; if k=0 then exit; dfs2(k,chain); j:=link[p]; while j<>0 do begin if (deep[ter[j]]>deep[p])and(k<>ter[j]) then dfs2(ter[j],ter[j]); j:=next[j]; end; end; procedure build(p,l,r:longint); var mid:longint; begin tr[p].l:=l;tr[p].r:=r;tr[p].cover:=1;tr[p].lc:=-1;tr[p].rc:=-1; if l=r then exit; mid:=(l+r) >> 1; build(p << 1,l,mid); build(p << 1+1,mid+1,r); end; procedure insert(p,l,r,x:longint); var mid:longint; begin if (tr[p].l=l)and(tr[p].r=r) then begin tr[p].cover:=1;tr[p].lc:=x;tr[p].rc:=x;tr[p].wait:=true; exit; end; mid:=(tr[p].l+tr[p].r) >> 1; if tr[p].wait then begin tr[p << 1].cover:=1;tr[p << 1].lc:=tr[p].lc;tr[p << 1].rc:=tr[p].rc; tr[p << 1+1].cover:=1;tr[p << 1+1].lc:=tr[p].lc;tr[p << 1+1].rc:=tr[p].rc; tr[p].wait:=false; tr[p << 1].wait:=true;tr[p << 1+1].wait:=true; end; if r<=mid then insert(p << 1,l,r,x) else if l>mid then insert(p << 1+1,l,r,x) else begin insert(p << 1,l,mid,x); insert(p << 1+1,mid+1,r,x); end; tr[p].cover:=tr[p << 1].cover+tr[p << 1+1].cover; if tr[p << 1].rc=tr[p << 1+1].lc then dec(tr[p].cover); tr[p].lc:=tr[p << 1].lc;tr[p].rc:=tr[p << 1+1].rc; end; function lca(x,y:longint):longint; var tem,i:longint; begin if deep[x]<deep[y] then begin tem:=x;x:=y;y:=tem; end; if deep[x]>deep[y] then begin i:=trunc(ln(deep[x]-deep[y])/ln(2)); while deep[x]<>deep[y] do begin while (deep[x]-deep[y]>=1 << i) do x:=fa[x][i]; dec(i); end; end; if x=y then exit(x); i:=trunc(ln(n)/ln(2)); while fa[x][0]<>fa[y][0] do begin while (fa[x][i]<>fa[y][i]) do begin x:=fa[x,i];y:=fa[y][i]; end; dec(i); end; exit(fa[x,0]); end; procedure query(p,l,r:longint;var lc,rc,ave:longint); var mid,ll,rr,x,tem:longint; begin if (tr[p].l=l)and(tr[p].r=r) then begin lc:=tr[p].lc;rc:=tr[p].rc; ave:=tr[p].cover; exit; end; mid:=(tr[p].l+tr[p].r) >> 1; if tr[p].wait then begin tr[p << 1].cover:=1;tr[p << 1].lc:=tr[p].lc;tr[p << 1].rc:=tr[p].rc; tr[p << 1+1].cover:=1;tr[p << 1+1].lc:=tr[p].lc;tr[p << 1+1].rc:=tr[p].rc; tr[p].wait:=false; tr[p << 1].wait:=true;tr[p << 1+1].wait:=true; end; if r<=mid then query(p << 1,l,r,lc,rc,ave) else if l>mid then query(p << 1+1,l,r,lc,rc,ave) else begin query(p << 1,l,mid,lc,ll,ave); query(p << 1+1,mid+1,r,rr,rc,tem); inc(ave,tem); if ll=rr then dec(ave); end; end; function solve(x,y:longint):longint; var r,sum,lc,rc,tem:longint; begin r:=-1;sum:=0; while belong[x]<>belong[y] do begin query(1,pos[belong[x]],pos[x],lc,rc,tem); inc(sum,tem); if rc=r then dec(sum); r:=lc; x:=fa[belong[x]][0]; end; query(1,pos[y],pos[x],lc,rc,tem); inc(sum,tem); if rc=r then dec(sum); exit(sum); end; procedure mend(x,y,z:longint); begin while belong[x]<>belong[y] do begin insert(1,pos[belong[x]],pos[x],z); x:=fa[belong[x]][0]; end; insert(1,pos[y],pos[x],z); end; begin readln(n,m); for i:=1 to n do read(v[i]);readln; j:=0; for i:=1 to n-1 do begin readln(x,y); add(x,y); end; deep[1]:=1; dfs1(1); cnt:=0;dfs2(1,1); build(1,1,n); for i:=1 to n do insert(1,pos[i],pos[i],v[i]); for i:=1 to m do begin read(ch); if ch='C' then begin readln(x,y,z); t:=lca(x,y); mend(x,t,z);mend(y,t,z); end else begin readln(x,y); z:=lca(x,y); writeln(solve(x,z)+solve(y,z)-1); end; end; end.
[UPD.5.11]今天复习树链剖分,挑了这道题写了写。一个地方WA检查了三遍对拍了一个多小时才找出来...
就是在判断链之间答案合并的时候是从右往左的而不是从左往右...习惯性的写法干脆检查的时候都忽略了这一点...
另外这次用C++大概150行左右..感觉还不错
1 #include<cstdio> 2 #include<cstdlib> 3 #include<cmath> 4 #include<cstring> 5 #define maxn 1000010 6 #define maxm 2000010 7 struct node{ 8 int l,r,lc,rc,count,wait; 9 }tr[maxn*4]; 10 int next[maxm],ter[maxm],link[maxn],pos[maxn],belong[maxn],dep[maxn]; 11 int fa[maxn][25],sz[maxn],n,m,e,v[maxn],cnt; 12 char ch[100]; 13 void swap(int &x,int &y){ 14 int tmp=x;x=y;y=tmp; 15 } 16 void add(int x,int y){ 17 ter[++e]=y;next[e]=link[x];link[x]=e; 18 ter[++e]=x;next[e]=link[y];link[y]=e; 19 } 20 void dfs1(int p){ 21 sz[p]=1; 22 for (int i=1;i<=20&&dep[p]>(1<<i);i++) fa[p][i]=fa[fa[p][i-1]][i-1]; 23 for (int i=link[p];i;i=next[i]) if (dep[ter[i]]==0){ 24 dep[ter[i]]=dep[p]+1; 25 fa[ter[i]][0]=p; 26 dfs1(ter[i]); 27 sz[p]+=sz[ter[i]]; 28 } 29 } 30 void dfs2(int p,int chain){ 31 pos[p]=++cnt;belong[p]=chain; 32 int k=0; 33 for (int i=link[p];i;i=next[i]) if (dep[ter[i]]==dep[p]+1&&sz[ter[i]]>sz[k]) k=ter[i]; 34 if (k==0) return; 35 dfs2(k,chain); 36 for (int i=link[p];i;i=next[i]) if (dep[ter[i]]==dep[p]+1&&ter[i]!=k) dfs2(ter[i],ter[i]); 37 } 38 void update(int p){ 39 tr[p].lc=tr[p<<1].lc; 40 tr[p].rc=tr[(p<<1)+1].rc; 41 tr[p].count=tr[p<<1].count+tr[(p<<1)+1].count; 42 if (tr[p<<1].rc==tr[(p<<1)+1].lc) tr[p].count--; 43 } 44 void push(int p){ 45 if (tr[p].wait==-1) return; 46 int u=p<<1,w=(p<<1)+1,col=tr[p].wait; 47 tr[u].lc=tr[u].rc=tr[u].wait=col;tr[u].count=1; 48 tr[w].lc=tr[w].rc=tr[w].wait=col;tr[w].count=1; 49 tr[p].wait=-1; 50 } 51 void build(int p,int l,int r){ 52 tr[p].l=l;tr[p].r=r;tr[p].wait=-1; 53 if (l==r) return; 54 int mid=(l+r) >> 1; 55 build(p<<1,l,mid);build((p<<1)+1,mid+1,r); 56 } 57 int lca(int x,int y){ 58 if (dep[x]<dep[y]) swap(x,y); 59 int i; 60 if (dep[x]>dep[y]){ 61 i=(int)(log(dep[x]-dep[y])/log(2)); 62 while (dep[x]>dep[y]){ 63 while (dep[x]-dep[y]>=1<<i) x=fa[x][i]; 64 i--; 65 } 66 } 67 if (x==y) return x; 68 i=(int)(log(n)/log(2)); 69 while (fa[x][0]!=fa[y][0]){ 70 while (fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i]; 71 i--; 72 } 73 return fa[x][0]; 74 } 75 void insert(int p,int l,int r,int col){ 76 if (tr[p].l==l&&tr[p].r==r){ 77 tr[p].lc=col;tr[p].rc=col;tr[p].wait=col;tr[p].count=1; 78 return; 79 } 80 push(p); 81 int mid=(tr[p].l+tr[p].r)>>1; 82 if (r<=mid) insert(p<<1,l,r,col);else 83 if (l>mid) insert((p<<1)+1,l,r,col);else{ 84 insert(p<<1,l,mid,col);insert((p<<1)+1,mid+1,r,col); 85 } 86 update(p); 87 } 88 void query(int p,int l,int r,int &lc,int &rc,int &ans){ 89 if (tr[p].l==l&&tr[p].r==r){ 90 lc=tr[p].lc;rc=tr[p].rc;ans=tr[p].count; 91 return; 92 } 93 push(p); 94 int mid=(tr[p].l+tr[p].r)>>1; 95 if (r<=mid) query(p<<1,l,r,lc,rc,ans);else 96 if (l>mid) query((p<<1)+1,l,r,lc,rc,ans);else 97 { 98 int ll,rr,ans1; 99 query(p<<1,l,mid,lc,rr,ans); 100 query((p<<1)+1,mid+1,r,ll,rc,ans1); 101 ans+=ans1;if (ll==rr) ans--; 102 } 103 update(p); 104 } 105 void change(int x,int y,int col){ 106 while (belong[x]!=belong[y]){ 107 insert(1,pos[belong[x]],pos[x],col); 108 x=fa[belong[x]][0]; 109 } 110 insert(1,pos[y],pos[x],col); 111 } 112 int solve(int x,int y){ 113 int sum=0,r=-1,ll,rr,tmp; 114 while (belong[x]!=belong[y]){ 115 query(1,pos[belong[x]],pos[x],ll,rr,tmp); 116 x=fa[belong[x]][0]; 117 sum+=tmp;if (rr==r) sum--; 118 r=ll; 119 } 120 query(1,pos[y],pos[x],ll,rr,tmp); 121 sum+=tmp;if (rr==r) sum--; 122 return sum; 123 } 124 int main(){ 125 scanf("%d%d",&n,&m); 126 for (int i=1;i<=n;i++) scanf("%d",&v[i]); 127 e=0; 128 int x,y,col,t; 129 for (int i=1;i<=n-1;i++) scanf("%d%d",&x,&y),add(x,y); 130 dep[1]=1;dfs1(1); 131 cnt=0;dfs2(1,1); 132 build(1,1,n); 133 for (int i=1;i<=n;i++) insert(1,pos[i],pos[i],v[i]); 134 for (int i=1;i<=m;i++){ 135 scanf("%s",ch); 136 if (ch[0]=='C'){ 137 scanf("%d%d%d",&x,&y,&col); 138 t=lca(x,y); 139 change(x,t,col);change(y,t,col); 140 }else{ 141 scanf("%d%d",&x,&y); 142 t=lca(x,y); 143 printf("%d\n",solve(x,t)+solve(y,t)-1); 144 } 145 } 146 return 0; 147 }