BZOJ - 2243 染色 (树链剖分+线段树+区间合并)

题目链接

线段树维护区间连续段个数即可。设lc为区间左端点颜色,rc为区间右端点颜色,则合并两区间的时候,如果左区间右端点和右区间左端点颜色相同,则连续段个数-1。

在树链上的区间合并可以定义一个结构体作为线段,分成左右两条链暴力合并。也可以考虑到树上的路径中每两个树链“断开”的地方必然有一个结点是另一个结点的祖先,因此如果top[u]的颜色与fa[top[u]]的颜色相同时答案-1即可。

树剖和线段树结合真容易把人搞晕啊,什么时候要用l,r,什么时候要用u,什么时候要用dfn[u],一定要分清楚~~

 1 #include<bits/stdc++.h>
 2 using namespace std;
 3 typedef long long ll;
 4 const int N=1e5+10,inf=0x3f3f3f3f;
 5 int hd[N],ne,n,k,fa[N],son[N],siz[N],dep[N],top[N],dfn[N],rnk[N],tot,a[N],cnt[N<<3],mk[N<<3],lc[N<<3],rc[N<<3];
 6 struct E {int v,nxt;} e[N<<1];
 7 void addedge(int u,int v) {e[ne]= {v,hd[u]},hd[u]=ne++;}
 8 void dfs1(int u,int f,int d) {
 9     fa[u]=f,fa[u]=f,siz[u]=1,dep[u]=d;
10     for(int i=hd[u]; ~i; i=e[i].nxt) {
11         int v=e[i].v;
12         if(v==fa[u])continue;
13         dfs1(v,u,d+1),siz[u]+=siz[v];
14         if(siz[v]>siz[son[u]])son[u]=v;
15     }
16 }
17 void dfs2(int u,int tp) {
18     top[u]=tp,dfn[u]=++tot,rnk[dfn[u]]=u;
19     if(!son[u])return;
20     dfs2(son[u],top[u]);
21     for(int i=hd[u]; ~i; i=e[i].nxt) {
22         int v=e[i].v;
23         if(v==fa[u]||v==son[u])continue;
24         dfs2(v,v);
25     }
26 }
27 #define ls (u<<1)
28 #define rs (u<<1|1)
29 #define mid ((l+r)>>1)
30 void pu(int u) {lc[u]=lc[ls],rc[u]=rc[rs],cnt[u]=cnt[ls]+cnt[rs]; if(rc[ls]==lc[rs])cnt[u]--;}
31 void pd(int u) {if(mk[u])lc[u]=rc[u]=mk[u],cnt[u]=1,mk[ls]=mk[rs]=mk[u],mk[u]=0;}
32 void build(int u=1,int l=1,int r=tot) {
33     if(l==r) {lc[u]=rc[u]=a[rnk[l]],cnt[u]=1; return;}
34     build(ls,l,mid),build(rs,mid+1,r),pu(u);
35 }
36 void upd(int L,int R,int x,int u=1,int l=1,int r=tot) {
37     pd(u);
38     if(l>=L&&r<=R) {mk[u]=x,pd(u); return;}
39     if(l>R||r<L)return;
40     upd(L,R,x,ls,l,mid),upd(L,R,x,rs,mid+1,r),pu(u);
41 }
42 int getcol(int p,int u=1,int l=1,int r=tot) {
43     pd(u);
44     if(l==r)return lc[u];
45     return p<=mid?getcol(p,ls,l,mid):getcol(p,rs,mid+1,r);
46 }
47 int qry(int L,int R,int u=1,int l=1,int r=tot) {
48     pd(u);
49     if(l>=L&&r<=R)return cnt[u];
50     if(l>R||r<L)return 0;
51     int t1=qry(L,R,ls,l,mid),t2=qry(L,R,rs,mid+1,r);
52     int ret=t1+t2;
53     if(t1&&t2&&rc[ls]==lc[rs])ret--;
54     return ret;
55 }
56 void change(int u,int v,int x) {
57     for(; top[u]!=top[v]; u=fa[top[u]]) {
58         if(dep[top[u]]<dep[top[v]])swap(u,v);
59         upd(dfn[top[u]],dfn[u],x);
60     }
61     if(dep[u]<dep[v])swap(u,v);
62     upd(dfn[v],dfn[u],x);
63 }
64 int ask(int u,int v) {
65     int ret=0;
66     for(; top[u]!=top[v]; u=fa[top[u]]) {
67         if(dep[top[u]]<dep[top[v]])swap(u,v);
68         ret+=qry(dfn[top[u]],dfn[u]);
69         if(getcol(dfn[top[u]])==getcol(dfn[fa[top[u]]]))ret--;
70     }
71     if(dep[u]<dep[v])swap(u,v);
72     ret+=qry(dfn[v],dfn[u]);
73     return ret;
74 }
75 int main() {
76     memset(hd,-1,sizeof hd),ne=0;
77     scanf("%d%d",&n,&k);
78     for(int i=1; i<=n; ++i)scanf("%d",&a[i]),a[i]++;
79     for(int i=1; i<n; ++i) {
80         int u,v;
81         scanf("%d%d",&u,&v);
82         addedge(u,v);
83         addedge(v,u);
84     }
85     tot=0,dfs1(1,0,0),dfs2(1,1),build();
86     while(k--) {
87         char ch;
88         int a,b,c;
89         scanf(" %c%d%d",&ch,&a,&b);
90         if(ch=='Q')printf("%d\n",ask(a,b));
91         else scanf("%d",&c),c++,change(a,b,c);
92     }
93     return 0;
94 }

 

posted @ 2019-03-25 21:44  jrltx  阅读(204)  评论(0编辑  收藏  举报