点分治学习笔记

参考蓝书发篇学习笔记。。。

一.算法梗概:

点分治是一种用于在一棵树上,无对路劲进行修改的操作,对某些具有限定条件的路径进行静态统计的算法。
点分治一般用来处理无根树,我们可以随意认定根节点。

二.实现过程:

我们拿一道例题来说一下:

P4178 Tree

我们认定根节点为 \(root\),那么对于 \(root\) 而言,树上的路径有两种:
1.经过 \(root\) 的路径;
2.不经过 \(root\) 但包含在 \(root\) 的某棵子树内。

对于路径种类1,我们可以从 \(root\) 点出发,对整棵树进行 \(\text{dfs}\),求出点 \(i\)\(root\) 的距离 \(dis_i\),同时可以求出 \(b_i\),表示 \(i\) 属于 \(root\) 的哪一棵子树。特别地,\(b_{root}=root\)

代码如下:

点击查看代码
inline void getdis(int x,int fa,int d,int from)
{
    a[++now]=x,b[x]=from,dis[x]=d;
    for(int i=head[x];i;i=e[i].nxt)
    {
        int y=e[i].to,z=e[i].len;
        if(y==fa || vis[y]) continue;
        getdis(y,x,d+z,from);
    }
    return;
}

而我们要统计的,就是满足如下所有条件的点对 \((x,y)\) 的数量:(1)\(b_x\ne b_y\);(2)\(dis_x+dis_y\leqslant k\)

对于路径种类2,我们可以分治一下,将 \(root\) 的每棵子树递归处理。

那么我们最常见的 \(calc\) 函数的写法,就是指针扫描数组的方法:
将树上每个节点放到一个数组 \(a\) 里去,然后按照节点的 \(dis\) 值排序。显然,\(l\) 在向右扫描的过程中,恰好使得 \(d_{a_l}+d_{a_r}\leqslant k\)\(r\) 是从右向左单调递减。那么我们用 \(cnt_s\) 来统计 \(l+1\sim r\) 之间满足 \(b_{a_i}\)\(i\) 的个数,那么,当某条路径的某一端为 \(a_l\) 时,另一端的合法的个数就为 \(r-l-cnt_{b_{a_l}}\)

代码如下:

点击查看代码
inline int calc(int x)
{
    tot=0,a[++tot]=x,b[x]=x,dis[x]=0,now[b[x]]=1;
    for(int i=head[x];i;i=e[i].nxt)
    {
        int y=e[i].to,z=e[i].len;
        if(vis[y]) continue;
        getdis(y,x,z,y);
    }
    sort(a+1,a+tot+1,cmp);
    int l=1,r=tot,res=0;
    while(l<=r)//一定要注意,一定不要写成l<r,因为这样会导致l=r时直接退出,但是有一个没有减掉
    {
        while(l<r && dis[a[l]]+dis[a[r]]<=k)
        {
            res+=r-l+1-now[b[a[l]]];
            now[b[a[l]]]--;l++;
        }
        now[b[a[r]]]--;r--;
    }
    return res;
}

若递归的深度为 \(dep\),那么算法的时间复杂度就为 \(\mathcal{O}(dep· n\log n)\)

但是我们想一种情况,若树的形态为一条链,那么最坏情况下,每次根都选到链的端点,那么递归深度就需要 \(n\) 层,算法时间复杂度就退化成 \(\mathcal{O}(n^2\log n)\)。所以,我们要对根的选择进行一个优化,每次都找到树的重心作为根节点。

代码如下:

点击查看代码
inline void getrt(int x,int fa,int tot)
{
    siz[x]=1,hson[x]=0;
    for(int i=head[x];i;i=e[i].nxt)
    {
        int y=e[i].to;
        if(y==fa || vis[y]) continue;
        getrt(y,x,tot);
        siz[x]+=siz[y];
        hson[x]=max(hson[x],siz[y]);
    }
    hson[x]=max(hson[x],tot-siz[x]);
    if(!root || hson[x]<hson[root]) root=x;
    return;
}

解释:因为此时 \(root\) 的每棵子树的大小都不会超过整棵树的一半,那么就限制了递归层数最多为 \(\mathcal{O}(\log n)\),那么现在算法的时间复杂度就变成了 \(\mathcal{O}(n\log^2n)\)

完整代码:

点击查看代码
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int MAXN=4e4+5;

inline int read()
{
    int x=0,f=1;char ch=getchar();
    while(!isdigit(ch))
    {
        if(ch=='-') f=-1;
        ch=getchar();
    }
    while(isdigit(ch))
    {
        x=(x<<1)+(x<<3)+(ch^48);
        ch=getchar();
    }
    return x*f;
}

int n,k;

struct edge
{
    int to,nxt,len;
}e[MAXN<<1];

int head[MAXN],cnt;

inline void add(int x,int y,int z)
{
    e[++cnt].to=y;
    e[cnt].len=z;
    e[cnt].nxt=head[x];
    head[x]=cnt;
    return;
}

int root,tot;
int siz[MAXN],hson[MAXN];
int dis[MAXN],a[MAXN],b[MAXN];
bool vis[MAXN];
int now[MAXN];

inline void getrt(int x,int fa,int tot)
{
    siz[x]=1,hson[x]=0;
    for(int i=head[x];i;i=e[i].nxt)
    {
        int y=e[i].to;
        if(y==fa || vis[y]) continue;
        getrt(y,x,tot);
        siz[x]+=siz[y];
        hson[x]=max(hson[x],siz[y]);
    }
    hson[x]=max(hson[x],tot-siz[x]);
    if(!root || hson[x]<hson[root]) root=x;
    return;
}

inline void getdis(int x,int fa,int d,int from)
{
    a[++tot]=x,b[x]=from,dis[x]=d,now[b[x]]++;
    for(int i=head[x];i;i=e[i].nxt)
    {
        int y=e[i].to,z=e[i].len;
        if(y==fa || vis[y]) continue;
        getdis(y,x,d+z,from);
    }
    return;
}

inline bool cmp(int a,int b) {return dis[a]<dis[b];}

int ans;

inline int calc(int x)
{
    tot=0,a[++tot]=x,b[x]=x,dis[x]=0,now[b[x]]=1;
    for(int i=head[x];i;i=e[i].nxt)
    {
        int y=e[i].to,z=e[i].len;
        if(vis[y]) continue;
        getdis(y,x,z,y);
    }
    sort(a+1,a+tot+1,cmp);
    int l=1,r=tot,res=0;
    while(l<=r)
    {
        while(l<r && dis[a[l]]+dis[a[r]]<=k)
        {
            res+=r-l+1-now[b[a[l]]];
            now[b[a[l]]]--;l++;
        }
        now[b[a[r]]]--;r--;
    }
    return res;
}

inline void solve(int x)
{
    vis[x]=true;ans+=calc(x);
    for(int i=head[x];i;i=e[i].nxt)
    {
        int y=e[i].to;
        if(vis[y]) continue;
        root=0;
        getrt(y,0,siz[y]);
        solve(root);
    }
    return;
}

signed main()
{
    n=read();
    for(int i=1;i<n;i++)
    {
        int x=read(),y=read(),z=read();
        add(x,y,z),add(y,x,z);
    }
    k=read();
    hson[0]=n,getrt(1,0,n);solve(root);
    printf("%lld\n",ans);
    return 0;
}

典型例题

例一 P3806 【模板】点分治1

纯纯的板子,只是从统计数量变成了是否存在的问题。

点击查看代码
#include<bits/stdc++.h>
using namespace std;
const int MAXN=1e4+5;

inline int read()
{
    int x=0,f=1;char ch=getchar();
    while(!isdigit(ch))
    {
        if(ch=='-') f=-1;
        ch=getchar();
    }
    while(isdigit(ch))
    {
        x=(x<<1)+(x<<3)+(ch^48);
        ch=getchar();
    }
    return x*f;
}

struct edge
{
    int to,nxt,len;
}e[MAXN<<1];

int head[MAXN],cnt;

inline void add(int x,int y,int z)
{
    e[++cnt].to=y;
    e[cnt].len=z;
    e[cnt].nxt=head[x];
    head[x]=cnt;
    return;
}

int n,m,root;
int siz[MAXN],hson[MAXN];
bool vis[MAXN],flag[MAXN];
int a[MAXN],b[MAXN],dis[MAXN],tot;

inline void getrt(int x,int fa,int tot)
{
    siz[x]=1,hson[x]=0;
    for(int i=head[x];i;i=e[i].nxt)
    {
        int y=e[i].to;
        if(y==fa || vis[y]) continue;
        getrt(y,x,tot);
        siz[x]+=siz[y];
        hson[x]=max(siz[y],hson[x]);
    }
    hson[x]=max(hson[x],tot-siz[x]);
    if(!root || hson[x]<hson[root]) root=x;
    return;
}

inline void getdis(int x,int fa,int d,int from)
{
    a[++tot]=x,b[x]=from,dis[x]=d;
    for(int i=head[x];i;i=e[i].nxt)
    {
        int y=e[i].to,z=e[i].len;
        if(y==fa || vis[y]) continue;
        getdis(y,x,d+z,from);
    }
    return;
}

inline bool cmp(int a,int b)
{
    return dis[a]<dis[b];
}

int ask[MAXN];

inline void calc(int x)
{
    tot=0,a[++tot]=x,b[x]=x,dis[x]=0;
    for(int i=head[x];i;i=e[i].nxt)
    {
        int y=e[i].to,z=e[i].len;
        if(vis[y]) continue;
        getdis(y,x,z,y);
    }
    sort(a+1,a+tot+1,cmp);
    for(int i=1;i<=m;i++)
    {
        int l=1,r=tot;
        if(flag[i]) continue;
        while(l<r)
        {
            if(dis[a[l]]+dis[a[r]]>ask[i]) r--;
            else if(dis[a[l]]+dis[a[r]]<ask[i]) l++;
            else if(b[a[l]]==b[a[r]])
            {
                if(dis[a[r]]==dis[a[r-1]]) r--;
                else l++;
            }
            else {flag[i]=true;break;}
        }
    }
}

inline void solve(int x)
{
    vis[x]=true;calc(x);
    for(int i=head[x];i;i=e[i].nxt)
    {
        int y=e[i].to;
        if(vis[y]) continue;
        root=0;
        getrt(y,0,siz[y]);
        solve(root);
    }
    return;
}

int main()
{
    n=read(),m=read();
    for(int i=1;i<=n-1;i++)
    {
        int x=read(),y=read(),z=read();
        add(x,y,z),add(y,x,z);
    }
    for(int i=1;i<=m;i++)
    {
        ask[i]=read();
        if(!ask[i]) flag[i]=true;
    }
    hson[0]=n;getrt(1,0,n);solve(root);
    for(int i=1;i<=m;i++)
    {
        if(flag[i]) printf("AYE\n");
        else printf("NAY\n");
    }
    return 0;
}
posted @ 2023-06-30 15:56  Code_AC  阅读(3)  评论(0编辑  收藏  举报