莫队

「SPOJ10707」COT2 - Count on a tree II

讲得挺好的

#include <bits/stdc++.h>
using namespace std;
const int N = 80005;
int tot = 0, lsh[N], w[N], len, n, m, eler[N], cnt = 0, er[N], el[N], dfn[N], f[N][15];
int vis[N], sum = 0, col[N], ans[N], dep[N];
int sm[N << 2];
vector<int> g[N];
struct Que
{
	int l, r, id, lc;
	bool operator < (const Que &a) const 
	{
		int x = l / len, y = a.l / len;
		if(x != y) return x < y;
		else 
		{
			if(x & 1) return r > a.r;
			return r < a.r;
		}
	}
}q[100005];
void dfs(int x, int fa)
{
	eler[++ cnt] = x, el[x] = cnt, dfn[x] = cnt;
	f[x][0] = fa, dep[x] = dep[fa] + 1;
	for (int i = 1; i <= 14; ++ i) f[x][i] = f[f[x][i - 1]][i - 1];
	for (int i = 0; i < g[x].size(); ++ i)
	{
		int y = g[x][i];
		if(y == fa) continue;
		dfs(y, x);
	}
	eler[++ cnt] = x, er[x] = cnt;
	return ;
}
int lca(int x, int y)
{
	if(dep[x] < dep[y]) swap(x,  y);
	for (int i = 14; i >= 0; -- i)
	{
		if(dep[f[x][i]] >= dep[y]) x = f[x][i];
	}
	if(x == y) return x;
	for (int i = 14; i >= 0; -- i)
	{
		if(f[x][i] != f[y][i]) x = f[x][i], y = f[y][i];
	}
	return f[x][0];
}
void add(int pos)
{
	int x = eler[pos];
	vis[x] ++;
	if(vis[x] & 1)
	{
		col[w[x]] ++;
		if(col[w[x]] == 1) sum ++;
	}
	else
	{
		col[w[x]] --;
		if(col[w[x]] == 0) sum --;
	}
	return ;
}
void del(int pos)
{
	int x = eler[pos];
	vis[x] --;
	if(vis[x] & 1)
	{
		col[w[x]] ++;
		if(col[w[x]] == 1) sum ++;
	}
	else
	{
		col[w[x]] --;
		if(col[w[x]] == 0) sum --;
	}
	return ;
}
int main()
{
	scanf("%d %d", &n, &m);
	for (int i = 1; i <= n; ++ i) scanf("%d", &w[i]), lsh[i] = w[i];
	sort(lsh + 1, lsh + n + 1);
	tot = unique(lsh + 1, lsh + n + 1) - lsh - 1;
	for (int i = 1; i <= n; ++ i) w[i] = lower_bound(lsh + 1, lsh + tot + 1, w[i]) - lsh;
	int u, v, l, r;
	for (int i = 1; i <= n - 1; ++ i)
	{
		scanf("%d %d", &u, &v);
		g[u].push_back(v);
		g[v].push_back(u);
	}
	dfs(1, 0);
	for (int i = 1; i <= m; ++ i) 
	{
		scanf("%d %d", &l, &r);
		int lc = lca(l, r);
		if(dfn[l] > dfn[r]) swap(l, r);
		q[i].lc = lc, q[i].id = i;
		if(lc == l) q[i].l = el[l], q[i].r = el[r];
		else q[i].l = er[l], q[i].r = el[r];
	}
	len = (int)sqrt(cnt);
	sort(q + 1, q + m + 1);
	l = 1, r = 0;
	for (int i = 1; i <= m; ++ i)
	{
		while(l < q[i].l) del(l), l ++;
		while(l > q[i].l) l --, add(l);
		while(r < q[i].r) r ++, add(r);
		while(r > q[i].r) del(r), r --;
		ans[q[i].id] = sum;
		if(col[w[q[i].lc]] == 0 && eler[q[i].l] != q[i].lc) ans[q[i].id] ++;
	}
	for (int i = 1; i <= m; ++ i) printf("%d\n", ans[i]);
	return  0;
}

posted @ 2025-02-13 10:36  Helioca  阅读(11)  评论(0)    收藏  举报
Document