bzoj 4999 This Problem Is Too Simple!

题目大意:

给一颗树,每个节点有个初始值

现在支持以下两种操作:

1. C i x 表示将i节点的值改为x

 

2. Q i j x 表示询问i节点到j节点的路径上有多少个值为x的节点

思路:

首先可以想到树链剖分

虽然颜色的数量看起来很吓人

但是实际上只可能有n+q种颜色

所以我们的线段树只需要像主席树那样去写就可以了

  1 #include<iostream>
  2 #include<cstdio>
  3 #include<cmath>
  4 #include<cstdlib>
  5 #include<cstring>
  6 #include<algorithm>
  7 #include<vector>
  8 #include<queue>
  9 #include<map>
 10 #define inf 2139062143
 11 #define ll long long
 12 #define MAXN 100100
 13 using namespace std;
 14 inline int read()
 15 {
 16     int x=0,f=1;char ch=getchar();
 17     while(!isdigit(ch)) {if(ch=='-') f=-1;ch=getchar();}
 18     while(isdigit(ch)) {x=x*10+ch-'0';ch=getchar();}
 19     return x*f;
 20 }
 21 int n,T,val[MAXN],to[MAXN<<1],fst[MAXN],nxt[MAXN<<1],Cnt,cnt[MAXN],dep[MAXN],fa[MAXN],bl[MAXN],hsh[MAXN],rt[MAXN<<2],clr;
 22 map <int,int> m;
 23 struct data {int ls,rs,sum;}tr[MAXN<<7];
 24 void add(int u,int v) {nxt[++Cnt]=fst[u],fst[u]=Cnt,to[Cnt]=v;}
 25 void dfs(int x)
 26 {
 27     cnt[x]=1;
 28     for(int i=fst[x];i;i=nxt[i])
 29     {
 30         if(to[i]==fa[x]) continue;
 31         fa[to[i]]=x,dep[to[i]]=dep[x]+1;
 32         dfs(to[i]);
 33         cnt[x]+=cnt[to[i]];
 34     }
 35 }
 36 void Dfs(int x,int anc)
 37 {
 38     int hvs=0;hsh[x]=++Cnt,bl[x]=anc;
 39     for(int i=fst[x];i;i=nxt[i])
 40         if(to[i]!=fa[x]&&cnt[hvs]<cnt[to[i]]) hvs=to[i];
 41     if(!hvs) return ;
 42     Dfs(hvs,anc);
 43     for(int i=fst[x];i;i=nxt[i])
 44         if(to[i]!=fa[x]&&to[i]!=hvs) Dfs(to[i],to[i]);
 45 }
 46 void inst(int &k,int l,int r,int x,int w)
 47 {
 48     if(!k) k=++Cnt;
 49     tr[k].sum+=w;
 50     if(l==r) return;
 51     int mid=(l+r)>>1;
 52     if(x<=mid) inst(tr[k].ls,l,mid,x,w);
 53     else inst(tr[k].rs,mid+1,r,x,w);
 54 }
 55 int query(int k,int l,int r,int a,int b)
 56 {
 57     if(!k) return 0;
 58     if(a==l&&r==b) return tr[k].sum;
 59     int mid=(l+r)>>1;
 60     if(b<=mid) return query(tr[k].ls,l,mid,a,b);
 61     else if(a>mid) return query(tr[k].rs,mid+1,r,a,b);
 62     else return query(tr[k].ls,l,mid,a,mid)+query(tr[k].rs,mid+1,r,mid+1,b);
 63 }
 64 int main()
 65 {
 66     n=read(),T=read();int a,b,c,res;char ch[5];
 67     for(int i=1;i<=n;i++) val[i]=read();
 68     for(int i=1;i<n;i++) {a=read(),b=read();add(a,b);add(b,a);}
 69     dep[1]=1;
 70     dfs(1);Cnt=0;
 71     Dfs(1,1);Cnt=0;
 72     for(int i=1;i<=n;i++) 
 73     {
 74         if(!m[val[i]]) m[val[i]]=++clr;
 75         inst(rt[m[val[i]]],1,n,hsh[i],1);
 76     }
 77     while(T--)
 78     {
 79         scanf("%s",ch);a=read(),b=read();
 80         if(ch[0]=='C')
 81         {
 82             inst(rt[m[val[a]]],1,n,hsh[a],-1);
 83             if(!m[b]) m[b]=++clr;
 84             inst(rt[m[b]],1,n,hsh[a],1);
 85             val[a]=b;
 86         }
 87         else
 88         {
 89             c=read(),res=0;
 90             if(!m[c]) {puts("0");continue;}
 91             while(bl[a]!=bl[b])
 92             {
 93                 //cout<<a<<" "<<b<<" "<<res<<endl;
 94                 if(dep[bl[a]]<dep[bl[b]]) swap(a,b);
 95                 res+=query(rt[m[c]],1,n,hsh[bl[a]],hsh[a]);
 96                 a=fa[bl[a]];
 97             }
 98         //cout<<a<<" "<<b<<" "<<res<<endl;
 99             res+=query(rt[m[c]],1,n,min(hsh[a],hsh[b]),max(hsh[a],hsh[b]));
100             printf("%d\n",res);
101         }
102     }
103 }
View Code

 

posted @ 2018-02-25 10:45  jack_yyc  阅读(140)  评论(0编辑  收藏  举报