Loading

长链剖分优化树形DP

长链剖分优化树形DP

当 dp 状态与深度有关时,考虑长链剖分。

基础题:CF1009F Dominant Indices

首先需要列出普通 dp 方程。

\(f_{x,d}\) 表示节点 \(x\) 子树中与 \(x\) 距离为 \(d\) 的点的个数。有转移:

\[f_{x,d}=\sum_{y\in son_x}f_{y,d-1} \]

直接做时 \(\mathcal O(n^2)\) 的。用长链剖分,可以每次直接继承长儿子的信息,做到总复杂度 \(\mathcal O(n)\),具体如下:

首先开一个整体 dp 数组申请一片空间,令 \(f_x\) 表示 \(x\) 在 dp 数组中的起点(指针)。设 \(y\)\(x\)

的长儿子,则令 \(f_v\leftarrow f_u+1\) 这样 \(f_{v,*}\) 的信息就可以直接继承到 \(f_{u}\)。对于 \(x\) 的非长儿子 \(z\),一定是一条长链的链顶,我们为这条长链申请一段连续空间,并使 \(f_z\) 指向起始位置。总结起来就两句话:“同一条长链共享内存,不同的长链分配内存"。

从长儿子继承信息后,暴力从短儿子合并,根据长链剖分的性质,复杂度为 \(\mathcal O(n)\)

代码如下:

#include<bits/stdc++.h>
using namespace std;
const int NN=1e6+5;
int n,ans[NN],son[NN],lng[NN],fa[NN];
int dp[NN];//虽然f都是指针,但具体存储的位置还是要开好,还有个方法是存位置,但用指针好看一些 
int*f[NN],*now=dp;
vector<int> ed[NN];
void DFS(int x){
    lng[x]=1; 
    for(int y:ed[x]){
        if(y==fa[x])continue;
        fa[y]=x;
        DFS(y);
        lng[x]=max(lng[x],lng[y]+1);
        if(lng[son[x]]<lng[y])son[x]=y;
    }
    return;
}
void DP(int x){
    f[x][0]=1;
    if(!son[x])return;
    f[son[x]]=f[x]+1;//同一条长链共享内存 
    DP(son[x]);
    ans[x]=ans[son[x]]+1;    
    for(int y:ed[x]){
        if(y==fa[x]||y==son[x])continue;
        f[y]=now;now+=lng[y];//不同的长链分配内存 
        DP(y);
        for(int i=1;i<=lng[y];i++){
            f[x][i]+=f[y][i-1];
            if(f[x][i]>f[x][ans[x]]||f[x][i]==f[x][ans[x]]&&i<ans[x])ans[x]=i;
        }
    }
    if(f[x][ans[x]]==1)ans[x]=0;
    return;
}
int main(){
    cin>>n;
    for(int i=1;i<n;i++){
        int u,v;cin>>u>>v;
        ed[u].push_back(v);
        ed[v].push_back(u);
    }
    DFS(1);
    f[1]=now,now+=lng[1];
    DP(1);
    for(int i=1;i<=n;i++)cout<<ans[i]<<"\n";
    return 0;
} 

进阶题:P4292 [WC2010] 重建计划

首先二分答案后变成求边数 \(\in[L,U]\) 的权值和最大的链的权值和。首先写出普通 dp 方程,令 \(f_{x,d}\) 表示以 \(x\) 为根的子树中,一端是 \(x\) 的长度为 \(d\) 的链中权值和最大值。枚举 \(x\) 的儿子 \(y\),有转移方程:

\[\begin{aligned} &f_{x,d}\leftarrow \max(f_{x,d},f_{y,d-1})\\ &ans\leftarrow ans+\max_{L\le i+j+1\le U}f_{x,i}+f_{y,j} \end{aligned} \]

发现在更新答案时,当 \(j\) 确定后变成了求区间最值,于是我们需要一个线段树维护整体 dp

数组。由于加入了线段树,无法再用指针的形式进行处理,这里设 \(v_x\) 表示 \(f_{x,0}\) 在整体 dp 数组中的下标,其他的处理方式照常。

代码如下:

#include<bits/stdc++.h>
using namespace std;
const int NN=1e6+5;
const double eps=1e-5,INF=1e9;
int n,son[NN],lng[NN],fa[NN],L,U;
double ans,val[NN];//由于维护了线段数,这里不能用指针处理 
int v[NN],now=1;
typedef pair<int,double> paid;
vector<paid> ed[NN];
void DFS(int x){
    lng[x]=1;son[x]=0;
    for(paid eg:ed[x]){
        int y=eg.first;
        if(y==fa[x])continue;
        fa[y]=x;val[y]=eg.second;
        DFS(y);
        lng[x]=max(lng[x],lng[y]+1);
        if(lng[son[x]]<lng[y])son[x]=y;
    }
    return;
}
struct SegTrNode{
    double mx,add;
    #define ls(x) (x<<1)
    #define rs(x) (x<<1|1)
    #define mx(x) sgt[x].mx
    #define add(x) sgt[x].add
}sgt[NN<<2];
inline void Up(int p){
    mx(p)=max(mx(ls(p)),mx(rs(p)));
    return;
}
inline void Alter(int p,double v){
    add(p)+=v;mx(p)+=v;
    return;
}
inline void Down(int p){
    Alter(ls(p),add(p));
    Alter(rs(p),add(p));
    add(p)=0;
    return;
}
double Max(int l,int r,int p=1,int L=1,int R=n){
    if(l>R||L>r)return -INF;
    if(l<=L&&R<=r)return mx(p);
    int mid=L+R>>1;
    Down(p);
    return max(Max(l,r,ls(p),L,mid),Max(l,r,rs(p),mid+1,R));
}
void Prune(int pos,double val,int p=1,int L=1,int R=n){
    if(L==R){
        mx(p)=val;
        return;
    }
    int mid=L+R>>1;
    Down(p);
    if(pos<=mid)Prune(pos,val,ls(p),L,mid);
    else        Prune(pos,val,rs(p),mid+1,R);
    return Up(p);
}
void Add(int l,int r,double v,int p=1,int L=1,int R=n){
    if(l>R||L>r)return;
    if(l<=L&&R<=r)return Alter(p,v);
    int mid=L+R>>1;
    Down(p);
    Add(l,r,v,ls(p),L,mid);
    Add(l,r,v,rs(p),mid+1,R);
    return Up(p);
}
void Build(int p=1,int L=1,int R=n){
    add(p)=0,mx(p)=-INF;
    if(L==R)return;
    int mid=L+R>>1;
    Build(ls(p),L,mid);
    Build(rs(p),mid+1,R);
    return;
}
inline double Ask(int pos){return Max(pos,pos);}
inline int Num(int x,int i){return v[x]+i;}
void DP(int x){
    Prune(Num(x,0),0);
    if(!son[x])return;
    v[son[x]]=v[x]+1;//同一条长链共享内存 
    DP(son[x]);
    Add(Num(x,1),Num(x,lng[x]-1),val[son[x]]);
    for(paid eg:ed[x]){
        int y=eg.first;double val=eg.second;
        if(y==fa[x]||y==son[x])continue;
        v[y]=now;now+=lng[y];//不同的长链分配内存 
        DP(y);
        for(int i=0;i<lng[y];i++)
            ans=max(ans,Ask(Num(y,i))+Max(Num(x,max(0,L-i-1)),Num(x,min(lng[x]-1,U-i-1)))+val);
        if(ans>0)throw "success";
        for(int i=1;i<=lng[y];i++)
            Prune(Num(x,i),max(Ask(Num(x,i)),Ask(Num(y,i-1))+val));
    }
    ans=max(ans,Max(Num(x,L),Num(x,min(lng[x]-1,U))));
    if(ans>0)throw "success";//throw大法好 
    return;
}
struct Edge{
    int u,v,w;
    inline void read(){scanf("%d%d%d",&u,&v,&w);}
}edge[NN];
inline bool Check(double s){
    ans=-INF;
    for(int i=1;i<=n;i++)ed[i].clear();
    for(int i=1;i<n;i++){
        int u=edge[i].u,v=edge[i].v;
        double w=edge[i].w-s;
        ed[u].push_back({v,w});
        ed[v].push_back({u,w});
    }
    DFS(1);
    Build();
    v[1]=now=1;now+=lng[1];
    try{DP(1);}catch(...){}
    return ans>0;
}
int main(){
    scanf("%d%d%d",&n,&L,&U);
    for(int i=1;i<n;i++)edge[i].read();
    double l=0,r=1e6;
    while(r-l>eps){
        double mid=(l+r)/2;
        if(Check(mid))l=mid;
        else r=mid;
    }
    printf("%.3f",l);
    return 0;
}

变式(tai)题:tf7z不萌萌模拟赛T3 浏览

题面:

给你一个 \(n\) 个节点的树,一个合法的联通块满足任意两个点 \(dis\le k\)\(dis\) 定义为经过的边数,求合法的联通块数对 \(998244353\) 取模的结果。

\(1\le k< n\le 5\times 10^5\)

题解:

\(f_{x,d}\) 表示 \(x\) 为根的子树中与 \(x\) 最远的距离为 \(d\) 的合法联通块树,枚举 \(x\) 的儿子 \(y\),可以列出朴素 dp 方程:

\[f_{x,i}\leftarrow f_{x,i}\times(1+\sum_{j=0}^{\min(i,k-i)-1}f_{y,j})+f_{y,i-1}\times\sum_{j=0}^{\min(i-1,k-i)}f_{x,j} \]

不难注意到如果要用长链剖分优化 dp 的话要使用维护区间和的线段树。

但一般长链剖分是与 \(lng_y\) 相关,这里是与 \(lng_x\) 相关,是不行的。

发现当 \(i\in [lng_y+2,k-lng_y-2]\) 时与 \(y\) 相关的部分是定值,可以直接用区间乘解决,这样复杂度就正确了。

代码如下:

#include<bits/stdc++.h>
#define int long long
using namespace std;
const int NN=1e6+5,MOD=998244353;
int n,k,ans;
vector<int> ed[NN];
int v[NN],dp[NN],son[NN],lng[NN],now;
int Num(int x,int i){
	return v[x]+i;
}
void DFS(int x,int fa){
	lng[x]=1;
	for(int y:ed[x]){
		if(y==fa)continue;
		DFS(y,x);
		lng[x]=max(lng[x],lng[y]+1);
		if(lng[son[x]]<lng[y])son[x]=y;
	}
}
struct SegTr{
	int sum,mul;
	SegTr(){sum=0,mul=1;}
	#define ls(x) (x<<1)
	#define rs(x) (x<<1|1)
	#define sum(x) sgt[x].sum
	#define mul(x) sgt[x].mul
}sgt[NN<<2];
void Alter(int p,int v){
	(mul(p)*=v)%=MOD;
	(sum(p)*=v)%=MOD;
	return;
}
void Down(int p){
    if(mul(p)==1)return;//卡常
	Alter(ls(p),mul(p));
	Alter(rs(p),mul(p));
	mul(p)=1;
	return;
}
void Up(int p){
	sum(p)=(sum(ls(p))+sum(rs(p)))%MOD;
	return;
}
void Add(int pos,int val,int p=1,int L=1,int R=n){
	if(L==R)return void((sum(p)+=val)%=MOD);
	int mid=L+R>>1;
	Down(p);
	if(pos<=mid)Add(pos,val,ls(p),L,mid);
	else		Add(pos,val,rs(p),mid+1,R);
	return Up(p);
}
int Sum(int l,int r,int p=1,int L=1,int R=n){
	if(l>R||L>r)return 0;
	if(l<=L&&R<=r)return sum(p);
	int mid=L+R>>1;
	Down(p);
	return (Sum(l,r,ls(p),L,mid)+Sum(l,r,rs(p),mid+1,R))%MOD;
}
void Mul(int l,int r,int v,int p=1,int L=1,int R=n){
	if(l>R||L>r)return;
	if(l<=L&&R<=r)return Alter(p,v);
	int mid=L+R>>1;
	Down(p);
	Mul(l,r,v,ls(p),L,mid);
	Mul(l,r,v,rs(p),mid+1,R);
	return Up(p);
}
int Ask(int pos){
	return Sum(pos,pos);
}
int QZ(int x,int p){
	return Sum(Num(x,0),Num(x,min(lng[x]-1,p)));
}
struct Node{int pos,val;}q[NN];int top=0;//数组不能定义在Merge里,因为定义申请内存的时间是线性的 
void Merge(int x,int y){
	for(int i=1;i<=min(lng[y],k);i++){
		q[++top]={i,Ask(Num(y,i-1))*QZ(x,min(i-1,k-i))%MOD};
		if(i<k)q[++top]={i,Ask(Num(x,i))*QZ(y,min(i,k-i)-1)%MOD};
	}
	for(int i=max(lng[y]+1,k-lng[y]);i<k;i++){
		q[++top]={i,Ask(Num(x,i))*QZ(y,min(i,k-i)-1)%MOD};
	}
	Mul(Num(x,lng[y]+1),Num(x,k-lng[y]-1),(QZ(y,k)+1));
	for(;top;top--)Add(Num(x,q[top].pos),q[top].val);
	return;
}
void DP(int x,int fa){
	Add(Num(x,0),1);
	if(son[x]){
		v[son[x]]=v[x]+1;
		DP(son[x],x);
		for(int y:ed[x]){
			if(y==fa||y==son[x])continue;
			v[y]=now,now+=lng[y];
			DP(y,x);
			Merge(x,y);
		}
	}
	(ans+=QZ(x,k))%=MOD;
	return;
}
signed main(){
	cin>>n>>k;
	for(int i=1;i<n;i++){
		int u,v;cin>>u>>v;
		ed[u].push_back(v);
		ed[v].push_back(u);
	}
	DFS(1,1);
	v[1]=now=1,now+=lng[1];
	DP(1,1);
	cout<<ans;
	return 0;
} 
posted @ 2025-02-07 21:37  lupengheyyds  阅读(141)  评论(0)    收藏  举报