cf1172E Nauuo and ODT(LCT)

首先可以转化问题,变为对每种颜色分别考虑不含该颜色的简单路径条数。然后把不是当前颜色的点视为白色,是当前颜色的点视为黑色,显然路径数量是每个白色连通块大小的平方和,然后题目变为:黑白两色的树,单点翻转颜色,维护白色连通块大小平方和,然后根据Auuan大佬的题解,我用了LCT。就是对每个点维护子树、儿子大小平方和,在 link/cut 的时候更新答案。初始化所有点是白色,离线处理每个颜色即可。

这题放在2h比赛上,除了lxl其他人都写不出来(况且lxl还是本题出题人呢)

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=4e5+7;
int n,m,c[N],f[N],fa[N],ch[N][2],sum[N],sz[N];
ll ans,d[N],sz2[N];
bool vis[N];
vector<int>vec[N][2],G[N];
bool nroot(int x){return x==ch[fa[x]][0]||x==ch[fa[x]][1];}
void pushup(int x){sum[x]=sum[ch[x][0]]+sum[ch[x][1]]+sz[x]+1;}
void rotate(int x)
{
    int y=fa[x],z=fa[y],w=x==ch[y][1];
    if(nroot(y))ch[z][y==ch[z][1]]=x;
    fa[x]=z,ch[y][w]=ch[x][w^1],fa[ch[x][w^1]]=y,ch[x][w^1]=y,fa[y]=x;
    pushup(y),pushup(x);
}
void splay(int x)
{
    while(nroot(x))
    {
        int y=fa[x],z=fa[y];
        if(nroot(y))
        {
            if((x==ch[y][1])^(y==ch[z][1]))rotate(x);
            else rotate(y);
        }
        rotate(x);
    }
}
void access(int x)
{
    int y=0;
    while(x)
    {
        splay(x);
        sz[x]+=sum[ch[x][1]]-sum[y];
        sz2[x]+=1ll*sum[ch[x][1]]*sum[ch[x][1]]-1ll*sum[y]*sum[y];
        ch[x][1]=y;
        pushup(x);
        x=fa[y=x];
    }
}
int findrt(int x)
{
    access(x),splay(x);
    while(ch[x][0])x=ch[x][0];
    splay(x);
    return x;
}
void link(int x)
{
    int y=f[x],z;
    splay(x);
    ans-=sz2[x]+1ll*sum[ch[x][1]]*sum[ch[x][1]];
    z=findrt(y);
    access(x),splay(z);
    ans-=1ll*sum[ch[z][1]]*sum[ch[z][1]];
    fa[x]=y;
    splay(y);
    sz[y]+=sum[x],sz2[y]+=1ll*sum[x]*sum[x];
    pushup(y),access(x),splay(z);
    ans+=1ll*sum[ch[z][1]]*sum[ch[z][1]];
}
void cut(int x)
{
    int y=f[x],z;
    access(x);
    ans+=sz2[x];
    z=findrt(y);
    access(x),splay(z);
    ans-=1ll*sum[ch[z][1]]*sum[ch[z][1]];
    splay(x);
    ch[x][0]=fa[ch[x][0]]=0;
    pushup(x),splay(z);
    ans+=1ll*sum[ch[z][1]]*sum[ch[z][1]];
}
void dfs(int u)
{for(int i=0;i<G[u].size();i++)if(G[u][i]!=f[u])f[G[u][i]]=u,dfs(G[u][i]);}
int main()
{
    scanf("%d%d",&n,&m);
    ll lst;
    for(int i=1;i<=n;i++)
    scanf("%d",&c[i]),vec[c[i]][0].push_back(i),vec[c[i]][1].push_back(0);
    for(int i=1;i<=n+1;i++)sum[i]=1;
    for(int i=1,x,y;i<n;i++)scanf("%d%d",&x,&y),G[x].push_back(y),G[y].push_back(x);
    for(int i=1,u,v;i<=m;i++)
    {
        scanf("%d%d",&u,&v);
        vec[c[u]][0].push_back(u),vec[c[u]][1].push_back(i);
        c[u]=v;
        vec[v][0].push_back(u),vec[v][1].push_back(i);
    }
    f[1]=n+1;
    dfs(1);
    for(int i=1;i<=n;i++)link(i);
    for(int i=1;i<=n;i++)
    {
        if(!vec[i][0].size()){d[0]+=1ll*n*n;continue;}
        if(vec[i][1][0])d[0]+=1ll*n*n,lst=1ll*n*n;else lst=0;
        for(int j=0;j<vec[i][0].size();j++)
        {
            int u=vec[i][0][j];
            if(vis[u]^=1)cut(u);else link(u);
            if(j==vec[i][0].size()-1||vec[i][1][j+1]!=vec[i][1][j])
            d[vec[i][1][j]]+=ans-lst,lst=ans;
        }
        for(int j=vec[i][0].size()-1;~j;j--)
        {
            int u=vec[i][0][j];
            if(vis[u]^=1)cut(u);else link(u);
        }
    }
    ans=1ll*n*n*n;
    for(int i=0;i<=m;i++)ans-=d[i],printf("%lld\n",ans);
}
View Code

 

posted @ 2019-06-13 18:51  hfctf0210  阅读(441)  评论(0编辑  收藏  举报