讲得挺好的
#include <bits/stdc++.h>
using namespace std;
const int N = 80005;
int tot = 0, lsh[N], w[N], len, n, m, eler[N], cnt = 0, er[N], el[N], dfn[N], f[N][15];
int vis[N], sum = 0, col[N], ans[N], dep[N];
int sm[N << 2];
vector<int> g[N];
struct Que
{
int l, r, id, lc;
bool operator < (const Que &a) const
{
int x = l / len, y = a.l / len;
if(x != y) return x < y;
else
{
if(x & 1) return r > a.r;
return r < a.r;
}
}
}q[100005];
void dfs(int x, int fa)
{
eler[++ cnt] = x, el[x] = cnt, dfn[x] = cnt;
f[x][0] = fa, dep[x] = dep[fa] + 1;
for (int i = 1; i <= 14; ++ i) f[x][i] = f[f[x][i - 1]][i - 1];
for (int i = 0; i < g[x].size(); ++ i)
{
int y = g[x][i];
if(y == fa) continue;
dfs(y, x);
}
eler[++ cnt] = x, er[x] = cnt;
return ;
}
int lca(int x, int y)
{
if(dep[x] < dep[y]) swap(x, y);
for (int i = 14; i >= 0; -- i)
{
if(dep[f[x][i]] >= dep[y]) x = f[x][i];
}
if(x == y) return x;
for (int i = 14; i >= 0; -- i)
{
if(f[x][i] != f[y][i]) x = f[x][i], y = f[y][i];
}
return f[x][0];
}
void add(int pos)
{
int x = eler[pos];
vis[x] ++;
if(vis[x] & 1)
{
col[w[x]] ++;
if(col[w[x]] == 1) sum ++;
}
else
{
col[w[x]] --;
if(col[w[x]] == 0) sum --;
}
return ;
}
void del(int pos)
{
int x = eler[pos];
vis[x] --;
if(vis[x] & 1)
{
col[w[x]] ++;
if(col[w[x]] == 1) sum ++;
}
else
{
col[w[x]] --;
if(col[w[x]] == 0) sum --;
}
return ;
}
int main()
{
scanf("%d %d", &n, &m);
for (int i = 1; i <= n; ++ i) scanf("%d", &w[i]), lsh[i] = w[i];
sort(lsh + 1, lsh + n + 1);
tot = unique(lsh + 1, lsh + n + 1) - lsh - 1;
for (int i = 1; i <= n; ++ i) w[i] = lower_bound(lsh + 1, lsh + tot + 1, w[i]) - lsh;
int u, v, l, r;
for (int i = 1; i <= n - 1; ++ i)
{
scanf("%d %d", &u, &v);
g[u].push_back(v);
g[v].push_back(u);
}
dfs(1, 0);
for (int i = 1; i <= m; ++ i)
{
scanf("%d %d", &l, &r);
int lc = lca(l, r);
if(dfn[l] > dfn[r]) swap(l, r);
q[i].lc = lc, q[i].id = i;
if(lc == l) q[i].l = el[l], q[i].r = el[r];
else q[i].l = er[l], q[i].r = el[r];
}
len = (int)sqrt(cnt);
sort(q + 1, q + m + 1);
l = 1, r = 0;
for (int i = 1; i <= m; ++ i)
{
while(l < q[i].l) del(l), l ++;
while(l > q[i].l) l --, add(l);
while(r < q[i].r) r ++, add(r);
while(r > q[i].r) del(r), r --;
ans[q[i].id] = sum;
if(col[w[q[i].lc]] == 0 && eler[q[i].l] != q[i].lc) ans[q[i].id] ++;
}
for (int i = 1; i <= m; ++ i) printf("%d\n", ans[i]);
return 0;
}