树分治入门(才开坑)

点分治

点分治,可能是一种针对可带权树上简单路径统计问题的算法。(也可能是因为博主理解的浅薄)
是分治思想在线段树上的体现。每次找树上的重心来计算跨过重心的答案就像序列分治每次统计跨过分治中心的答案。

前置知识

dfs,找重心以及树上的一些处理方法。

重心

重心就和我们的分治重心一样直接影响时间复杂度,所以要选择一个最优的点。
重心显然是最优的,但是正常找重心需要两遍bfs或做dp,显然不太优秀,所以有一种写法是从本层的根开始深度优先遍历计算出子树大小,将传入值与以自身为根子树大小之差作为向本层根方向的子树大小(将父亲那一坨当成根),再取最大子树最小的点为重心
这可能不是正确的根,但是有神人已经替我们证明这样的时间复杂度,虽然我并没有看懂,所以我只能把链接放在这里

思路

每次找到树的重心,然后计算以它为lca的答案。
然后遍历它的每一个子树,然后对于每个子树递归处理即可。

代码

不同的题统计答案以及做的事情不一样(doit函数不一样)。
但求重心和 \(solve\) 应该基本上一样。

void getrt(int u,int fa)
{
    siz[u]=1;dp[u]=0;
    for(int i=hd[u];i;i=nxt[i])
    {
        int v=go[i];
        if(v==fa||vis[v])continue;//已经用过的不用计算
        getrt(v,u);
        siz[u]+=siz[v];
        dp[u]=max(dp[u],siz[v]);//记录最大子树
    }
    dp[u]=max(dp[u],sum-siz[u]);
    if(dp[u]<dp[rt])rt=u; 
}
void solve(int u)
{
    vis[u]=judge[0]=1;//要统计上从重心出发的路径
    doit(u);//统计u为根的答案
    for(int i=hd[u];i;i=nxt[i])
    {
        int v=go[i];
        if(vis[v])continue;
        sum=siz[v],dp[0]=inf,rt=0;
        getrt(v,0);
        solve(rt);//递归处理
    }
}

例题

模板题

对于每次处理答案时候,开一个桶数组,每次找到一个距离\(x\) ,查找\(x-k\)是否存在即可。

点击查看代码
#include <bits/stdc++.h>
using namespace std;
const int N=1e4+10,M=1010,inf=1e9,maxn=1e7+10;
int n,m,tot,hd[N],nxt[N<<1],go[N<<1],jz[N<<1],siz[N],dp[N],rt,dis[N],cd[N],q[N],sum,vis[N],query[M];
bool judge[maxn],bk[M];
void add(int u,int v,int w)
{
    go[++tot]=v;nxt[tot]=hd[u];hd[u]=tot;jz[tot]=w;
    go[++tot]=u;nxt[tot]=hd[v];hd[v]=tot;jz[tot]=w;
}
void getrt(int u,int fa)
{
    siz[u]=1;dp[u]=0;
    for(int i=hd[u];i;i=nxt[i])
    {
        int v=go[i];
        if(v==fa||vis[v])continue;
        getrt(v,u);
        siz[u]+=siz[v];
        dp[u]=max(dp[u],siz[v]);
    }
    dp[u]=max(dp[u],sum-siz[u]);
    if(dp[u]<dp[rt])rt=u; 
}
void getdis(int u,int fa)
{
    cd[++cd[0]]=dis[u];
    for(int i=hd[u];i;i=nxt[i])
    {
        int v=go[i];
        if(v==fa||vis[v])continue;
        dis[v]=dis[u]+jz[i];
        getdis(v,u);
    }
    return;
}
void doit(int u)
{
    int p=0;
    for(int i=hd[u];i;i=nxt[i])
    {
        int v=go[i];
        if(vis[v])continue;
        cd[0]=0,dis[v]=jz[i];
        getdis(v,u);
        for(int j=1;j<=cd[0];j++)
            for(int k=1;k<=m;k++)
                if(query[k]>=cd[j])
                    bk[k]|=judge[query[k]-cd[j]];
        for(int j=1;j<=cd[0];j++)
        if(cd[j]<maxn)q[++p]=cd[j],judge[cd[j]]=1;
    }
    for(int i=1;i<=p;i++)judge[q[i]]=0;
}
void solve(int u)
{
    vis[u]=judge[0]=1;
    doit(u);
    for(int i=hd[u];i;i=nxt[i])
    {
        int v=go[i];
        if(vis[v])continue;
        sum=siz[v],dp[0]=inf,rt=0;
        getrt(v,0);
        solve(rt);
    }
}
int main()
{
    scanf("%d%d",&n,&m);
    for(int i=1;i<n;i++)
    {
        int u,v,w;
        scanf("%d%d%d",&u,&v,&w);
        add(u,v,w);
    }
    for(int i=1;i<=m;i++)scanf("%d",&query[i]);
    getrt(1,0);
    solve(1);
    for(int i=1;i<=m;i++)
        if(bk[i])printf("AYE\n");
        else printf("NAY\n");
    return 0;
}

P2634 [国家集训队] 聪聪可可

可以发现与上一个题非常的像,所以直接把模 \(3\) 然后像上一题一样。
记得求gcd,不要像我一样不调用而wa一发

点击查看代码
#include <bits/stdc++.h>
using namespace std;
const int N=1e5+10,inf=1e9+10;
#define ll long long 
int siz[N],dp[N],vis[N],sum,cd[N],dis[N],ju[10],rt;
ll n,ans;
struct node 
{
	int v,z;
};
vector<node> ve[N];
void getrt(int u,int f)
{
	siz[u]=1,dp[u]=0;
	int len=ve[u].size();
	for(int i=0;i<len;i++)
	{
		int v=ve[u][i].v;
		if(v==f||vis[v])continue;
		getrt(v,u);
		siz[u]+=siz[v];
		dp[u]=max(dp[u],siz[v]);
	}
	dp[u]=max(dp[u],sum-siz[u]); 
	if(dp[u]<dp[rt]) rt=u;
}
void getdis(int u,int f)
{
	cd[++cd[0]]=dis[u];
	int len=ve[u].size();
	for(int i=0;i<len;i++)
	{
		int v=ve[u][i].v;
		if(v==f||vis[v])continue;
		dis[v]=(dis[u]+ve[u][i].z)%3; 
		getdis(v,u);
	}
}
void doit(int u)
{
	int len=ve[u].size();
	for(int i=0;i<len;i++)
	{
		int v=ve[u][i].v;
		if(vis[v])continue;
		cd[0]=0,dis[v]=ve[u][i].z;
		getdis(v,u);
		for(int j=1;j<=cd[0];j++)ans+=ju[(3-cd[j])%3];
		for(int j=1;j<=cd[0];j++)ju[cd[j]]++;
	} 
	for(int i=0;i<3;i++)ju[i]=0;
}
void solve(int u)
{
	vis[u]=1,ju[0]=1;
	doit(u);
	int len=ve[u].size(); 
	for(int i=0;i<len;i++)
	{
		int v=ve[u][i].v;
		if(vis[v])continue;
		sum=siz[v],dp[0]=inf,rt=0;
		getrt(v,0);
		solve(rt);
	}
}
ll gcd(ll x,ll y)
{
	if(y==0)return x;
	else return gcd(y,x%y);
}
int main()
{
	ios::sync_with_stdio(0);
	cin.tie(0);cout.tie(0);
	cin>>n;
	for(int i=1;i<n;i++)
	{
		int u,v;ll w;cin>>u>>v>>w;
		ve[u].push_back({v,w%3});
		ve[v].push_back({u,w%3}); 
	} 
	dp[0]=inf;sum=n;
	getrt(1,0);
	solve(rt);
	int gc=gcd(ans*2+n,n*n); 
	cout<<(ans*2+n)/gc<<"/"<<(n*n)/gc;
	return 0;
} 
posted @ 2025-03-13 16:43  exCat  阅读(12)  评论(0)    收藏  举报