BZOJ 4012. [HNOI2015]开店
BZOJ 4012. [HNOI2015]开店
题目描述
询问一个点到树上其它点的距离,这些点满足颜色属于区间\([l,r]\).
多组询问
解题思路
首先考虑离线的做法,把所有颜色离散化,然后扫描颜色序列,每扫到一个点就把它到根的贡献算上,就是每个点都加上父边长度的贡献,这样求答案的时候直接跳到根节点,得到的总的权值和,就是dep(lca(u,v))的和。
这样再利用经典的树上算距离的公式算一下就好。
但是这道题目要求强制在线,我们考虑变成主席树。
乍一看做不了,因为是每次\(O(n)\) 个单点修改。所以我们要使用标记永久化,这样只用修改\(O(log^2)\)个节点了。但是空间还是开不下。
注意到同一个版本会有许多重复开的新节点,这样我们对每个节点打标记,如果上次修改这个节点的版本和此次相同,就不开新的节点,这样可以卡很多空间。
#include<bits/stdc++.h>
using namespace std;
#define LL long long
const int N = 150011;
int n, Q, A;
int tot, rt[N], a[N], b[N], num;
int head[N], nex[N<<1], to[N<<1], wei[N<<1], size;
LL sumE[N], sumV[N], cntV[N], val[N], dis[N];
int id[N], dfn[N], son[N], sz[N], top[N], fa[N], cnt;
struct node{
int ls, rs, cnt, tim;
LL val;
}t[N*150];
void add(int x, int y, int z){
to[++size] = y;
nex[size] = head[x];
head[x] = size;
wei[size] = z;
}
void dfs(int u){
sz[u] = 1;
for(int i = head[u];i;i = nex[i]){
int v = to[i];
if(v == fa[u])continue;
fa[v] = u; val[v] = wei[i];
dis[v] = dis[u] + wei[i];
dfs(v);
sz[u] += sz[v];
if(sz[v] > sz[son[u]])son[u] = v;
}
}
void dfs(int u, int tp){
top[u] = tp;
dfn[u] = ++num;
id[num] = u;
if(son[u])dfs(son[u], tp);
for(int i = head[u];i;i = nex[i]){
int v = to[i];
if(v == fa[u] || v == son[u])continue;
dfs(v, v);
}
}
void modi(int &p, int l, int r, int x, int y, int tim){
int now = t[p].tim == tim ? p : ++tot;
t[now] = t[p];
t[now].tim = tim;
p = now;
t[p].val += sumE[min(y, r)] - sumE[max(l, x)-1];
if(l >= x && r <= y){
t[now].cnt++;
return ;
}
int mid = l + r >> 1;
if(mid >= y)modi(t[p].ls, l, mid, x, y, tim);
else if(mid < x)modi(t[p].rs, mid + 1, r, x, y, tim);
else modi(t[p].ls, l, mid, x, y, tim), modi(t[p].rs, mid + 1, r, x, y, tim);
}
LL query(int R, int L, int l, int r, int x, int y){
if(l > y || r < x)return 0;
if(l >= x && r <= y){
return t[R].val - t[L].val;
}
int mid = l + r >> 1;
return query(t[R].ls, t[L].ls, l, mid, x, y) + query(t[R].rs, t[L].rs, mid + 1, r, x, y) + 1LL * (t[R].cnt - t[L].cnt) * (sumE[min(y, r)] - sumE[max(x, l)-1]);
}
bool cmp(int x, int y){
return a[x] < a[y];
}
void jump(int x, int id){
while(x){
modi(rt[id], 1, n, dfn[top[x]], dfn[x], id);
x = fa[top[x]];
}
}
int main(){
freopen("4012.in", "r", stdin);
//freopen("4012.out", "w", stdout);
cin>>n>>Q>>A;
for(int i = 1;i <= n; i++){
scanf("%d", &a[i]);
b[i] = a[i];
}
sort(b + 1, b + 1 + n);
cnt = unique(b + 1, b + 1 + n) - (b + 1);
static int p[N];
for(int i = 1;i <= n; i++){
a[i] = lower_bound(b + 1, b + 1 + cnt, a[i]) - b;
p[i] = i;
}
int u, v, w;
for(int i = 1;i < n; i++){
scanf("%d%d%d", &u, &v, &w);
add(u, v, w); add(v, u, w);
}
dfs(1);
dfs(1, 1);
for(int i = 1;i <= n; i++){
sumE[i] = sumE[i-1] + val[id[i]];
}
sort(p + 1, p + 1 + n, cmp);
int pos = 1;
for(int i = 1;i <= n; i++){
int x = p[i];
rt[a[x]] = rt[a[p[i-1]]];
sumV[a[x]] += dis[x];
cntV[a[x]]++;
jump(x, a[x]);
}
for(int i = 1;i <= cnt; i++)sumV[i] += sumV[i-1], cntV[i] += cntV[i-1];
int a, b;
cin>>u>>a>>b;
int ans = 0;
for(int i = 1;i <= Q; i++){
int L = min((a + ans) % A, (b + ans) % A);
int R = max((a + ans) % A, (b + ans) % A);
L = lower_bound(::b + 1, ::b + 1 + cnt, L) - ::b;
R = upper_bound(::b + 1, ::b + 1 + cnt, R) - ::b - 1;
if(L > R){
ans = 0;
}
else{
LL res = (cntV[R] - cntV[L-1]) * dis[u] + sumV[R] - sumV[L-1];
int x = u;
while(x){
res -= 2ll * query(rt[R], rt[L-1], 1, n, dfn[top[x]], dfn[x]);
x = fa[top[x]];
}
ans = res;
}
printf("%lld\n", ans);
if(i < Q)scanf("%d%d%d", &u, &a, &b);
}
return 0;
}