点分治和点分树

其实还是没怎么懂点分治和点分树,随便写点自己的理解。

点分治1

求两个点的距离公式很明了 \(dis(u,v)=d_u+d_v-2\times d_{lca(u,v)}\),其中 \(d_x\) 表示根到 \(x\) 的距离。距离公式中最不好处理的其实是后面的 \(-2\times d_{lca(u,v)}\)

转换下思路,如果 \(lca(u,v)\) 是树根是不是就会好求许多了。这时我们就需要一个算法实现一下几个操作:

  1. 找到一个新根
  2. 统计经过根的答案
  3. 删掉根,递归处理没经过根的路径
  4. 重复以上操作

显然每次操作的复杂度为根所在联通块的大小,那只要使根每次取重心,复杂度就会下降到 \(O(n\log n)\)

因此点分治过程就是每次取出重心,处理出包含重心的连通块或路径的答案的过程。

回到这题中,考虑记录下重心所在联通块中所有点到重心的链,判断这些链是否能拼成一条长度为要求长度的路径。

点击查看代码
#include<bits/stdc++.h>
using namespace std;
const int maxn=1e7+10;
int root,sum,n,m,q[205],tot;
int head[maxn],mx[maxn],siz[maxn],cur[maxn],d[maxn],tmp[maxn];
bool vis[maxn],ans[maxn],ju[maxn];
struct node{
    int u,v,w,next;
}e[maxn<<1];
int read(){
    int s=0,f=1;
    char ch=getchar();
    while(ch>'9'||ch<'0'){
        if(ch=='-'){
            f=-1;
        }
        ch=getchar();
    }
    while(ch>='0'&&ch<='9'){
        s=s*10+(ch-'0');
        ch=getchar();
    }
    return s*f;
}
void add(int u,int v,int w){
    e[++tot].u=u;
    e[tot].v=v;
    e[tot].w=w;
    e[tot].next=head[u];
    head[u]=tot;
}
void dfs(int u,int fa){
    siz[u]=1;
    mx[u]=0;
    for(int i=head[u];i;i=e[i].next){
        int v=e[i].v;
        if(v==fa||vis[v])continue;
        dfs(v,u);
        siz[u]+=siz[v];
        mx[u]=max(mx[u],siz[v]);
    }
    mx[u]=max(mx[u],sum-siz[u]);
    if(mx[u]<mx[root]){
        root=u;
    }
}
void calc(int u,int fa){
    cur[++cur[0]]=d[u];
    for(int i=head[u];i;i=e[i].next){
        int v=e[i].v;
        if(v==fa||vis[v]){
            continue;
        }
        d[v]=d[u]+e[i].w;
        calc(v,u);
    }
}
void work(int u){
    int cnt=0;
    for(int i=head[u];i;i=e[i].next){
        int v=e[i].v;
        if(vis[v]){
            continue;
        }
        cur[0]=0;
        d[v]=e[i].w;
        calc(v,u);
        for(int j=cur[0];j;j--){
            for(int k=1;k<=m;k++){
                if(q[k]>=cur[j]){
                    ans[k]|=ju[q[k]-cur[j]];
                }
            }
        }
        for(int j=cur[0];j;j--){
            tmp[++cnt]=cur[j];
            ju[cur[j]]=1;
        }
    }
    for(int i=1;i<=cnt;i++){
        ju[tmp[i]]=0;
    }
}
void divid(int u){
    vis[u]=ju[0]=1;
    work(u);
    for(int i=head[u];i;i=e[i].next){
        int v=e[i].v;
        if(vis[v]){
            continue;
        }
        sum=siz[v];
        mx[root=0]=1e8;
        dfs(v,v);
        divid(root);
    }
}
int main(){
    n=read(),m=read();
    int u,v,w;
    for(int i=1;i<n;i++){
        u=read(),v=read(),w=read();
        add(u,v,w);
        add(v,u,w);
    }
    for(int i=1;i<=m;i++){
        q[i]=read();
    }
    mx[root]=sum=n;
    dfs(1,1);
    divid(root);
    for(int i=1;i<=m;i++){
        if(ans[i]){
            puts("AYE");
        }
        else{
            puts("NAY");
        }
    }
    return 0;
}

[国家集训队]聪聪可可

也是一道比较板的题,对于每个分治中心处理出联通块中每个点到分治中心的距离,得到 \(s_0,s_1,s_2\)\(s_i\) 表示距离分治中心距离模 \(3\)\(i\) 的点的数量。贡献为 \(s_2\times s_1\times 2+s_0\times s_0\),再容斥掉没有经过分治重心的贡献即可。

点击查看代码
#include<bits/stdc++.h>
#define ull unsigned long long
#define ll long long
#define pdi pair<double,int>
#define pii pair<int,int>
#define pb push_back
#define mp make_pair
#define eps 1e-9
using namespace std;
namespace IO{
    template<typename T>
    inline void read(T &x){
        x=0;
        int f=1;
        char ch=getchar();
        while(ch>'9'||ch<'0'){
            if(ch=='-'){
                f=-1;
    }
            ch=getchar();
        }
        while(ch>='0'&&ch<='9'){
            x=x*10+(ch-'0');
            ch=getchar();
        }
        x=(f==1?x:-x);
    }
    template<typename T>
    inline void write(T x){
        if(x<0){
            putchar('-');
            x=-x;
        }
        if(x>=10){
            write(x/10);
        }
        putchar(x%10+'0');
    }
    template<typename T>
    inline void write_endl(T x){
        write(x);
        putchar('\n');
    }
    template<typename T>
    inline void write_space(T x){
        write(x);
        putchar(' ');
    }
}
using namespace IO;
const int N=2e4+10,inf=1e9;
int n,rt,siz[N],mx[N],del[N],tot,ans;
int d[N],s[10];
vector<pii>e[N];
void get_rt(int u,int fa){
    siz[u]=1;
    mx[u]=0;
    for(auto x:e[u]){
        int v=x.first;
        if(del[v]||v==fa){
            continue;
        }
        get_rt(v,u);
        siz[u]+=siz[v];
        mx[u]=max(mx[u],siz[v]);
    }
    mx[u]=max(mx[u],tot-siz[u]);
    if(mx[u]<mx[rt]){
        rt=u;
    }
}
void dfs(int u,int fa){
    s[d[u]%3]++;
    for(auto x:e[u]){
        int v=x.first,w=x.second;
        if(del[v]||v==fa){
            continue;
        }
        d[v]=d[u]+w;
        d[v]%=3;
        dfs(v,u);
    }
}
int calc(int st,int w){
    memset(s,0,sizeof(s));
    d[st]=w;
    dfs(st,0);
    return s[1]*s[2]*2+s[0]*s[0];
}
void solve(int u){
    get_rt(u,0);
    ans=ans+calc(u,0);
    del[u]=1;
    for(auto x:e[u]){
        int v=x.first,w=x.second;
        if(del[v]){
            continue;
        }
        ans=ans-calc(v,w);
        rt=0;
        tot=siz[v];
        get_rt(v,0);
        solve(rt);
    }
}
signed main(){
    #ifndef ONLINE_JUDGE
        freopen("1.in","r",stdin);
        freopen("1.out","w",stdout);
    #endif
    read(n);
    for(int i=1;i<n;i++){
        int u,v,w;
        read(u),read(v),read(w);
        e[u].pb(mp(v,w));
        e[v].pb(mp(u,w));
    }
    tot=n;
    rt=0;
    mx[0]=inf;
    get_rt(1,0);
    solve(rt);
    write(ans/__gcd(ans,n*n)),putchar('/'),write_endl(n*n/__gcd(ans,n*n));
    return 0;
}

树上游戏

点分治好题,难点在于处理贡献。

先看题目要求求什么,树上路径数颜色,很容易想到树分块和树上莫队,但要求对树上所有路径都要求,果断放弃。

考虑点分治,对经过分治中心的路径一起算贡献。因为颜色数和树的大小是同阶的,所以不能一个颜色一个颜色来算贡献,否则复杂度会退化,只能所有颜色一起算贡献。

先处理路径对分治中心 \(u\) 的贡献。容易发现一个性质,在 \(u\) 的一个子树中所有的颜色相同的点中,只有深度最小的点 \(x\) 能给 \(u\) 造成贡献,贡献为 \(siz_x\)。记 \(cnt_c\) 为颜色 \(c\)\(u\) 造成的贡献,则该连通块中经过分治中心的路径对答案造成的贡献为 \(sum=\sum\limits_{c=1}^{max_c}cnt_c\),需要注意的是 \(u\) 属于的颜色 \(col_u\) 对答案贡献为 \(siz_u\),因为任何一个点到分治中心必然经过分治中心。

接下来计算对非分治中心的贡献,因为一定经过分治中心,所以在计算前先去掉所在子树的贡献。对于子树 \(v\) 内一个点 \(x\),新增的贡献为 \(siz_u-siz_v-cnt_{col_x}\),其中 \(cnt_{col_x}\) 为去掉 \(v\) 子树内的贡献后,颜色 \(col_x\)\(u\) 产生的贡献。需要注意的是,统计完子树 \(v\) 的贡献后要将子树 \(v\) 的贡献加回,方便后续计算。

点击查看代码
#include<bits/stdc++.h>
#define ull unsigned long long
#define int long long
#define pdi pair<double,int>
#define pii pair<int,int>
#define pb push_back
#define mp make_pair
#define eps 1e-9
using namespace std;
namespace IO{
    template<typename T>
    inline void read(T &x){
        x=0;
        int f=1;
        char ch=getchar();
        while(ch>'9'||ch<'0'){
            if(ch=='-'){
                f=-1;
    }
            ch=getchar();
        }
        while(ch>='0'&&ch<='9'){
            x=x*10+(ch-'0');
            ch=getchar();
        }
        x=(f==1?x:-x);
    }
    template<typename T>
    inline void write(T x){
        if(x<0){
            putchar('-');
            x=-x;
        }
        if(x>=10){
            write(x/10);
        }
        putchar(x%10+'0');
    }
    template<typename T>
    inline void write_endl(T x){
        write(x);
        putchar('\n');
    }
    template<typename T>
    inline void write_space(T x){
        write(x);
        putchar(' ');
    }
}
using namespace IO;
const int N=1e5+10,inf=1e9;
int col[N],n,del[N],ans[N];
int mx[N],rt,siz[N],tot;
int sum,cnt[N],Cnt[N];
vector<int>e[N];
void Get_rt(int u,int fa){
    siz[u]=1;
    mx[u]=0;
    for(auto v:e[u]){
        if(del[v]||v==fa){
            continue;
        }
        Get_rt(v,u);
        siz[u]+=siz[v];
        mx[u]=max(mx[u],siz[v]);
    }
    mx[u]=max(mx[u],tot-siz[u]);
    if(mx[u]<mx[rt]){
        rt=u;
    }
}
void Add(int u,int fa){
    cnt[col[u]]++;
    if(cnt[col[u]]==1&&col[u]!=col[rt]){
        Cnt[col[u]]+=siz[u];
        sum+=siz[u];
    }
    for(auto v:e[u]){
        if(del[v]||v==fa){
            continue;
        }
        Add(v,u);
    }
    cnt[col[u]]--;
}
void Del(int u,int fa){
    cnt[col[u]]++;
    if(cnt[col[u]]==1&&col[u]!=col[rt]){
        Cnt[col[u]]-=siz[u];
        sum-=siz[u];
    }
    for(auto v:e[u]){
        if(del[v]||v==fa){
            continue;
        }
        Del(v,u);
    }
    cnt[col[u]]--;
}
void Update(int u,int fa,int belong){
    cnt[col[u]]++;
    if(cnt[col[u]]==1&&col[u]!=col[rt]){
        sum+=tot-siz[belong]-Cnt[col[u]];
    }
    ans[u]+=sum;
    for(auto v:e[u]){
        if(del[v]||v==fa){
            continue;
        }
        Update(v,u,belong);
    }
    if(cnt[col[u]]==1&&col[u]!=col[rt]){
        sum-=tot-siz[belong]-Cnt[col[u]];
    }
    cnt[col[u]]--;
}
void calc(int u){
    sum=siz[u];
    for(auto v:e[u]){
        if(del[v]){
            continue;
        }
        Add(v,u);
    }
    ans[u]+=sum;
    for(auto v:e[u]){
        if(del[v]){
            continue;
        }
        sum-=siz[v];
        Del(v,u);
        Update(v,u,v);
        Add(v,u);
        sum+=siz[v];
    }
    for(auto v:e[u]){
        if(del[v]){
            continue;
        }
        Del(v,u);
    }
}
void Divid(int u){
    del[u]=1;
    Get_rt(u,0);
    calc(u);
    for(auto v:e[u]){
        if(del[v]){
            continue;
        }
        tot=siz[v];
        rt=0;
        Get_rt(v,0);
        Divid(rt);
    }
}
signed main(){
    #ifndef ONLINE_JUDGE
        freopen("1.in","r",stdin);
        freopen("1.out","w",stdout);
    #endif
    read(n);
    for(int i=1;i<=n;i++){
        read(col[i]);
    }
    for(int i=1,u,v;i<n;i++){
        read(u),read(v);
        e[u].pb(v);
        e[v].pb(u);
    }
    mx[0]=inf;
    rt=0;
    tot=n;
    Get_rt(1,0);
    Divid(rt);
    for(int i=1;i<=n;i++){
        write_endl(ans[i]);
    }
    return 0;
}
posted @ 2023-03-15 16:43  luo_shen  阅读(22)  评论(0)    收藏  举报