[树上游戏]解题报告

题面

点分治做法

点分治和点分树相关中已有提及,不再分析。

树上差分做法

正难则反,考虑一个点 \(u\) 和一个颜色 \(i\),记 \(f_{u,i}\) 为以 \(u\) 为路径一个端点,不经过颜色为 \(i\) 的路径的数量。这个东西转换一下就变成,删去颜色为 \(i\) 的点后,有多少个点和点 \(u\) 在同一个联通块中。

但这个做法复杂度是 \(O(nm)\) 的(\(m\) 表示颜色的数量),考虑优化。找到一个颜色为 \(i\) 的点 \(u\),那么删去颜色 \(i\) 后,子树 \(v\) 所在连通块的大小为 \(siz_v-\sum\limits_{x\in Sv\text{且}col_x=i}siz_x\)。因为对联通块中的点均有贡献,考虑树上差分,在 \(v\) 处差分数组加上连通块的大小,在 \(x\) 处差分数组减去联通块的大小。每个点会在它到根的路径上找到第一个和它颜色相同的点(或根节点),然后减去一定的贡献,由此可知,每个点最多被减一次贡献。考虑用 vector 存下每个颜色中还未找到颜色相同的点有哪些,dfs 完点 \(v\) 后计算 \(col_u\) 的贡献,根节点所在联通块特殊处理,复杂度为 \(O(n)\)

最后 \(ans_u=n\times m-d_u\)\(d_u\) 表示树上前缀和后的差分数组的值。

点击查看代码
#include<bits/stdc++.h>
#define ull unsigned long long
#define ll long long
#define pdi pair<double,int>
#define pii pair<int,int>
#define pb push_back
#define mp make_pair
#define eps 1e-9
using namespace std;
namespace IO{
    template<typename T>
    inline void read(T &x){
        x=0;
        int f=1;
        char ch=getchar();
        while(ch>'9'||ch<'0'){
            if(ch=='-'){
                f=-1;
    }
            ch=getchar();
        }
        while(ch>='0'&&ch<='9'){
            x=x*10+(ch-'0');
            ch=getchar();
        }
        x=(f==1?x:-x);
    }
    template<typename T>
    inline void write(T x){
        if(x<0){
            putchar('-');
            x=-x;
        }
        if(x>=10){
            write(x/10);
        }
        putchar(x%10+'0');
    }
    template<typename T>
    inline void write_endl(T x){
        write(x);
        putchar('\n');
    }
    template<typename T>
    inline void write_space(T x){
        write(x);
        putchar(' ');
    }
}
using namespace IO;
const int N=1e5+10;
int col[N],n,m,c[N],id[N];
int siz[N],dfn[N],idx,cnt[N];
ll d[N];
vector<int>e[N],son[N];
void dfs(int u,int fa){
    siz[u]=1;
    dfn[u]=++idx;
    for(auto v:e[u]){
        if(v==fa){
            continue;
        }
        int lstcnt=cnt[col[u]];
        dfs(v,u);
        siz[u]+=siz[v];
        int nowcnt=siz[v]-(cnt[col[u]]-lstcnt);
        d[v]+=nowcnt;
        cnt[col[u]]+=nowcnt;
        while(son[col[u]].size()&&dfn[son[col[u]].back()]>dfn[u]){
            d[son[col[u]].back()]-=nowcnt;
            son[col[u]].pop_back();
        }
    }
    cnt[col[u]]++;
    son[col[u]].pb(u);
}
void get_ans(int u,int fa){
    d[u]+=d[fa];
    for(auto v:e[u]){
        if(v==fa){
            continue;
        }
        get_ans(v,u);
    }
}
signed main(){
    #ifndef ONLINE_JUDGE
        freopen("1.in","r",stdin);
        freopen("1.out","w",stdout);
    #endif
    read(n);
    for(int i=1;i<=n;i++){
        read(col[i]);
        c[col[i]]=1;
    }
    for(int i=1;i<N;i++){
        if(c[i]){
            id[i]=++m;
        }
    }
    for(int i=1;i<=n;i++){
        col[i]=id[col[i]];
    }
    for(int i=1,u,v;i<n;i++){
        read(u),read(v);
        e[u].pb(v);
        e[v].pb(u);
    }
    dfs(1,0);
    for(int i=1;i<=m;i++){
        d[1]+=n-cnt[i];
        for(auto x:son[i]){
            d[x]-=n-cnt[i];
        }
    }
    get_ans(1,0);
    for(int i=1;i<=n;i++){
        write_endl(1ll*n*m-d[i]);
    }
    return 0;
}

换根dp做法

还是从只有深度最小的点能产生 \(siz\) 的贡献开始,通过一次 dfs 可以得到 \(1\) 号点的答案。考虑计算答案的变化量,设当前根节点为 \(u\),要转移到 \(v\),点 \(u\) 由根变到 \(v\) 的儿子节点,贡献减少 \(siz_v\)\(v\) 从根的一个儿子变为根节点,\(col_u\) 的贡献由原来的 \(cnt_{col_u}\) 变成了 \(n\),唯一没有计算的就是原来子树 \(v\)\(col_u\) 带来的贡献,记为 \(tot_v\),这个可以在第一次 dfs 中同步得到,\(cnt_{col_u}\) 变为 \(n-siz_v+tot_v\)\(cnt_{col_v}\) 变为 \(n\)

点击查看代码
#include<bits/stdc++.h>
#define ull unsigned long long
#define int long long
#define pdi pair<double,int>
#define pii pair<int,int>
#define pb push_back
#define mp make_pair
#define eps 1e-9
using namespace std;
namespace IO{
    template<typename T>
    inline void read(T &x){
        x=0;
        int f=1;
        char ch=getchar();
        while(ch>'9'||ch<'0'){
            if(ch=='-'){
                f=-1;
    }
            ch=getchar();
        }
        while(ch>='0'&&ch<='9'){
            x=x*10+(ch-'0');
            ch=getchar();
        }
        x=(f==1?x:-x);
    }
    template<typename T>
    inline void write(T x){
        if(x<0){
            putchar('-');
            x=-x;
        }
        if(x>=10){
            write(x/10);
        }
        putchar(x%10+'0');
    }
    template<typename T>
    inline void write_endl(T x){
        write(x);
        putchar('\n');
    }
    template<typename T>
    inline void write_space(T x){
        write(x);
        putchar(' ');
    }
}
using namespace IO;
const int N=1e5+10;
int n,col[N],cnt[N],tot[N],ans[N];
int siz[N];
vector<int>e[N];
void dfs(int u,int fa){
    int tmp=cnt[col[u]];
    cnt[col[fa]]=0;
    siz[u]=1;
    for(auto v:e[u]){
        if(v==fa){
            continue;
        }
        dfs(v,u);
        siz[u]+=siz[v];
    }
    cnt[col[u]]=siz[u];
    tot[u]=cnt[col[fa]];
    cnt[col[u]]+=tmp;
}
void get_ans(int u,int fa){
    int tmp1=cnt[col[u]],tmp2=cnt[col[fa]];
    if(fa){
        ans[u]=ans[fa]-siz[u]+tot[u]+n-cnt[col[u]];
        cnt[col[u]]=n;
        cnt[col[fa]]=n-siz[u]+tot[u];
    }
    for(auto v:e[u]){
        if(v==fa){
            continue;
        }
        get_ans(v,u);
    }
    cnt[col[u]]=tmp1,cnt[col[fa]]=tmp2;
}
signed main(){
    #ifndef ONLINE_JUDGE
        freopen("1.in","r",stdin);
        freopen("1.out","w",stdout);
    #endif
    read(n);
    for(int i=1;i<=n;i++){
        read(col[i]);
    }
    for(int i=1,u,v;i<n;i++){
        read(u),read(v);
        e[u].pb(v);
        e[v].pb(u);
    }
    dfs(1,0);
    for(int i=1;i<=1e5;i++){
        ans[1]+=cnt[i];
    }
    get_ans(1,0);
    for(int i=1;i<=n;i++){
        write_endl(ans[i]);
    }
    return 0;
}
posted @ 2023-03-16 15:26  luo_shen  阅读(42)  评论(0)    收藏  举报