[树上游戏]解题报告
点分治做法
在点分治和点分树相关中已有提及,不再分析。
树上差分做法
正难则反,考虑一个点 \(u\) 和一个颜色 \(i\),记 \(f_{u,i}\) 为以 \(u\) 为路径一个端点,不经过颜色为 \(i\) 的路径的数量。这个东西转换一下就变成,删去颜色为 \(i\) 的点后,有多少个点和点 \(u\) 在同一个联通块中。
但这个做法复杂度是 \(O(nm)\) 的(\(m\) 表示颜色的数量),考虑优化。找到一个颜色为 \(i\) 的点 \(u\),那么删去颜色 \(i\) 后,子树 \(v\) 所在连通块的大小为 \(siz_v-\sum\limits_{x\in Sv\text{且}col_x=i}siz_x\)。因为对联通块中的点均有贡献,考虑树上差分,在 \(v\) 处差分数组加上连通块的大小,在 \(x\) 处差分数组减去联通块的大小。每个点会在它到根的路径上找到第一个和它颜色相同的点(或根节点),然后减去一定的贡献,由此可知,每个点最多被减一次贡献。考虑用 vector 存下每个颜色中还未找到颜色相同的点有哪些,dfs 完点 \(v\) 后计算 \(col_u\) 的贡献,根节点所在联通块特殊处理,复杂度为 \(O(n)\)。
最后 \(ans_u=n\times m-d_u\),\(d_u\) 表示树上前缀和后的差分数组的值。
点击查看代码
#include<bits/stdc++.h>
#define ull unsigned long long
#define ll long long
#define pdi pair<double,int>
#define pii pair<int,int>
#define pb push_back
#define mp make_pair
#define eps 1e-9
using namespace std;
namespace IO{
template<typename T>
inline void read(T &x){
x=0;
int f=1;
char ch=getchar();
while(ch>'9'||ch<'0'){
if(ch=='-'){
f=-1;
}
ch=getchar();
}
while(ch>='0'&&ch<='9'){
x=x*10+(ch-'0');
ch=getchar();
}
x=(f==1?x:-x);
}
template<typename T>
inline void write(T x){
if(x<0){
putchar('-');
x=-x;
}
if(x>=10){
write(x/10);
}
putchar(x%10+'0');
}
template<typename T>
inline void write_endl(T x){
write(x);
putchar('\n');
}
template<typename T>
inline void write_space(T x){
write(x);
putchar(' ');
}
}
using namespace IO;
const int N=1e5+10;
int col[N],n,m,c[N],id[N];
int siz[N],dfn[N],idx,cnt[N];
ll d[N];
vector<int>e[N],son[N];
void dfs(int u,int fa){
siz[u]=1;
dfn[u]=++idx;
for(auto v:e[u]){
if(v==fa){
continue;
}
int lstcnt=cnt[col[u]];
dfs(v,u);
siz[u]+=siz[v];
int nowcnt=siz[v]-(cnt[col[u]]-lstcnt);
d[v]+=nowcnt;
cnt[col[u]]+=nowcnt;
while(son[col[u]].size()&&dfn[son[col[u]].back()]>dfn[u]){
d[son[col[u]].back()]-=nowcnt;
son[col[u]].pop_back();
}
}
cnt[col[u]]++;
son[col[u]].pb(u);
}
void get_ans(int u,int fa){
d[u]+=d[fa];
for(auto v:e[u]){
if(v==fa){
continue;
}
get_ans(v,u);
}
}
signed main(){
#ifndef ONLINE_JUDGE
freopen("1.in","r",stdin);
freopen("1.out","w",stdout);
#endif
read(n);
for(int i=1;i<=n;i++){
read(col[i]);
c[col[i]]=1;
}
for(int i=1;i<N;i++){
if(c[i]){
id[i]=++m;
}
}
for(int i=1;i<=n;i++){
col[i]=id[col[i]];
}
for(int i=1,u,v;i<n;i++){
read(u),read(v);
e[u].pb(v);
e[v].pb(u);
}
dfs(1,0);
for(int i=1;i<=m;i++){
d[1]+=n-cnt[i];
for(auto x:son[i]){
d[x]-=n-cnt[i];
}
}
get_ans(1,0);
for(int i=1;i<=n;i++){
write_endl(1ll*n*m-d[i]);
}
return 0;
}
换根dp做法
还是从只有深度最小的点能产生 \(siz\) 的贡献开始,通过一次 dfs 可以得到 \(1\) 号点的答案。考虑计算答案的变化量,设当前根节点为 \(u\),要转移到 \(v\),点 \(u\) 由根变到 \(v\) 的儿子节点,贡献减少 \(siz_v\),\(v\) 从根的一个儿子变为根节点,\(col_u\) 的贡献由原来的 \(cnt_{col_u}\) 变成了 \(n\),唯一没有计算的就是原来子树 \(v\) 中 \(col_u\) 带来的贡献,记为 \(tot_v\),这个可以在第一次 dfs 中同步得到,\(cnt_{col_u}\) 变为 \(n-siz_v+tot_v\),\(cnt_{col_v}\) 变为 \(n\)。
点击查看代码
#include<bits/stdc++.h>
#define ull unsigned long long
#define int long long
#define pdi pair<double,int>
#define pii pair<int,int>
#define pb push_back
#define mp make_pair
#define eps 1e-9
using namespace std;
namespace IO{
template<typename T>
inline void read(T &x){
x=0;
int f=1;
char ch=getchar();
while(ch>'9'||ch<'0'){
if(ch=='-'){
f=-1;
}
ch=getchar();
}
while(ch>='0'&&ch<='9'){
x=x*10+(ch-'0');
ch=getchar();
}
x=(f==1?x:-x);
}
template<typename T>
inline void write(T x){
if(x<0){
putchar('-');
x=-x;
}
if(x>=10){
write(x/10);
}
putchar(x%10+'0');
}
template<typename T>
inline void write_endl(T x){
write(x);
putchar('\n');
}
template<typename T>
inline void write_space(T x){
write(x);
putchar(' ');
}
}
using namespace IO;
const int N=1e5+10;
int n,col[N],cnt[N],tot[N],ans[N];
int siz[N];
vector<int>e[N];
void dfs(int u,int fa){
int tmp=cnt[col[u]];
cnt[col[fa]]=0;
siz[u]=1;
for(auto v:e[u]){
if(v==fa){
continue;
}
dfs(v,u);
siz[u]+=siz[v];
}
cnt[col[u]]=siz[u];
tot[u]=cnt[col[fa]];
cnt[col[u]]+=tmp;
}
void get_ans(int u,int fa){
int tmp1=cnt[col[u]],tmp2=cnt[col[fa]];
if(fa){
ans[u]=ans[fa]-siz[u]+tot[u]+n-cnt[col[u]];
cnt[col[u]]=n;
cnt[col[fa]]=n-siz[u]+tot[u];
}
for(auto v:e[u]){
if(v==fa){
continue;
}
get_ans(v,u);
}
cnt[col[u]]=tmp1,cnt[col[fa]]=tmp2;
}
signed main(){
#ifndef ONLINE_JUDGE
freopen("1.in","r",stdin);
freopen("1.out","w",stdout);
#endif
read(n);
for(int i=1;i<=n;i++){
read(col[i]);
}
for(int i=1,u,v;i<n;i++){
read(u),read(v);
e[u].pb(v);
e[v].pb(u);
}
dfs(1,0);
for(int i=1;i<=1e5;i++){
ans[1]+=cnt[i];
}
get_ans(1,0);
for(int i=1;i<=n;i++){
write_endl(ans[i]);
}
return 0;
}

浙公网安备 33010602011771号