BZOJ2588 Count on a tree <DFS序+LCA+值域主席树>

Count on a tree

给定一棵N个节点的树,每个点有一个权值,对于M个询问(u,v,k),你需要回答u xor lastans和v这两个节点间第K小的点权。其中lastans是上一个询问的答案,初始为0,即第一个询问的u是明文。

Input
第一行两个整数N,M。
第二行有N个整数,其中第i个整数表示点i的权值。
后面N-1行每行两个整数(x,y),表示点x到点y有一条边。
最后M行每行两个整数(u,v,k),表示一组询问。
Output
M行,表示每个询问的答案。最后一个询问不输出换行符

Sample Input
8 5
105 2 9 3 8 5 7 7
1 2
1 3
1 4
3 5
3 6
3 7
4 8
2 5 1
0 5 2
10 5 3
11 5 4
110 8 2
Sample Output
2
8
9
105
7

Hint
N,M<=100000

标签:DFS序+LCA+值域主席树

这时一个区间第k小的问题,所以可以很自然的想到值域主席树。但是此题将区间移到了树上,在树上套线段树,可以想到DFS序和树链剖分。此题应该是DFS序。
在解区间第k小的时候,对于每次询问区间[a,b],我们需要找到a-1位置的线段树和b位置的线段树,然后递归query的时候用个数相减。对于这道题,我们把每个结点到根的那条链作为一个序列,用区间第k小的方法存储,然后找到u和v的LCA(假定它为t),递归query的时候计算左区间数的个数,即u结点对应线段树左区间数的个数+v结点....数的个数-t结点...数的个数-t的父结点...数的个数。即tmp = tr[tr[u].ls].val+tr[tr[v].ls].val-tr[tr[t].ls].val-tr[tr[fa[t]].ls].val。
写的时候注意强制在线的操作方式和读入数后先离散化。

最后附上AC代码:

#include <iostream>
#include <cstdio>
#include <vector>
#include <algorithm>
#define MAX_N 100000
using namespace std;
int n, m, c[MAX_N+5];
int cnt = 0, root[MAX_N+5];
int tot = 0, map[MAX_N+5];
int anc[MAX_N+5][25], dep[MAX_N+5];
bool vis[MAX_N+500];
vector <int> G[MAX_N+500];
struct Pre {int id, val;} pre[MAX_N+5];
struct TNode {int ls, rs, val;}	tr[MAX_N*32+500];
bool cmp(const Pre &a, const Pre &b) {return a.val < b.val;}
void DFS(int u) {
	vis[u] = true;
	for (int i = 1; (1<<i) <= dep[u]; i++)	anc[u][i] = anc[anc[u][i-1]][i-1];
	for (int i = 0; i < G[u].size(); i++) {
		int v = G[u][i];
		if (!vis[v])	anc[v][0] = u, dep[v] = dep[u]+1, DFS(v);
	}
}
int LCA(int a, int b) {
	int i, j;
	if (dep[a] < dep[b])	swap(a, b);
	for (i = 0; (1<<i) <= dep[a]; i++) ;	i--;
	for (j = i; j >= 0; j--)
		if (dep[a]-(1<<j) >= dep[b])
			a = anc[a][j];
	if (a == b)	return a;
	for (j = i; j >= 0; j--)
		if (anc[a][j] != anc[b][j])
			a = anc[a][j], b = anc[b][j];
	return anc[a][0];
}
void init(int &v, int s, int t) {
	v = ++cnt;
	if (s == t)	return;
	int mid = s+t>>1;
	init(tr[v].ls, s, mid);
	init(tr[v].rs, mid+1, t);
}
void insert(int v, int o, int s, int t, int val) {
	tr[v] = tr[o];
	if (s == t)	{tr[v].val++;	return;}
	int mid = s+t>>1;
	if (val <= mid)	insert(tr[v].ls = ++cnt, tr[o].ls, s, mid, val);
	else	insert(tr[v].rs = ++cnt, tr[o].rs, mid+1, t, val);
	tr[v].val = tr[tr[v].ls].val+tr[tr[v].rs].val;
}
void build(int u) {
	root[u] = ++cnt;
	insert(root[u], root[anc[u][0]], 1, tot, c[u]);
	for (int i = 0; i < G[u].size(); i++) {
		int v = G[u][i];
		if (v != anc[u][0])	build(v);
	}
}
int query(int v1, int v2, int v3, int v4, int s, int t, int k) {
	if (s == t)	return s;
	int mid = s+t>>1, tmp = tr[tr[v1].ls].val+tr[tr[v2].ls].val-tr[tr[v3].ls].val-tr[tr[v4].ls].val;
	if (k <= tmp)	return query(tr[v1].ls, tr[v2].ls, tr[v3].ls, tr[v4].ls, s, mid, k);
	return query(tr[v1].rs, tr[v2].rs, tr[v3].rs, tr[v4].rs, mid+1, t, k-tmp);
}
int main() {
	scanf("%d%d", &n, &m);
	for (int i = 1; i <= n; i++)	pre[i].id = i, scanf("%d", &pre[i].val);
	sort(pre+1, pre+n+1, cmp);
	for (int i = 1; i <= n; i++) {
		if (i == 1 || pre[i].val != pre[i-1].val)	map[++tot] = pre[i].val;
		c[pre[i].id] = tot;
	}
	for (int i = 1; i < n; i++) {
		int u, v;
		scanf("%d%d", &u, &v);
		G[u].push_back(v), G[v].push_back(u);
	}
	DFS(1);
	init(root[0], 1, tot), build(1);
	int ans = 0;
	while (m--) {
		int u, v, k;
		scanf("%d%d%d", &u, &v, &k);	u ^= ans;
		int lca = LCA(u, v);
		ans = map[query(root[u], root[v], root[lca], root[anc[lca][0]], 1, tot, k)];
		printf("%d", ans);
		if (m >= 1)	printf("\n");
	}
	return 0;
}
posted @ 2017-09-20 15:26  Azrael_Death  阅读(155)  评论(0编辑  收藏  举报