长链剖分优化树形DP
长链剖分优化树形DP
当 dp 状态与深度有关时,考虑长链剖分。
基础题:CF1009F Dominant Indices
首先需要列出普通 dp 方程。
令 \(f_{x,d}\) 表示节点 \(x\) 子树中与 \(x\) 距离为 \(d\) 的点的个数。有转移:
直接做时 \(\mathcal O(n^2)\) 的。用长链剖分,可以每次直接继承长儿子的信息,做到总复杂度 \(\mathcal O(n)\),具体如下:
首先开一个整体 dp 数组申请一片空间,令 \(f_x\) 表示 \(x\) 在 dp 数组中的起点(指针)。设 \(y\) 为 \(x\)
的长儿子,则令 \(f_v\leftarrow f_u+1\) 这样 \(f_{v,*}\) 的信息就可以直接继承到 \(f_{u}\)。对于 \(x\) 的非长儿子 \(z\),一定是一条长链的链顶,我们为这条长链申请一段连续空间,并使 \(f_z\) 指向起始位置。总结起来就两句话:“同一条长链共享内存,不同的长链分配内存"。
从长儿子继承信息后,暴力从短儿子合并,根据长链剖分的性质,复杂度为 \(\mathcal O(n)\)。
代码如下:
#include<bits/stdc++.h>
using namespace std;
const int NN=1e6+5;
int n,ans[NN],son[NN],lng[NN],fa[NN];
int dp[NN];//虽然f都是指针,但具体存储的位置还是要开好,还有个方法是存位置,但用指针好看一些
int*f[NN],*now=dp;
vector<int> ed[NN];
void DFS(int x){
lng[x]=1;
for(int y:ed[x]){
if(y==fa[x])continue;
fa[y]=x;
DFS(y);
lng[x]=max(lng[x],lng[y]+1);
if(lng[son[x]]<lng[y])son[x]=y;
}
return;
}
void DP(int x){
f[x][0]=1;
if(!son[x])return;
f[son[x]]=f[x]+1;//同一条长链共享内存
DP(son[x]);
ans[x]=ans[son[x]]+1;
for(int y:ed[x]){
if(y==fa[x]||y==son[x])continue;
f[y]=now;now+=lng[y];//不同的长链分配内存
DP(y);
for(int i=1;i<=lng[y];i++){
f[x][i]+=f[y][i-1];
if(f[x][i]>f[x][ans[x]]||f[x][i]==f[x][ans[x]]&&i<ans[x])ans[x]=i;
}
}
if(f[x][ans[x]]==1)ans[x]=0;
return;
}
int main(){
cin>>n;
for(int i=1;i<n;i++){
int u,v;cin>>u>>v;
ed[u].push_back(v);
ed[v].push_back(u);
}
DFS(1);
f[1]=now,now+=lng[1];
DP(1);
for(int i=1;i<=n;i++)cout<<ans[i]<<"\n";
return 0;
}
进阶题:P4292 [WC2010] 重建计划
首先二分答案后变成求边数 \(\in[L,U]\) 的权值和最大的链的权值和。首先写出普通 dp 方程,令 \(f_{x,d}\) 表示以 \(x\) 为根的子树中,一端是 \(x\) 的长度为 \(d\) 的链中权值和最大值。枚举 \(x\) 的儿子 \(y\),有转移方程:
发现在更新答案时,当 \(j\) 确定后变成了求区间最值,于是我们需要一个线段树维护整体 dp
数组。由于加入了线段树,无法再用指针的形式进行处理,这里设 \(v_x\) 表示 \(f_{x,0}\) 在整体 dp 数组中的下标,其他的处理方式照常。
代码如下:
#include<bits/stdc++.h>
using namespace std;
const int NN=1e6+5;
const double eps=1e-5,INF=1e9;
int n,son[NN],lng[NN],fa[NN],L,U;
double ans,val[NN];//由于维护了线段数,这里不能用指针处理
int v[NN],now=1;
typedef pair<int,double> paid;
vector<paid> ed[NN];
void DFS(int x){
lng[x]=1;son[x]=0;
for(paid eg:ed[x]){
int y=eg.first;
if(y==fa[x])continue;
fa[y]=x;val[y]=eg.second;
DFS(y);
lng[x]=max(lng[x],lng[y]+1);
if(lng[son[x]]<lng[y])son[x]=y;
}
return;
}
struct SegTrNode{
double mx,add;
#define ls(x) (x<<1)
#define rs(x) (x<<1|1)
#define mx(x) sgt[x].mx
#define add(x) sgt[x].add
}sgt[NN<<2];
inline void Up(int p){
mx(p)=max(mx(ls(p)),mx(rs(p)));
return;
}
inline void Alter(int p,double v){
add(p)+=v;mx(p)+=v;
return;
}
inline void Down(int p){
Alter(ls(p),add(p));
Alter(rs(p),add(p));
add(p)=0;
return;
}
double Max(int l,int r,int p=1,int L=1,int R=n){
if(l>R||L>r)return -INF;
if(l<=L&&R<=r)return mx(p);
int mid=L+R>>1;
Down(p);
return max(Max(l,r,ls(p),L,mid),Max(l,r,rs(p),mid+1,R));
}
void Prune(int pos,double val,int p=1,int L=1,int R=n){
if(L==R){
mx(p)=val;
return;
}
int mid=L+R>>1;
Down(p);
if(pos<=mid)Prune(pos,val,ls(p),L,mid);
else Prune(pos,val,rs(p),mid+1,R);
return Up(p);
}
void Add(int l,int r,double v,int p=1,int L=1,int R=n){
if(l>R||L>r)return;
if(l<=L&&R<=r)return Alter(p,v);
int mid=L+R>>1;
Down(p);
Add(l,r,v,ls(p),L,mid);
Add(l,r,v,rs(p),mid+1,R);
return Up(p);
}
void Build(int p=1,int L=1,int R=n){
add(p)=0,mx(p)=-INF;
if(L==R)return;
int mid=L+R>>1;
Build(ls(p),L,mid);
Build(rs(p),mid+1,R);
return;
}
inline double Ask(int pos){return Max(pos,pos);}
inline int Num(int x,int i){return v[x]+i;}
void DP(int x){
Prune(Num(x,0),0);
if(!son[x])return;
v[son[x]]=v[x]+1;//同一条长链共享内存
DP(son[x]);
Add(Num(x,1),Num(x,lng[x]-1),val[son[x]]);
for(paid eg:ed[x]){
int y=eg.first;double val=eg.second;
if(y==fa[x]||y==son[x])continue;
v[y]=now;now+=lng[y];//不同的长链分配内存
DP(y);
for(int i=0;i<lng[y];i++)
ans=max(ans,Ask(Num(y,i))+Max(Num(x,max(0,L-i-1)),Num(x,min(lng[x]-1,U-i-1)))+val);
if(ans>0)throw "success";
for(int i=1;i<=lng[y];i++)
Prune(Num(x,i),max(Ask(Num(x,i)),Ask(Num(y,i-1))+val));
}
ans=max(ans,Max(Num(x,L),Num(x,min(lng[x]-1,U))));
if(ans>0)throw "success";//throw大法好
return;
}
struct Edge{
int u,v,w;
inline void read(){scanf("%d%d%d",&u,&v,&w);}
}edge[NN];
inline bool Check(double s){
ans=-INF;
for(int i=1;i<=n;i++)ed[i].clear();
for(int i=1;i<n;i++){
int u=edge[i].u,v=edge[i].v;
double w=edge[i].w-s;
ed[u].push_back({v,w});
ed[v].push_back({u,w});
}
DFS(1);
Build();
v[1]=now=1;now+=lng[1];
try{DP(1);}catch(...){}
return ans>0;
}
int main(){
scanf("%d%d%d",&n,&L,&U);
for(int i=1;i<n;i++)edge[i].read();
double l=0,r=1e6;
while(r-l>eps){
double mid=(l+r)/2;
if(Check(mid))l=mid;
else r=mid;
}
printf("%.3f",l);
return 0;
}
变式(tai)题:tf7z不萌萌模拟赛T3 浏览
题面:
给你一个 \(n\) 个节点的树,一个合法的联通块满足任意两个点 \(dis\le k\),\(dis\) 定义为经过的边数,求合法的联通块数对 \(998244353\) 取模的结果。
\(1\le k< n\le 5\times 10^5\)
题解:
令 \(f_{x,d}\) 表示 \(x\) 为根的子树中与 \(x\) 最远的距离为 \(d\) 的合法联通块树,枚举 \(x\) 的儿子 \(y\),可以列出朴素 dp 方程:
不难注意到如果要用长链剖分优化 dp 的话要使用维护区间和的线段树。
但一般长链剖分是与 \(lng_y\) 相关,这里是与 \(lng_x\) 相关,是不行的。
发现当 \(i\in [lng_y+2,k-lng_y-2]\) 时与 \(y\) 相关的部分是定值,可以直接用区间乘解决,这样复杂度就正确了。
代码如下:
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int NN=1e6+5,MOD=998244353;
int n,k,ans;
vector<int> ed[NN];
int v[NN],dp[NN],son[NN],lng[NN],now;
int Num(int x,int i){
return v[x]+i;
}
void DFS(int x,int fa){
lng[x]=1;
for(int y:ed[x]){
if(y==fa)continue;
DFS(y,x);
lng[x]=max(lng[x],lng[y]+1);
if(lng[son[x]]<lng[y])son[x]=y;
}
}
struct SegTr{
int sum,mul;
SegTr(){sum=0,mul=1;}
#define ls(x) (x<<1)
#define rs(x) (x<<1|1)
#define sum(x) sgt[x].sum
#define mul(x) sgt[x].mul
}sgt[NN<<2];
void Alter(int p,int v){
(mul(p)*=v)%=MOD;
(sum(p)*=v)%=MOD;
return;
}
void Down(int p){
if(mul(p)==1)return;//卡常
Alter(ls(p),mul(p));
Alter(rs(p),mul(p));
mul(p)=1;
return;
}
void Up(int p){
sum(p)=(sum(ls(p))+sum(rs(p)))%MOD;
return;
}
void Add(int pos,int val,int p=1,int L=1,int R=n){
if(L==R)return void((sum(p)+=val)%=MOD);
int mid=L+R>>1;
Down(p);
if(pos<=mid)Add(pos,val,ls(p),L,mid);
else Add(pos,val,rs(p),mid+1,R);
return Up(p);
}
int Sum(int l,int r,int p=1,int L=1,int R=n){
if(l>R||L>r)return 0;
if(l<=L&&R<=r)return sum(p);
int mid=L+R>>1;
Down(p);
return (Sum(l,r,ls(p),L,mid)+Sum(l,r,rs(p),mid+1,R))%MOD;
}
void Mul(int l,int r,int v,int p=1,int L=1,int R=n){
if(l>R||L>r)return;
if(l<=L&&R<=r)return Alter(p,v);
int mid=L+R>>1;
Down(p);
Mul(l,r,v,ls(p),L,mid);
Mul(l,r,v,rs(p),mid+1,R);
return Up(p);
}
int Ask(int pos){
return Sum(pos,pos);
}
int QZ(int x,int p){
return Sum(Num(x,0),Num(x,min(lng[x]-1,p)));
}
struct Node{int pos,val;}q[NN];int top=0;//数组不能定义在Merge里,因为定义申请内存的时间是线性的
void Merge(int x,int y){
for(int i=1;i<=min(lng[y],k);i++){
q[++top]={i,Ask(Num(y,i-1))*QZ(x,min(i-1,k-i))%MOD};
if(i<k)q[++top]={i,Ask(Num(x,i))*QZ(y,min(i,k-i)-1)%MOD};
}
for(int i=max(lng[y]+1,k-lng[y]);i<k;i++){
q[++top]={i,Ask(Num(x,i))*QZ(y,min(i,k-i)-1)%MOD};
}
Mul(Num(x,lng[y]+1),Num(x,k-lng[y]-1),(QZ(y,k)+1));
for(;top;top--)Add(Num(x,q[top].pos),q[top].val);
return;
}
void DP(int x,int fa){
Add(Num(x,0),1);
if(son[x]){
v[son[x]]=v[x]+1;
DP(son[x],x);
for(int y:ed[x]){
if(y==fa||y==son[x])continue;
v[y]=now,now+=lng[y];
DP(y,x);
Merge(x,y);
}
}
(ans+=QZ(x,k))%=MOD;
return;
}
signed main(){
cin>>n>>k;
for(int i=1;i<n;i++){
int u,v;cin>>u>>v;
ed[u].push_back(v);
ed[v].push_back(u);
}
DFS(1,1);
v[1]=now=1,now+=lng[1];
DP(1,1);
cout<<ans;
return 0;
}

浙公网安备 33010602011771号