[BZOJ2588][Spoj 10628]Count on a tree
[BZOJ2588][Spoj 10628]Count on a tree
试题描述
给定一棵N个节点的树,每个点有一个权值,对于M个询问(u,v,k),你需要回答u xor lastans和v这两个节点间第K小的点权。其中lastans是上一个询问的答案,初始为0,即第一个询问的u是明文。
输入
第一行两个整数N,M。
第二行有N个整数,其中第i个整数表示点i的权值。
后面N-1行每行两个整数(x,y),表示点x到点y有一条边。
最后M行每行两个整数(u,v,k),表示一组询问。
输出
M行,表示每个询问的答案。最后一个询问不输出换行符
输入示例
8 5 105 2 9 3 8 5 7 7 1 2 1 3 1 4 3 5 3 6 3 7 4 8 2 5 1 0 5 2 10 5 3 11 5 4 110 8 2
输出示例
2 8 9 105 7
数据规模及约定
N,M<=100000
题解
我们可以把主席树按照树形结构来建,即每一个节点上的版本从它父亲节点的版本修改而来,那么一个节点上的主席树记录的就是该节点到根节点的权值信息了,于是利用 d(a, b) = dep(a) + dep(b) - dep(lca(a, b)) - dep(fa[lca(a, b)]) 这个公式(其中 d(a, b) 表示路径 a 到 b 的权值和,dep(u) = d(root, u),root 为根节点,lca(a, b) 为 a 与 b 的最近公共祖先,fa[u] 为 u 的父亲)二分。
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>
#include <stack>
#include <vector>
#include <queue>
#include <cstring>
#include <string>
#include <map>
#include <set>
using namespace std;
const int BufferSize = 1 << 16;
char buffer[BufferSize], *Head, *Tail;
inline char Getchar() {
if(Head == Tail) {
int l = fread(buffer, 1, BufferSize, stdin);
Tail = (Head = buffer) + l;
}
return *Head++;
}
int read() {
int x = 0, f = 1; char c = Getchar();
while(!isdigit(c)){ if(c == '-') f = -1; c = Getchar(); }
while(isdigit(c)){ x = x * 10 + c - '0'; c = Getchar(); }
return x * f;
}
#define maxn 100010
#define maxm 200010
#define maxlog 17
#define maxnode 2000010
int n, rt[maxn], val[maxn], num[maxn];
int ToT, sumv[maxnode], lc[maxnode], rc[maxnode];
void update(int& y, int x, int l, int r, int p) {
sumv[y = ++ToT] = sumv[x] + 1;
if(l == r) return ;
int mid = l + r >> 1; lc[y] = lc[x]; rc[y] = rc[x];
if(p <= mid) update(lc[y], lc[x], l, mid, p);
else update(rc[y], rc[x], mid + 1, r, p);
return ;
}
int m, head[maxn], next[maxm], to[maxm], fa[maxlog][maxn], dep[maxn];
void AddEdge(int a, int b) {
to[++m] = b; next[m] = head[a]; head[a] = m;
swap(a, b);
to[++m] = b; next[m] = head[a]; head[a] = m;
return ;
}
void build(int u) {
update(rt[u], rt[fa[0][u]], 1, n, val[u]);
for(int i = 1; i < maxlog; i++) fa[i][u] = fa[i-1][fa[i-1][u]];
for(int e = head[u]; e; e = next[e]) if(to[e] != fa[0][u]) {
fa[0][to[e]] = u;
dep[to[e]] = dep[u] + 1;
build(to[e]);
}
return ;
}
int lca(int a, int b) {
if(dep[a] < dep[b]) swap(a, b);
for(int i = maxlog - 1; i >= 0; i--) if(dep[a] - dep[b] >= (1 << i)) a = fa[i][a];
for(int i = maxlog - 1; i >= 0; i--) if(fa[i][a] != fa[i][b]) a = fa[i][a], b = fa[i][b];
return a == b ? a : fa[0][b];
}
int solve(int a, int b, int k) {
int lrt[2] = {rt[a], rt[b]}, c = lca(a, b), rrt[2] = {rt[c], rt[fa[0][c]]};
int l = 1, r = n;
while(l < r) {
int mid = l + r >> 1, sum = 0;
for(int i = 0; i < 2; i++) if(lrt[i] && lc[lrt[i]]) sum += sumv[lc[lrt[i]]];
for(int i = 0; i < 2; i++) if(rrt[i] && lc[rrt[i]]) sum -= sumv[lc[rrt[i]]];
if(sum < k) {
k -= sum; l = mid + 1;
for(int i = 0; i < 2; i++) if(lrt[i]) lrt[i] = rc[lrt[i]];
for(int i = 0; i < 2; i++) if(rrt[i]) rrt[i] = rc[rrt[i]];
}
else {
r = mid;
for(int i = 0; i < 2; i++) if(lrt[i]) lrt[i] = lc[lrt[i]];
for(int i = 0; i < 2; i++) if(rrt[i]) rrt[i] = lc[rrt[i]];
}
}
return num[l];
}
int main() {
n = read(); int q = read();
for(int i = 1; i <= n; i++) val[i] = num[i] = read();
sort(num + 1, num + n + 1);
for(int i = 1; i <= n; i++) val[i] = lower_bound(num + 1, num + n + 1, val[i]) - num;
for(int i = 1; i < n; i++) {
int a = read(), b = read();
AddEdge(a, b);
}
build(1);
int lst = 0;
while(q--) {
int a = read() ^ lst, b = read(), k = read();
lst = solve(a, b, k);
if(q) printf("%d\n", lst);
else printf("%d", lst);
}
return 0;
}

浙公网安备 33010602011771号