2020 hdu多校赛 第三场 1003 Tokitsukaze and Colorful Tree

题意:

给你一棵树(n<=1e5),每个节点有颜色col[i]<=n,和权值val[i]<2^20,

每次修改一个节点的权值或颜色,求每次修改之后每个节点与不是他的祖先或在他子树内的且颜色相同的点的权值异或之和。

首先,我们考虑简化问题,如果没有颜色、祖先和子树限制,单纯求各个点对的异或值之和应该怎么求?

很简单,我们设sum[i]为二进制第 i 位为1的数有多少个,然后枚举每个点的权值在着一位为1还是0即可。

那么加上颜色限制呢?

设sum[x][i]为颜色x的点里面二进制第 i 位为 1 的数有多少个,其他不变。

让我们再加上祖先限制。

此时,一个点会对他子树内所有点产生影响,我们就可以考虑在 dfs 序上用线段树差分维护sum[x][i],对 x 点查询的时候直接看1~x的区间和就可以知道x的祖先的影响了。

最后,让我们加上子树限制。

与上文同理,我们用dfs序维护,直接区间求和即可。

(出题人竟然卡我内存,QAQ,赌了一把改成short才过,不然不是MLE就是RE

  1 #include<iostream>
  2 #include<cstdlib>
  3 #include<cstring>
  4 #include<algorithm>
  5 #include<cmath>
  6 #include<cstdio>
  7 #define N 100005
  8 using namespace std;
  9 int T,n,zz,a[N];
 10 struct ro{
 11     int to,next;
 12 }road[N*2];
 13 void build(int x,int y)
 14 {
 15     zz++;
 16     road[zz].to=y;
 17     road[zz].next=a[x];
 18     a[x]=zz;
 19 }
 20 int fa[N],l[N],r[N],dfn[N],zz1,dl[N];
 21 void dfs(int x)
 22 {
 23     zz1++;
 24     dfn[x]=zz1;
 25     dl[zz1]=x;
 26     l[x]=zz1;
 27     for(int i=a[x];i;i=road[i].next)
 28     {
 29         int y=road[i].to;
 30         if(y==fa[x])continue;
 31         fa[y]=x;
 32         dfs(y);
 33     }
 34     r[x]=zz1;
 35 }
 36 int val[N],col[N],zz2;
 37 struct no{
 38     int l,r,size[2];
 39     short da[2][20];
 40 }node[int(N*30)];
 41 int A[N],sm[2];
 42 void add(int l,int r,int x,int to,int op,int dat)
 43 {
 44     if(l==r)
 45     {
 46         node[x].size[op]+=dat;
 47         for(int i=0;i<20;i++)
 48         {
 49             node[x].da[op][i]+=A[i];
 50         }
 51         return;
 52     }
 53     int mid=(l+r)>>1;
 54     if(to>mid)
 55     {
 56         if(!node[x].r)
 57         {
 58             zz2++;
 59             node[zz2].l=node[zz2].r=0;
 60             node[zz2].size[0]=node[zz2].size[1]=0;
 61             memset(node[zz2].da,0,sizeof(node[zz2].da));
 62             node[x].r=zz2;
 63         }
 64         add(mid+1,r,node[x].r,to,op,dat);
 65     }
 66     else
 67     {
 68         if(!node[x].l)
 69         {
 70             zz2++;
 71             node[zz2].l=node[zz2].r=0;
 72             node[zz2].size[0]=node[zz2].size[1]=0;
 73             memset(node[zz2].da,0,sizeof(node[zz2].da));
 74             node[x].l=zz2;
 75         }
 76         add(l,mid,node[x].l,to,op,dat);
 77     }
 78     node[x].size[op]=node[node[x].l].size[op]+node[node[x].r].size[op];
 79     for(int i=0;i<20;i++)
 80     {
 81         node[x].da[op][i]=node[node[x].l].da[op][i]+node[node[x].r].da[op][i];
 82     }
 83 }
 84 void get(int l,int r,int left,int right,int x,int op,int op2)
 85 {
 86     if(!x)return ;
 87     if(left>right)return;
 88     if(l==left&&r==right)
 89     {
 90         sm[op]+=node[x].size[op]*op2;
 91         for(int i=0;i<20;i++)
 92         {
 93             A[i]+=node[x].da[op][i]*op2;
 94         }
 95         return;
 96     }
 97     int mid=(l+r)>>1;
 98     if(left>mid)
 99     {
100         get(mid+1,r,left,right,node[x].r,op,op2);
101     }
102     else if(right<=mid)
103     {
104         get(l,mid,left,right,node[x].l,op,op2);
105     }
106     else
107     {
108         get(l,mid,left,mid,node[x].l,op,op2);
109         get(mid+1,r,mid+1,right,node[x].r,op,op2);
110     }
111 }
112 int cnt[N],root[N],sum[N][20];
113 void del(int x)
114 {
115     for(int i=0;i<20;i++)
116     {
117         if(val[x]&(1<<i)) A[i]=-1,sum[col[x]][i]--;
118         else A[i]=0;
119     }
120     add(1,n+1,root[col[x]],dfn[x],0,-1);
121     add(1,n+1,root[col[x]],l[x],1,-1);
122     for(int i=0;i<20;i++)
123     {
124         if(val[x]&(1<<i)) A[i]=1;
125         else A[i]=0;
126     }
127     add(1,n+1,root[col[x]],r[x]+1,1,1);
128     cnt[col[x]]--;
129 }
130 long long work(int x)
131 {
132     for(int i=0;i<20;i++) A[i]=0;
133     sm[0]=sm[1]=0;
134     get(1,n+1,l[x],r[x],root[col[x]],0,1);
135     get(1,n+1,1,dfn[x],root[col[x]],1,1);
136     long long ans=0;
137     for(int i=0;i<20;i++)
138     {
139         if(val[x]&(1<<i))
140         {
141             ans+=1ll*((cnt[col[x]]-sum[col[x]][i])-(sm[0]+sm[1]-A[i]))*(1<<i);
142         }
143         else
144         {
145             ans+=1ll*(sum[col[x]][i]-A[i])*(1<<i);
146         }
147 
148     }
149     return ans;
150 }
151 void ins(int x)
152 {
153     if(!root[col[x]])
154     {
155         zz2++;
156         root[col[x]]=zz2;
157         node[zz2].l=node[zz2].r=0;
158         node[zz2].size[0]=node[zz2].size[1]=0;
159         memset(node[zz2].da,0,sizeof(node[zz2].da));
160     }
161     for(int i=0;i<20;i++)
162     {
163         if(val[x]&(1<<i)) A[i]=1,sum[col[x]][i]++;
164         else A[i]=0;
165     }
166     add(1,n+1,root[col[x]],dfn[x],0,1);
167     add(1,n+1,root[col[x]],l[x],1,1);
168     for(int i=0;i<20;i++)
169     {
170         if(val[x]&(1<<i)) A[i]=-1;
171         else A[i]=0;
172     }
173     add(1,n+1,root[col[x]],r[x]+1,1,-1);
174     cnt[col[x]]++;
175 }
176 int main()
177 {
178 //    freopen("1003.in","r",stdin);
179 //    freopen("1.out","w",stdout);    
180     scanf("%d",&T);
181     while(T--)
182     {
183         scanf("%d",&n);
184         for(int i=1;i<=n;i++) scanf("%d",&col[i]);
185         for(int i=1;i<=n;i++) scanf("%d",&val[i]);
186         zz=0;
187         memset(a,0,sizeof(a));
188         for(int i=1;i<n;i++)
189         {
190             int x,y;
191             scanf("%d%d",&x,&y);
192             build(x,y);
193             build(y,x);
194         }
195         memset(fa,0,sizeof(fa));
196         memset(dl,0,sizeof(dl));
197         memset(l,0,sizeof(l));
198         memset(r,0,sizeof(r));
199         memset(cnt,0,sizeof(cnt)); 
200         memset(sum,0,sizeof(sum));
201         memset(root,0,sizeof(root));
202         zz1=0;
203         dfs(1);
204         
205         zz2=0;
206         long long ans=0;
207         for(int i=1;i<=n;i++)
208         {
209             if(!root[col[i]])
210             {
211                 zz2++;
212                 root[col[i]]=zz2;
213                 node[zz2].l=node[zz2].r=0;
214                 node[zz2].size[0]=node[zz2].size[1]=0;
215                 memset(node[zz2].da,0,sizeof(node[zz2].da));
216             }
217             ans+=work(i);
218             ins(i);
219         }
220         printf("%lld\n",ans);
221         int q;
222         scanf("%d",&q);
223         for(int i=1;i<=q;i++)
224         {
225             int op,x,y;
226             scanf("%d%d%d",&op,&x,&y);
227             del(x);
228             ans-=work(x);
229         //    cout<<i<<' '<<ans<<endl;
230             if(op==1) val[x]=y;
231             else col[x]=y;
232             ans+=work(x);
233             ins(x);
234             printf("%lld\n",ans);
235         }
236     }
237     return 0;
238 }
View Code

 

posted @ 2020-07-31 09:29  Hzoi_joker  阅读(220)  评论(0编辑  收藏  举报