Luogu P3258 [JLOI2014]松鼠的新家


思路
这道题我一开始做的时候并不会树上差分,然后就卡了好久……
首先是树上差分,这个东西和普通序列上的差分大同小异。设\(sum[i]\)为差分数组,那么\(sum[i]\)表示的是\(i\)这个点到根节点上所有值的和。若要对\(x \sim y\)这条链上所有的点都加上\(v\),
那么就要对\(sum[x]+=v , sum[y]+=v , sum[lca(x,y)]-=v , sum[fa(lca(x,y))]-=v\) 。(这一部分建议自行画图理解一下,和普通差分的原理相似)
关于LCA,应该就没啥好说的,一般就是倍增求LCA和树剖求LCA(偶尔也会有用RMQ求LCA)。这里我使用的是树剖求LCA(比较快嘛)。
最后统计答案,就是进行一遍DFS,每个结点的权值即为该节点的子树的权值和(差分和前缀和互为逆运算)。但是这个题还有一个坑点,就是每次对两个点进行操作时,\(2 \sim (n-1)\)这个区间
中的所有点就被重复加了,所以在最后要对这些点的权值减去\(1\) 。并且根据题目中所说,在第\(n\)个节点是不需要糖的,所以\(n\)这个结点的权值同样也要减\(1\)。
Code
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#define MAXN 300005
int n, a[MAXN], cnt;
int top[MAXN], son[MAXN], fa[MAXN];
int siz[MAXN], dep[MAXN], sum[MAXN];
class node{
public:
int to;
node *nxt = NULL;
} edge[MAXN << 1], *head[MAXN];
inline int read(void){
int f = 1, x = 0;char ch;
do{ch = getchar();if(ch=='-')f = -1;} while (ch < '0' || ch > '9');
do{ x = x * 10 + ch - '0';ch = getchar();} while (ch >= '0' && ch <= '9');
return f * x;
}
inline void add_edge(int x,int y){
++cnt;
edge[cnt].nxt = head[x];
head[x] = &edge[cnt];
head[x]->to = y;
return;
}
void DFS1(int u){
siz[u] = 1;
dep[u] = dep[fa[u]] + 1;
for (node *i = head[u]; i != NULL;i = i->nxt){
int v = i->to;
if(v==fa[u]) continue;
fa[v] = u;
DFS1(v);
siz[u] += siz[v];
if(!son[u]||siz[son[u]]<siz[v]) son[u] = v;
}
}//树剖的第一个DFS
void DFS2(int u,int idx){
top[u] = idx;
if(son[u]) DFS2(son[u], idx);
for (node *i = head[u]; i != NULL;i = i->nxt){
int v = i->to;
if(v==fa[u]||v==son[u]) continue;
DFS2(v, v);
}
}//...
inline int LCA(int x,int y){
while(top[x]!=top[y]){
if(dep[top[x]]>=dep[top[y]]) x = fa[top[x]];
else y = fa[top[y]];
}
return dep[x] < dep[y] ? x : y;
}//...
void DFS3(int u){
for (node *i = head[u]; i != NULL;i = i->nxt){
int v = i->to;
if(v==fa[u]) continue;
DFS3(v);
sum[u] += sum[v];
}
}//统计答案,把差分数组还原
int main(){
n = read();
for (int i = 1; i <= n;++i)
a[i] = read();
for (int i = 1; i < n;++i){
int x = read(), y = read();
add_edge(x, y), add_edge(y, x);//不要忘了是加双向边
}
DFS1(1), DFS2(1, 1);
for (int i = 1; i < n;++i){
++sum[a[i]], ++sum[a[i + 1]];
int la = LCA(a[i], a[i + 1]);
--sum[la], --sum[fa[la]];//差分部分
// std::cout << a[i] << ' ' << a[i + 1]<<' ' << la << '\n';
}
// for (int i = 1; i <= n;++i)
// std::cout << sum[i] << ' ';
// puts("");
DFS3(1);//统计答案
for (int i = 2; i <= n;++i)
--sum[a[i]];//减去重复和不需要的点
for (int i = 1; i <= n;++i)
printf("%d\n", sum[i]);
return 0;
}

浙公网安备 33010602011771号