# Count on a Tree II

### Count on a Tree II

#include "iostream"
#include "algorithm"
#include "cstring"
#include "cstdio"
#include "cmath"
using namespace std;
#define MAXN 40006
#define B 206
#define swap( a , b ) ( (a) ^= (b) , (b) ^= (a) , (a) ^= (b) )
int n , m , blo , sz , bl;
int c[MAXN] , C[MAXN];
int head[MAXN] , to[MAXN << 1] , nex[MAXN << 1] , ecn;
void ade( int u , int v ) {
to[++ ecn] = v , nex[ecn] = head[u] , head[u] = ecn;
}
int ps[B] , im[MAXN] , pre[B][B] , cid , ee;
int w[MAXN] , an;
void dfs( int u , int fa ) {
if( !w[c[u]] ) ++ an;
++ w[c[u]];
if( im[u] ) pre[cid][im[u]] = an;
for( int i = head[u] ; i ; i = nex[i] ) {
int v = to[i];
if( v == fa ) continue;
dfs( v , u );
}
-- w[c[u]];
if( !w[c[u]] ) -- an;
}

int kk[MAXN << 2][B] , idx = 0 , rt[MAXN];
void build( ) {
idx = rt[0] = 1;
for( int i = 1 ; ( i - 1 ) * bl < sz ; ++ i ) kk[1][i] = i + 1 , ++ idx;
}
void add( int rt , int old , int x ) {
memcpy( kk[rt] , kk[old] , sizeof kk[old] );
++ idx;
memcpy( kk[idx] , kk[kk[rt][( x - 1 ) / bl + 1]] , sizeof kk[1] );
kk[rt][( x - 1 ) / bl + 1] = idx;
++ kk[idx][( x - 1 ) % bl + 1];
}
int que( int rt , int x ) {
int t = kk[rt][( x - 1 ) / bl + 1];
return kk[t][( x - 1 ) % bl + 1];
}

int dfn[MAXN << 1] , dep[MAXN << 1] , pos[MAXN << 1] , en = 0 , d[MAXN] , par[MAXN];
int dfs1( int u , int fa ) {
dfn[++en] = u , pos[u] = en , dep[en] = d[u] , par[u] = fa;
rt[u] = ++ idx;
add( rt[u] , rt[fa] , c[u] );
int re = 0;
for( int i = head[u] ; i ; i = nex[i] ) {
int v = to[i];
if( v == fa ) continue;
d[v] = d[u] + 1;
re = max( re , dfs1( v , u ) );
dfn[++ en] = u , dep[en] = d[u];
}
if( re - d[u] >= blo && d[u] % blo == 0 ) ps[++ ee] = u , im[u] = ee;
return re ? re : d[u];
}
int lg[MAXN << 1];
int st[MAXN << 1][17];
void ST() {
for( int i = 1 ; i <= en ; ++ i ) lg[i] = lg[i - 1] + (1 << lg[i - 1] == i);
for( int i = 1 ; i <= en ; ++ i ) st[i][0] = i;
for( int i = 1 ; ( 1 << i ) <= en ; ++ i )
for (int j = 1; j + (1 << i) - 1 <= en; j++)
st[j][i] = dep[st[j][i - 1]] < dep[st[j + (1 << (i - 1))][i - 1]] ? st[j][i - 1] : st[j + (1 << (i - 1))][i - 1];
}
int lca( int l , int r ) {
l = pos[l] , r = pos[r];
if( l > r ) swap( l , r );
int k = lg[r - l + 1] - 1;
return dep[st[l][k]] <= dep[st[r-(1<<k)+1][k]] ? dfn[st[l][k]] : dfn[st[r-(1<<k)+1][k]];
}
int col[MAXN] , cn , l;
int jump( int u ) {
if( im[u] ) return im[u];
col[++ cn] = c[u];
if( u == l ) return -1;
return jump( par[u] );
}
int rejjump( int u ) {
if( im[u] ) return im[u];
col[++ cn] = c[u];
return rejjump( par[u] );
}
int oc[MAXN];
int main() {
//    freopen("10.in","r",stdin);
//    freopen("ot","w",stdout);
cin >> n >> m;
blo = sqrt( n );
for( int i = 1 ; i <= n ; ++ i ) scanf("%d",&c[i]) , C[i] = c[i];
sort( C + 1 , C + 1 + n ); sz = unique( C + 1 , C + 1 + n ) - C - 1;
bl = sqrt( sz );
for( int i = 1 ; i <= n ; ++ i ) c[i] = lower_bound( C + 1 , C + 1 + sz , c[i] ) - C;
for( int i = 1 , u , v ; i < n ; ++ i ) {
scanf("%d%d",&u,&v);
}
build( );
dfs1( 1 , 0 );
//    cout << que( rt[3] , 3 ) << endl;
ST( );
for( int i = 1 ; i <= blo ; ++ i ) {
int p = ps[i]; cid = i;
dfs( p , p );
}
int u , v , psu , psv , flg = 0 , re , last = 0 , t , pr;
while( m-- ) {
scanf("%d%d",&u,&v);
u ^= last;
//        cout << u << ' ' << v << endl;
l = lca( u , v );
cn = re = 0;
psu = jump( u ) , psv = jump( v );
if( psu == psv ) {
oc[c[l]] = 1 , ++ re;
for( int i = 1 ; i <= cn ; ++ i ) if( !oc[col[i]] ) oc[col[i]] = 1 , ++ re;
for( int i = 1 ; i <= cn ; ++ i ) oc[col[i]] = 0;
oc[c[l]] = 0;
} else if( ~psu && ~psv ) {
re = pre[psu][psv];
for( int i = 1 ; i <= cn ; ++ i ) if( !oc[col[i]] ) {
oc[col[i]] = 1;
if( que( rt[ps[psu]] , col[i] ) + que( rt[ps[psv]] , col[i] ) - que( rt[par[l]] , col[i] ) * 2 == 0 )
++ re;
}
for( int i = 1 ; i <= cn ; ++ i ) oc[col[i]] = 0;
} else {
if( ~psu ) swap( u , v ) , swap( psu , psv );
pr = l; t = cn;
while( !im[pr] ) {
pr = par[pr];
if( !oc[c[pr]] ) {
oc[c[pr]] = 1;
if (que(rt[ps[psv]], c[pr]) - que(rt[par[l]], c[pr]) == 0)
--re;
col[++ cn] = c[pr];
}
}
re += pre[im[pr]][psv];
for( int i = t + 1 ; i <= cn ; ++ i ) oc[col[i]] = 0;
cn = t;
for( int i = 1 ; i <= cn ; ++ i ) if( !oc[col[i]] ) {
oc[col[i]] = 1;
if( que( rt[ps[psv]] , col[i] ) - que( rt[par[l]] , col[i] ) == 0 )
++ re;
}
for( int i = 1 ; i <= cn ; ++ i ) oc[col[i]] = 0;
}
printf("%d\n",last = re);
}
}


posted @ 2020-02-28 17:28  yijan  阅读(141)  评论(0编辑  收藏  举报