[BZOJ2243][SDOI2011]染色 解题报告|树链剖分

Description

 

给定一棵有n个节点的无根树和m个操作,操作有2类:

1、将节点a到节点b路径上所有点都染成颜色c;

2、询问节点a到节点b路径上的颜色段数量(连续相同颜色被认为是同一段),如“112221”3段组成:“11”、“222”和“1”

请你写一个程序依次完成这m个操作。

  与上一题差别不大,主要就是solve过程要根据左端点和右端点的颜色处理合并时候的情况

  线段树的每个节点要记录颜色段数|最左边的颜色|最右边的颜色

  同时往下传的时候标记要做好(之前那道题是单点修改所以不用考虑下传

  因为这个昨晚调了一个晚上无果...果然是早上头脑清醒的多 一下子就改好了...

program bzoj2243;
const maxn=1000010;maxm=2000010;
var ter,next:array[-1..maxm]of longint;
    link,pos,deep,size,belong,v:array[-1..maxn]of longint;
    fa:array[-1..maxn,-1..20]of longint;
    cnt,n,m,i,j,x,y,z,t:longint;
    ch:char;
    tr:array[-1..5*maxn]of record l,r,lc,rc,cover:longint;wait:boolean;end;

procedure add(x,y:longint);
begin
    inc(j);ter[j]:=y;next[j]:=link[x];link[x]:=j;
    inc(j);ter[j]:=x;next[j]:=link[y];link[y]:=j;
end;

procedure dfs1(p:longint);
var i,j:longint;
begin
    size[p]:=1;
    for i:=1 to 17 do
    begin
        if deep[p]<=1 << i then break;
        fa[p][i]:=fa[fa[p][i-1]][i-1];
    end;
    j:=link[p];
    while j<>0 do
    begin
        if deep[ter[j]]=0 then
        begin
            deep[ter[j]]:=deep[p]+1;
            fa[ter[j]][0]:=p;
            dfs1(ter[j]);
            inc(size[p],size[ter[j]]);
        end;
        j:=next[j];
    end;
end;

procedure dfs2(p,chain:longint);
var j,k:longint;
begin
    inc(cnt);pos[p]:=cnt;belong[p]:=chain;
    k:=0;
    j:=link[p];
    while j<>0 do
    begin
        if deep[ter[j]]>deep[p] then
            if size[ter[j]]>size[k] then k:=ter[j];
        j:=next[j];
    end;
    if k=0 then exit;
    dfs2(k,chain);
    j:=link[p];
    while j<>0 do
    begin
        if (deep[ter[j]]>deep[p])and(k<>ter[j]) then dfs2(ter[j],ter[j]);
        j:=next[j];
    end;
end;

procedure build(p,l,r:longint);
var mid:longint;
begin
    tr[p].l:=l;tr[p].r:=r;tr[p].cover:=1;tr[p].lc:=-1;tr[p].rc:=-1;
    if l=r then exit;
    mid:=(l+r) >> 1;
    build(p << 1,l,mid);
    build(p << 1+1,mid+1,r);
end;

procedure insert(p,l,r,x:longint);
var mid:longint;
begin
    if (tr[p].l=l)and(tr[p].r=r) then
    begin
        tr[p].cover:=1;tr[p].lc:=x;tr[p].rc:=x;tr[p].wait:=true;
        exit;
    end;
    mid:=(tr[p].l+tr[p].r) >> 1;
    if tr[p].wait then
    begin
        tr[p << 1].cover:=1;tr[p << 1].lc:=tr[p].lc;tr[p << 1].rc:=tr[p].rc;
        tr[p << 1+1].cover:=1;tr[p << 1+1].lc:=tr[p].lc;tr[p << 1+1].rc:=tr[p].rc;
        tr[p].wait:=false;
                tr[p << 1].wait:=true;tr[p << 1+1].wait:=true;
    end;
    if r<=mid then insert(p << 1,l,r,x) else
        if l>mid then insert(p << 1+1,l,r,x) else
        begin
            insert(p << 1,l,mid,x);
            insert(p << 1+1,mid+1,r,x);
        end;
        tr[p].cover:=tr[p << 1].cover+tr[p << 1+1].cover;
    if tr[p << 1].rc=tr[p << 1+1].lc then dec(tr[p].cover);
    tr[p].lc:=tr[p << 1].lc;tr[p].rc:=tr[p << 1+1].rc;
end;

function lca(x,y:longint):longint;
var tem,i:longint;
begin
    if deep[x]<deep[y] then
    begin
        tem:=x;x:=y;y:=tem;
    end;
    if deep[x]>deep[y] then
    begin
        i:=trunc(ln(deep[x]-deep[y])/ln(2));
        while deep[x]<>deep[y] do
        begin
            while (deep[x]-deep[y]>=1 << i) do x:=fa[x][i];
            dec(i);
        end;
    end;
    if x=y then exit(x);
    i:=trunc(ln(n)/ln(2));
    while fa[x][0]<>fa[y][0] do
    begin
        while (fa[x][i]<>fa[y][i]) do
        begin
            x:=fa[x,i];y:=fa[y][i];
        end;
        dec(i);
    end;
    exit(fa[x,0]);
end;

procedure query(p,l,r:longint;var lc,rc,ave:longint);
var mid,ll,rr,x,tem:longint;
begin
    if (tr[p].l=l)and(tr[p].r=r) then
    begin
        lc:=tr[p].lc;rc:=tr[p].rc;
        ave:=tr[p].cover;
                exit;
        end;
    mid:=(tr[p].l+tr[p].r) >> 1;
        if tr[p].wait then
    begin
        tr[p << 1].cover:=1;tr[p << 1].lc:=tr[p].lc;tr[p << 1].rc:=tr[p].rc;
        tr[p << 1+1].cover:=1;tr[p << 1+1].lc:=tr[p].lc;tr[p << 1+1].rc:=tr[p].rc;
        tr[p].wait:=false;
                tr[p << 1].wait:=true;tr[p << 1+1].wait:=true;
    end;
    if r<=mid then query(p << 1,l,r,lc,rc,ave) else
        if l>mid then query(p << 1+1,l,r,lc,rc,ave) else
        begin
            query(p << 1,l,mid,lc,ll,ave);
            query(p << 1+1,mid+1,r,rr,rc,tem);
                        inc(ave,tem);
            if ll=rr then dec(ave);
        end;
end;

function solve(x,y:longint):longint;
var r,sum,lc,rc,tem:longint;
begin
    r:=-1;sum:=0;
    while belong[x]<>belong[y] do
    begin
        query(1,pos[belong[x]],pos[x],lc,rc,tem);
                inc(sum,tem);
        if rc=r then dec(sum);
        r:=lc;
        x:=fa[belong[x]][0];
    end;
    query(1,pos[y],pos[x],lc,rc,tem);
        inc(sum,tem);
    if rc=r then dec(sum);
    exit(sum);
end;

procedure mend(x,y,z:longint);
begin
    while belong[x]<>belong[y] do
    begin
        insert(1,pos[belong[x]],pos[x],z);
        x:=fa[belong[x]][0];
    end;
    insert(1,pos[y],pos[x],z);
end;

begin
    readln(n,m);
    for i:=1 to n do read(v[i]);readln;
    j:=0;
    for i:=1 to n-1 do
    begin
        readln(x,y);
        add(x,y);
    end;
        deep[1]:=1;
    dfs1(1);
    cnt:=0;dfs2(1,1);
        build(1,1,n);
    for i:=1 to n do insert(1,pos[i],pos[i],v[i]);
        for i:=1 to m do
    begin
        read(ch);
        if ch='C' then
        begin
            readln(x,y,z);
            t:=lca(x,y);
            mend(x,t,z);mend(y,t,z);
        end else
        begin
            readln(x,y);
                z:=lca(x,y);
                writeln(solve(x,z)+solve(y,z)-1);
        end;
    end;
end.

 [UPD.5.11]今天复习树链剖分,挑了这道题写了写。一个地方WA检查了三遍对拍了一个多小时才找出来...

就是在判断链之间答案合并的时候是从右往左的而不是从左往右...习惯性的写法干脆检查的时候都忽略了这一点...

另外这次用C++大概150行左右..感觉还不错

 

  1 #include<cstdio>
  2 #include<cstdlib>
  3 #include<cmath>
  4 #include<cstring>
  5 #define maxn 1000010
  6 #define maxm 2000010
  7 struct node{
  8     int l,r,lc,rc,count,wait;
  9 }tr[maxn*4];
 10 int next[maxm],ter[maxm],link[maxn],pos[maxn],belong[maxn],dep[maxn];
 11 int fa[maxn][25],sz[maxn],n,m,e,v[maxn],cnt;
 12 char ch[100];
 13 void swap(int &x,int &y){
 14     int tmp=x;x=y;y=tmp;
 15 }
 16 void add(int x,int y){
 17     ter[++e]=y;next[e]=link[x];link[x]=e;
 18     ter[++e]=x;next[e]=link[y];link[y]=e;
 19 }
 20 void dfs1(int p){
 21     sz[p]=1;
 22     for (int i=1;i<=20&&dep[p]>(1<<i);i++) fa[p][i]=fa[fa[p][i-1]][i-1];    
 23     for (int i=link[p];i;i=next[i]) if (dep[ter[i]]==0){
 24         dep[ter[i]]=dep[p]+1;
 25         fa[ter[i]][0]=p;
 26         dfs1(ter[i]);
 27         sz[p]+=sz[ter[i]];
 28     }
 29 }
 30 void dfs2(int p,int chain){
 31     pos[p]=++cnt;belong[p]=chain;
 32     int k=0;
 33     for (int i=link[p];i;i=next[i]) if (dep[ter[i]]==dep[p]+1&&sz[ter[i]]>sz[k]) k=ter[i];
 34     if (k==0) return;
 35     dfs2(k,chain);
 36     for (int i=link[p];i;i=next[i]) if (dep[ter[i]]==dep[p]+1&&ter[i]!=k) dfs2(ter[i],ter[i]);
 37 }
 38 void update(int p){
 39     tr[p].lc=tr[p<<1].lc;
 40     tr[p].rc=tr[(p<<1)+1].rc;
 41     tr[p].count=tr[p<<1].count+tr[(p<<1)+1].count;
 42     if (tr[p<<1].rc==tr[(p<<1)+1].lc) tr[p].count--;
 43 }
 44 void push(int p){
 45     if (tr[p].wait==-1) return;
 46     int u=p<<1,w=(p<<1)+1,col=tr[p].wait;
 47     tr[u].lc=tr[u].rc=tr[u].wait=col;tr[u].count=1;
 48     tr[w].lc=tr[w].rc=tr[w].wait=col;tr[w].count=1;
 49     tr[p].wait=-1;
 50 }
 51 void build(int p,int l,int r){
 52     tr[p].l=l;tr[p].r=r;tr[p].wait=-1;
 53     if (l==r) return;
 54     int mid=(l+r) >> 1;
 55     build(p<<1,l,mid);build((p<<1)+1,mid+1,r);
 56 }
 57 int lca(int x,int y){
 58     if (dep[x]<dep[y]) swap(x,y);
 59     int i;
 60     if (dep[x]>dep[y]){
 61         i=(int)(log(dep[x]-dep[y])/log(2));
 62         while (dep[x]>dep[y]){            
 63             while (dep[x]-dep[y]>=1<<i) x=fa[x][i];
 64             i--;    
 65         }
 66     }
 67     if (x==y) return x;
 68     i=(int)(log(n)/log(2));
 69     while (fa[x][0]!=fa[y][0]){
 70         while (fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
 71         i--;
 72     }
 73     return fa[x][0];
 74 }
 75 void insert(int p,int l,int r,int col){    
 76     if (tr[p].l==l&&tr[p].r==r){
 77         tr[p].lc=col;tr[p].rc=col;tr[p].wait=col;tr[p].count=1;
 78         return;
 79     }
 80     push(p);
 81     int mid=(tr[p].l+tr[p].r)>>1;
 82     if (r<=mid) insert(p<<1,l,r,col);else
 83         if (l>mid) insert((p<<1)+1,l,r,col);else{
 84             insert(p<<1,l,mid,col);insert((p<<1)+1,mid+1,r,col);
 85         }
 86     update(p);
 87 }
 88 void query(int p,int l,int r,int &lc,int &rc,int &ans){    
 89     if (tr[p].l==l&&tr[p].r==r){
 90         lc=tr[p].lc;rc=tr[p].rc;ans=tr[p].count;
 91         return;
 92     }
 93     push(p);
 94     int mid=(tr[p].l+tr[p].r)>>1;
 95     if (r<=mid) query(p<<1,l,r,lc,rc,ans);else
 96         if (l>mid) query((p<<1)+1,l,r,lc,rc,ans);else
 97         {
 98             int ll,rr,ans1;
 99             query(p<<1,l,mid,lc,rr,ans);
100             query((p<<1)+1,mid+1,r,ll,rc,ans1);
101             ans+=ans1;if (ll==rr) ans--;
102         }
103     update(p);
104 }
105 void change(int x,int y,int col){
106     while (belong[x]!=belong[y]){
107         insert(1,pos[belong[x]],pos[x],col);
108         x=fa[belong[x]][0];
109     }
110     insert(1,pos[y],pos[x],col);
111 }
112 int solve(int x,int y){
113     int sum=0,r=-1,ll,rr,tmp;
114     while (belong[x]!=belong[y]){
115         query(1,pos[belong[x]],pos[x],ll,rr,tmp);                
116         x=fa[belong[x]][0];
117         sum+=tmp;if (rr==r) sum--;
118         r=ll;
119     }
120     query(1,pos[y],pos[x],ll,rr,tmp);
121     sum+=tmp;if (rr==r) sum--;
122     return sum;
123 }
124 int main(){
125     scanf("%d%d",&n,&m);
126     for (int i=1;i<=n;i++) scanf("%d",&v[i]);
127     e=0;
128     int x,y,col,t;
129     for (int i=1;i<=n-1;i++) scanf("%d%d",&x,&y),add(x,y);
130     dep[1]=1;dfs1(1);
131     cnt=0;dfs2(1,1);
132     build(1,1,n);
133     for (int i=1;i<=n;i++) insert(1,pos[i],pos[i],v[i]);
134     for (int i=1;i<=m;i++){        
135         scanf("%s",ch);
136         if (ch[0]=='C'){
137             scanf("%d%d%d",&x,&y,&col);
138             t=lca(x,y);
139             change(x,t,col);change(y,t,col);
140         }else{
141             scanf("%d%d",&x,&y);            
142             t=lca(x,y);
143             printf("%d\n",solve(x,t)+solve(y,t)-1);
144         }
145     }
146     return 0;
147 }

 

 

 

posted @ 2015-04-11 08:17  mjy0724  阅读(211)  评论(0编辑  收藏  举报