#全局平衡二叉树,树链剖分#洛谷 4751 【模板】动态 DP(加强版)

题目传送门


分析

正常的树形dp是 \(f[x][0]+=\max(f[y][0],f[y][1]),f[x][1]+=f[y][0]\)

按照重儿子和轻儿子进行拆分,那么 \(f[x][0]=g[x][0]+\max(f[big[x]][0],f[big[x]][1]),f[x][1]=g[x][1]+f[big[x]][0]\)

那么可以转化为广义矩阵乘法

\[\begin{bmatrix} f[x][0] \\ f[x][1] \end{bmatrix} = \begin{bmatrix} g[x][0] & g[x][0] \\ g[x][1] & -\infty \end{bmatrix} * \begin{bmatrix} f[big[x]][0] \\ f[big[x]][1] \end{bmatrix} \]

可以发现单点修改的时候只需要对整个重链进行查询,因此可以对每个重链开一棵线段树就能卡进时限


代码

#include <cstdio>
#include <cctype>
using namespace std;
const int N=1000011,inf=0x3f3f3f3f;
struct node{int y,next;}e[N<<1];
int dfn[N],big[N],Top[N],siz[N],nfd[N],tot,ofn[N],ls[N<<2],rs[N<<2];
int dp[N][2],a[N],dep[N],fat[N],et=1,n,m,as[N],rt[N],cnt;
int iut(){
	int ans=0,f=1; char c=getchar();
	while (!isdigit(c)) f=(c=='-')?-f:f,c=getchar();
	while (isdigit(c)) ans=ans*10+c-48,c=getchar();
	return ans*f;
}
inline void print(int ans){
	if (ans<0) putchar('-'),ans=-ans;
	if (ans>9) print(ans/10);
	putchar(ans%10+48);
}
int max(int a,int b){return a>b?a:b;}
struct maix{
	int p[2][2];
	inline maix operator *(const maix &B)const{
	    maix C;
	    C.p[0][0]=max(p[0][0]+B.p[0][0],p[0][1]+B.p[1][0]),
	    C.p[0][1]=max(p[0][0]+B.p[0][1],p[0][1]+B.p[1][1]),
	    C.p[1][0]=max(p[1][0]+B.p[0][0],p[1][1]+B.p[1][0]),
	    C.p[1][1]=max(p[1][0]+B.p[0][1],p[1][1]+B.p[1][1]);
	    return C;
	}
}w[N<<2],A[N];
void build(int &rt,int l,int r){
	rt=++cnt;
	if (l==r) {w[rt]=A[nfd[l]]; return;}
	int mid=(l+r)>>1;
	build(ls[rt],l,mid);
	build(rs[rt],mid+1,r);
	w[rt]=w[ls[rt]]*w[rs[rt]];
}
void update(int rt,int l,int r,int x){
	if (l==r) {w[rt]=A[nfd[x]]; return;}
	int mid=(l+r)>>1;
	if (x<=mid) update(ls[rt],l,mid,x);
	    else update(rs[rt],mid+1,r,x);
	w[rt]=w[ls[rt]]*w[rs[rt]];	
}
void dfs1(int x,int fa){
	fat[x]=fa,siz[x]=1,dep[x]=dep[fa]+1;
	for (int i=as[x],SIZ=-1;i;i=e[i].next)
	if (e[i].y!=fa){
		dfs1(e[i].y,x),siz[x]+=siz[e[i].y];
		if (SIZ<siz[e[i].y]) big[x]=e[i].y,SIZ=siz[e[i].y];
	}
}
void dfs2(int x,int linp){
	dfn[x]=++tot,nfd[tot]=x,Top[x]=linp,ofn[linp]=tot;
	dp[x][0]=0,dp[x][1]=a[x],A[x].p[1][1]=-inf,
	A[x].p[0][0]=A[x].p[0][1]=0,A[x].p[1][0]=a[x];
	if (!big[x]) return; dfs2(big[x],linp);
	dp[x][0]+=max(dp[big[x]][0],dp[big[x]][1]);
	dp[x][1]+=dp[big[x]][0];
	for (int i=as[x];i;i=e[i].next)
	if (e[i].y!=fat[x]&&e[i].y!=big[x]){
	    dfs2(e[i].y,e[i].y);
	    int now=max(dp[e[i].y][0],dp[e[i].y][1]);
	    dp[x][0]+=now,dp[x][1]+=dp[e[i].y][0],
	    A[x].p[0][0]+=now,A[x].p[1][0]+=dp[e[i].y][0];
	}
	A[x].p[0][1]=A[x].p[0][0];
}
inline void Update(int x,int z){
	for (A[x].p[1][0]+=z-a[x],a[x]=z;x;){
		maix B1=w[rt[Top[x]]];
		update(rt[Top[x]],dfn[Top[x]],ofn[Top[x]],dfn[x]);
		maix B2=w[rt[Top[x]]];
		x=fat[Top[x]];
		A[x].p[0][0]+=max(B2.p[0][0],B2.p[1][0])-max(B1.p[0][0],B1.p[1][0]),
		A[x].p[0][1]=A[x].p[0][0],A[x].p[1][0]+=B2.p[0][0]-B1.p[0][0];
	}
}
int main(){
	n=iut(); m=iut();
	for (int i=1;i<=n;++i) a[i]=iut();
	for (int i=1;i<n;++i){
		int x=iut(),y=iut();
		e[++et]=(node){y,as[x]},as[x]=et;
		e[++et]=(node){x,as[y]},as[y]=et;
	}
	dfs1(1,0),dfs2(1,1);
	for (int x=1;x<=n;++x) if (Top[x]==x) build(rt[Top[x]],dfn[Top[x]],ofn[Top[x]]);
	for (int i=1,lans=0;i<=m;++i,putchar(10)){
		int x=iut()^lans,z=iut(); Update(x,z);
		print(lans=max(w[1].p[0][0],w[1].p[1][0]));
	}
	return 0;
}

分析(全局平衡二叉树)

然而这样复杂度仍然是 log 方的,考虑重链能不能也变成 log 呢,其实是可以的,

不妨对重链按照每个节点轻儿子的大小从中位数分治,轻儿子认父不认子,就得到了全局平衡二叉树,

这样树高是 log 级别的,对这棵辅助树跳father修改即可。


代码

#include <cstdio>
#include <cctype>
#include <vector>
using namespace std;
const int N=1000011,inf=0x3f3f3f3f; struct node{int y,next;}e[N<<1];
int fat[N],siz[N],a[N],f[N][2],g[N][2],light[N],et=1,n,big[N],lights[N],son[N][2],as[N],father[N],gbrt,Q;
int iut(){
	int ans=0,f=1; char c=getchar();
	while (!isdigit(c)) f=(c=='-')?-f:f,c=getchar();
	while (isdigit(c)) ans=ans*10+c-48,c=getchar();
	return ans*f;
}
inline void print(int ans){
	if (ans<0) putchar('-'),ans=-ans;
	if (ans>9) print(ans/10);
	putchar(ans%10+48);
}
int max(int a,int b){return a>b?a:b;}
struct maix{
	int p[2][2];
	inline maix operator *(const maix &B)const{
	    maix C;
	    C.p[0][0]=max(p[0][0]+B.p[0][0],p[0][1]+B.p[1][0]),
	    C.p[0][1]=max(p[0][0]+B.p[0][1],p[0][1]+B.p[1][1]),
	    C.p[1][0]=max(p[1][0]+B.p[0][0],p[1][1]+B.p[1][0]),
	    C.p[1][1]=max(p[1][0]+B.p[0][1],p[1][1]+B.p[1][1]);
	    return C;
	}
}w[N],A[N];
void dfs1(int x,int fa){
	fat[x]=fa,siz[x]=1;
	f[x][0]=0,f[x][1]=a[x];
	for (int i=as[x],SIZ=-1;i;i=e[i].next)
	if (e[i].y!=fa){
		dfs1(e[i].y,x),siz[x]+=siz[e[i].y];
		if (SIZ<siz[e[i].y]) big[x]=e[i].y,SIZ=siz[e[i].y];
		f[x][0]+=max(f[e[i].y][0],f[e[i].y][1]);
		f[x][1]+=f[e[i].y][0];
	}
	if (big[x]){
		light[x]=siz[x]-siz[big[x]];
		g[x][0]=f[x][0]-max(f[big[x]][0],f[big[x]][1]);
		g[x][1]=f[x][1]-f[big[x]][0];
	}else{
		light[x]=siz[x];
		g[x][0]=f[x][0];
		g[x][1]=f[x][1];
	}
}
void reset(int x){
	A[x].p[0][0]=A[x].p[0][1]=g[x][0];
	A[x].p[1][0]=g[x][1],A[x].p[1][1]=-inf;
}
void pup(int x){
	if (son[x][0]) w[x]=w[son[x][0]]*A[x];
	     else w[x]=A[x];
	if (son[x][1]) w[x]=w[x]*w[son[x][1]];
}
int build(vector<int>heavy,int l,int r,int fa){
	if (l>r) return 0;
	lights[r+1]=0;
	for (int i=r;i>=l;--i) lights[i]=lights[i+1]+light[heavy[i]];
	int mid=l;
	for (int i=r;i>l;--i)
	if (lights[i]>=lights[l]-lights[i]){
		mid=i;
		break;
	}
	int x=heavy[mid];
	father[x]=fa;
	son[x][0]=build(heavy,l,mid-1,x);
	son[x][1]=build(heavy,mid+1,r,x);
	reset(x),pup(x);
	return x;
}
int dfs2(int x,int fa){
	vector<int>heavy;
	for (int u=x;u;u=big[u]) heavy.push_back(u);
	int rt=build(heavy,0,heavy.size()-1,fa);
	for (int u:heavy)
	for (int i=as[u];i;i=e[i].next)
	if (e[i].y!=fat[u]&&e[i].y!=big[u]) dfs2(e[i].y,u);
	return rt;
}
void update(int x,int y){
	g[x][1]+=y-a[x],a[x]=y,reset(x);
	for (int fa;x;x=fa){
		fa=father[x];
		if (fa&&son[fa][0]!=x&&son[fa][1]!=x){
			maix B1=w[x]; pup(x); maix B2=w[x];
			g[fa][0]+=max(B2.p[0][0],B2.p[1][0])-max(B1.p[0][0],B1.p[1][0]),
		    g[fa][1]+=B2.p[0][0]-B1.p[0][0],reset(fa);
		}else pup(x);
	}
}
int main(){
	n=iut(),Q=iut();
	for (int i=1;i<=n;++i) a[i]=iut();
	for (int i=1;i<n;++i){
		int x=iut(),y=iut();
		e[++et]=(node){y,as[x]},as[x]=et;
		e[++et]=(node){x,as[y]},as[y]=et;
	}
	dfs1(1,0);
	gbrt=dfs2(1,0);
	for (int i=1,lans=0;i<=Q;++i,putchar(10)){
		int x=iut()^lans,z=iut(); update(x,z);
		print(lans=max(w[gbrt].p[0][0],w[gbrt].p[1][0]));
	}
	return 0;
}
posted @ 2025-08-22 14:50  lemondinosaur  阅读(2)  评论(0)    收藏  举报