DDP 详解

例题链接

本文参考了部分资料,参见文末。

如果没有特殊声明,本文中所有的矩阵乘法均为广义矩阵乘法,请不要混淆。


DDP(动态 DP)可用于解决一类带修改操作的 DP 问题。一般用来解决树上带有点(边)权修改的 DP 问题。

引入

以较为简单的最大子段和问题来引入。

例题:SPOJ GSS3 - Can you answer these queries III

给定一个长度为 \(n\) 的序列 \(a_1,a_2,\cdots,a_n\),你需要完成 \(m\) 个操作,每个操作是以下两个操作之一:

  • 给定 \(i\),修改 \(a_i\)

  • 给定 \(l,r\),查询 \(a_l,a_{l+1},a_{l+2},\cdots,a_r\) 的最大子段和。

常规解法

这其实是一个经典问题,不带修,常规解决方法有 DP 和 贪心(滚动数组优化 DP)。

带修其实问题也不大,甚至可以通过线段树等数据结构拓展到求解区间内的最大子段和。即对于线段树节点,维护前后缀最大子段和、区间内最大子段和,合并信息即可。如果这个都不会,不建议学 DDP。

DDP 求解带修最大子段和

其实这种解法无论常数还是实现难度都完全劣于线段树解法,仅仅是为了引入 DDP。

设序列 \(a_1,a_2,\cdots,a_n\)

\(f_i\) 表示以 \(i\) 为结尾的最大子段和,\(g_i\) 表示 \(a_1,a_2,\cdots,a_i\) 的最大子段和大小(不考虑子段为空),答案即 \(g_n\)。显然,有转移:

\[\begin{aligned} f_i&=\max(f_{i-1}+a_i,a_i)\\ g_i&=\max(g_{i-1},f_i) \end{aligned} \]

这样,查询复杂度是 \(\mathcal O(1)\),而修改复杂度是 \(\mathcal O(n)\)。(若只维护 \(f\),则相反)需要一些数据结构或者算法来优化这个 DP。

DDP 所利用的,便是矩阵

广义矩阵乘法

矩阵乘法能够成立,是因为乘法具有关于加法的分配律。即 \(a(b+c)=ab+ac\)

注意到加法关于 \(\max\)\(\min\) 也具有分配律,即:

\[\begin{aligned} a+\max(b,c)&=\max(a+b,a+c)\\ a+\min(b,c)&=\min(a+b,a+c)\\ \end{aligned} \]

\(n_A\times m_A\) 的矩阵 \(A\)\(n_B\times m_B\) 的矩阵 \(B\)。对于矩阵\(A,B\),定义广义矩阵乘法 \(A\times B=C\) 为:

\[C_{i,j}=\max_{k=1}^{m_A}(A_{i,k}+B_{k,j}) \]

广义矩阵乘法同样满足结合律

证明

设 $n_C\times m_C$ 的矩阵 $C$,令 $m_A=n_B,m_B=n_C$,作广义矩阵乘法 $A\times B\times C$。令:

$$ \begin{aligned} D&=A\times B\\ E&=D\times C\\ F&=B\times C\\ G&=A\times F\\ \end{aligned} $$

则有:

$$ \begin{aligned} D_{i,j}&=\max_{k=1}^{m_A}(A_{i,k}+B_{k,j})\\ E_{i,j}&=\max_{k=1}^{m_B}(D_{i,k}+C_{k,j})\\ &=\max_{k=1}^{m_B}\left(\max_{t=1}^{m_A}(A_{i,t}+B_{t,k})+C_{k,j}\right)\\ &=\max_{k=1}^{m_B}\max_{t=1}^{m_A}(A_{i,t}+B_{t,k}+C_{k,j})\\ F_{i,j}&=\max_{k=1}^{m_B}(B_{i,k}+C_{k,j})\\ G_{i,j}&=\max_{k=1}^{m_A}(A_{i,k}+F_{k,j})\\ &=\max_{k=1}^{m_A}\left(A_{i,k}+\max_{t=1}^{m_B}(B_{k,t}+C_{t,j})\right)\\ &=\max_{k=1}^{m_A}\max_{t=1}^{m_B}(A_{i,k}+B_{k,t}+C_{t,j})\\ &=\max_{t=1}^{m_B}\max_{k=1}^{m_A}(A_{i,k}+B_{k,t}+C_{t,j})\\ \end{aligned} $$

注意到:

$$ \begin{aligned} E_{i,j}&=\max_{k=1}^{m_B}\max_{t=1}^{m_A}(A_{i,t}+B_{t,k}+C_{k,j})\\ G_{i,j}&=\max_{t=1}^{m_B}\max_{k=1}^{m_A}(A_{i,k}+B_{k,t}+C_{t,j})\\ \end{aligned} $$

显然,$E_{i,j}=G_{i,j}$,故 $(AB)C=A(BC)$,即广义矩阵乘法满足乘法结合律


事实上,广义矩阵乘法可以推广到任意两种满足分配律的运算,且广义矩阵乘法满足结合律,因此可以用于加速 DP。

广义矩阵乘法下的单位矩阵

在广义矩阵乘法下,单位矩阵不像普通矩阵乘法那样,即主对角线均为 \(1\),其余为 \(0\),而是需要根据实际定义的运算来构造。需要两个特殊值来保证正确性。

例如在关于 \(\max\) 和加法的广义矩阵乘法中,单位矩阵为:

\[\begin{bmatrix} 0&-\infty&\cdots&-\infty\\ -\infty&0&\cdots&-\infty\\ \vdots&\vdots&\ddots&\vdots\\ -\infty&-\infty&\cdots&0 \end{bmatrix} \]

令单位矩阵为 \(I\)

则有:

\[A_{i,j}=\max_{k=1}^{m_A}(A_{i,k}+I_{k,j}) \]

\(k=j\) 时,\(I_{k,j}=0\) 对于 \(\max\)没有影响的,因此取 \(I_{j,j}=0\)

否则,\(k\neq j\) 时的答案是不应当被计算的,因此取 \(I_{k,j}=-\infty\)确保不会取到其他值。

代码实现

struct Matrix{
	int n,m;
	int a[2][2];//这里的上限可以自己设置
	Matrix(int nn=0,int mm=-1){//构造函数
		if(mm==-1){
			mm=nn;
		}
		n=nn,m=mm;
	}
	void unit(){//这是单位矩阵,也可以不要
		for(int i=0;i<n;i++){
			for(int j=0;j<m;j++){
				a[i][j]=-inf;
			}
			a[i][i]=0;
		}
	}
	void print(){//用于调试的输出函数,也可以不要
		for(int i=0;i<n;i++){
			for(int j=0;j<m;j++){
				if(a[i][j]==-inf){
					cerr<<setw(5)<<"-inf";
				}else{
					cerr<<setw(4)<<a[i][j];
				}
			}
			cerr<<endl;
		}
	}
};
Matrix operator*(Matrix A,Matrix B){
	Matrix C(A.n,B.m);
	for(int i=0;i<C.n;i++){
		for(int j=0;j<C.m;j++){
			C.a[i][j]=-inf;//inf 如果可能溢出则需要特判,这里取的 0x3f3f3f3f
			for(int k=0;k<A.m;k++){
				C.a[i][j]=max(C.a[i][j],A.a[i][k]+B.a[k][j]);
			}
		}
	}
	return C;
}
Matrix& operator*=(Matrix &A,Matrix B){
	return A=A*B;
}

利用广义矩阵乘法求解

矩阵乘法优化 DP 是将递推式改写为矩阵乘法的形式,推广到广义矩阵乘法也是如此。原递推式为:

\[\begin{aligned} f_i&=\max(f_{i-1}+a_i,a_i)\\ g_i&=\max(g_{i-1},f_i)\\ &=\max(f_{i-1}+a_i,g_{i-1},a_i) \end{aligned} \]

利用广义矩阵乘法,现在便可以优化上述递推式。

\[\begin{bmatrix} f_i\\ g_i\\ 0 \end{bmatrix} = \begin{bmatrix} a_i&-\infty&a_i\\ a_i&0&a_i\\ -\infty&-\infty&0 \end{bmatrix} \begin{bmatrix} f_{i-1}\\ g_{i-1}\\ 0 \end{bmatrix} \]

则有:

\[\begin{aligned} \begin{bmatrix} f_n\\ g_n\\ 0 \end{bmatrix} &= \begin{bmatrix} a_n&-\infty&a_n\\ a_n&0&a_n\\ -\infty&-\infty&0 \end{bmatrix} \begin{bmatrix} f_{n-1}\\ g_{n-1}\\ 0 \end{bmatrix}\\ &=\begin{bmatrix} a_n&-\infty&a_n\\ a_n&0&a_n\\ -\infty&-\infty&0 \end{bmatrix} \begin{bmatrix} a_{n-1}&-\infty&a_{n-1}\\ a_{n-1}&0&a_{n-1}\\ -\infty&-\infty&0 \end{bmatrix} \begin{bmatrix} f_{n-2}\\ g_{n-2}\\ 0 \end{bmatrix}\\ &=\cdots\\ &= \begin{bmatrix} a_n&-\infty&a_n\\ a_n&0&a_n\\ -\infty&-\infty&0 \end{bmatrix} \begin{bmatrix} f_{n-1}\\ g_{n-1}\\ 0 \end{bmatrix}\\ &=\begin{bmatrix} a_n&-\infty&a_n\\ a_n&0&a_n\\ -\infty&-\infty&0 \end{bmatrix} \begin{bmatrix} a_{n-1}&-\infty&a_{n-1}\\ a_{n-1}&0&a_{n-1}\\ -\infty&-\infty&0 \end{bmatrix} \cdots \begin{bmatrix} a_2&-\infty&a_2\\ a_2&0&a_2\\ -\infty&-\infty&0 \end{bmatrix} \begin{bmatrix} f_1\\ g_1\\ 0 \end{bmatrix} \end{aligned} \]

注意到广义矩阵乘法满足结合律,因此可以使用线段树维护区间 \([l,r]\) 内的广义矩阵乘法的乘积。注意运算顺序,广义矩阵乘法不一定满足交换律。即维护 \([l,r]\) 内的矩阵乘积,应当用 \(\left[\left\lfloor\dfrac{l+r}2\right\rfloor+1,r\right]\) 的矩阵乘 \(\left[l,\left\lfloor\dfrac{l+r}2\right\rfloor\right]\) 的矩阵。

也因此可以扩展到 \(a_l,a_{l+1},\cdots,a_r\) 内的最大子段和 \(\textit{ans}\)。记:

\[A= \begin{bmatrix} a_r&-\infty&a_r\\ a_r&0&a_r\\ -\infty&-\infty&0 \end{bmatrix} \begin{bmatrix} a_{r-1}&-\infty&a_{r-1}\\ a_{r-1}&0&a_{r-1}\\ -\infty&-\infty&0 \end{bmatrix} \cdots \begin{bmatrix} a_{l+1}&-\infty&a_{l+1}\\ a_{l+1}&0&a_{l+1}\\ -\infty&-\infty&0 \end{bmatrix} \begin{bmatrix} a_l\\ a_l\\ 0 \end{bmatrix} \]

则有:

\[\textit{ans}=A_{2,1} \]

注意广义矩阵乘法不一定满足交换律,因此一定要注意运算顺序!

\(q\) 为操作次数,这样即可在 \(\mathcal O(q\log n)\) 的复杂度内完成。单次修改和单次查询均为 \(\mathcal O(\log n)\)

注意乘法顺序即可。

参考代码
//#include<bits/stdc++.h>
#include<algorithm>
#include<iostream>
#include<cstring>
#include<iomanip>
#include<cstdio>
#include<string>
#include<vector>
#include<cmath>
#include<ctime>
#include<deque>
#include<queue>
#include<stack>
#include<list>
using namespace std;
constexpr const int N=50000,inf=0x3f3f3f3f;
struct Matrix{
	int n,m;
	int a[4][4];
	Matrix(int nn=0,int mm=-1){
		if(mm==-1){
			mm=nn;
		}
		n=nn,m=mm;
	}
	bool unit(){
		if(n!=m){
			return false;
		}
		for(int i=1;i<=n;i++){
			for(int j=1;j<=m;j++){
				a[i][j]=-inf;
			}
			a[i][i]=0;
		}
		return true;
	}
};
Matrix operator *(Matrix A,Matrix B){
	Matrix C(A.n,B.m); 
	for(int i=1;i<=C.n;i++){
		for(int j=1;j<=C.m;j++){
			C.a[i][j]=-inf;
			for(int k=1;k<=A.m;k++){
				if(A.a[i][k]==-inf||B.a[k][j]==-inf){
					continue;
				}
				C.a[i][j]=max(C.a[i][j],A.a[i][k]+B.a[k][j]);
			}
		}
	}
	return C;
}
Matrix& operator *=(Matrix&A,Matrix B){
	return A=A*B;
}
int n,a[N+1];
struct segTree{
	struct node{
		Matrix value;
		int l,r;
	}t[N<<2|1];
	
	Matrix create(int x){
		Matrix ans(3);
		ans.a[1][1]=ans.a[1][3]=ans.a[2][1]=ans.a[2][3]=x;
		ans.a[1][2]=ans.a[3][1]=ans.a[3][2]=-inf;
		ans.a[2][2]=ans.a[3][3]=0;
		return ans;
	}
	void up(int p){
		//注意乘法顺序!! 
		t[p].value=t[p<<1|1].value*t[p<<1].value;
	}
	void build(int p,int l,int r){
		t[p].l=l,t[p].r=r;
		if(l==r){
			t[p].value=create(a[l]);
			return;
		}
		int mid=l+r>>1;
		build(p<<1,l,mid);
		build(p<<1|1,mid+1,r);
		up(p);
	}
	void change(int p,int x,int k){
		if(t[p].l==t[p].r){
			t[p].value=create(k);
			return; 
		}
		if(x<=t[p<<1].r){
			change(p<<1,x,k);
		}else{
			change(p<<1|1,x,k);
		}
		up(p);
	}
	Matrix query(int p,int l,int r){
		if(l<=t[p].l&&t[p].r<=r){ 
			return t[p].value;
		}
		Matrix ans(3,3);
		ans.unit();
		//注意乘法顺序!!!! 
		if(t[p<<1|1].l<=r){
			ans*=query(p<<1|1,l,r);
		}
		if(l<=t[p<<1].r){
			ans*=query(p<<1,l,r);
		}
		return ans;
	}
}t;
int main(){
	/*freopen("test.in","r",stdin);
	freopen("test.out","w",stdout);*/
	
	ios::sync_with_stdio(false);
	cin.tie(0);cout.tie(0);
	
	cin>>n;
	for(int i=1;i<=n;i++){
		cin>>a[i];
	}
	t.build(1,1,n);
	int q;
	cin>>q;
	while(q--){
		int op,x,y;
		cin>>op>>x>>y;
		if(op){
			Matrix pl(3,1);
			pl.a[1][1]=pl.a[2][1]=a[x];
			pl.a[3][1]=0;
			if(x+1<=y){
				pl=t.query(1,x+1,y)*pl;
			}
			cout<<pl.a[2][1]<<'\n';
		}else{
			a[x]=y; 
			t.change(1,x,y);
		}
	}
	
	cout.flush();
	 
	/*fclose(stdin);
	fclose(stdout);*/
	return 0;
}
/*
4
1 2 3 4
4
1 1 3
0 3 -3
1 2 4
1 3 3

6
4
-3
*/
双倍经验

洛谷 P4513 小白逛公园

但是使用 DDP 写法需要卡一下空间。

//#include<bits/stdc++.h>
#include<algorithm>
#include<iostream>
#include<cstring>
#include<iomanip>
#include<cstdio>
#include<string>
#include<vector>
#include<cmath>
#include<ctime>
#include<deque>
#include<queue>
#include<stack>
#include<list>
using namespace std;
constexpr const int N=5e5,inf=0x3f3f3f3f;
struct Matrix{
	int n,m;
	int a[3][3];
	Matrix(int nn=0,int mm=-1){
		if(mm==-1){
			mm=nn;
		}
		n=nn,m=mm;
	}
	bool unit(){
		if(n!=m){
			return false;
		}
		for(int i=0;i<n;i++){
			for(int j=0;j<m;j++){
				a[i][j]=-inf;
			}
			a[i][i]=0;
		}
		return true;
	}
	void print(){
		for(int i=1;i<=n;i++){
			for(int j=1;j<=m;j++){
				if(a[i][j]!=-inf){
					cerr<<setw(4)<<a[i][j]<<' ';
				}else{
					cerr<<setw(5)<<"-inf ";
				}
			}
			cerr<<endl;
		}
	}
};
Matrix operator *(Matrix A,Matrix B){
	Matrix C(A.n,B.m); 
	for(int i=0;i<C.n;i++){
		for(int j=0;j<C.m;j++){
			C.a[i][j]=-inf;
			for(int k=0;k<A.m;k++){
				if(A.a[i][k]==-inf||B.a[k][j]==-inf){
					continue;
				}
				C.a[i][j]=max(C.a[i][j],A.a[i][k]+B.a[k][j]);
			}
		}
	}
	return C;
}
Matrix& operator *=(Matrix&A,Matrix B){
	return A=A*B;
}
int n,a[N+1];
struct segTree{
	struct node{
		Matrix value;
	}t[N<<2|1];
	
	Matrix create(int x){
		Matrix ans(3);
		ans.a[0][0]=ans.a[0][2]=ans.a[1][0]=ans.a[1][2]=x;
		ans.a[0][1]=ans.a[2][0]=ans.a[2][1]=-inf;
		ans.a[1][1]=ans.a[2][2]=0;
		return ans;
	}
	void up(int p){
		t[p].value=t[p<<1|1].value*t[p<<1].value;
	}
	void build(int p,int l,int r){
		if(l==r){
			t[p].value=create(a[l]);
			return;
		}
		int mid=l+r>>1;
		build(p<<1,l,mid);
		build(p<<1|1,mid+1,r);
		up(p);
	}
	void change(int p,int tl,int tr,int x,int k){
		if(tl==tr){
			t[p].value=create(k);
			return; 
		}
		int mid=tl+tr>>1;
		if(x<=mid){
			change(p<<1,tl,mid,x,k);
		}else{
			change(p<<1|1,mid+1,tr,x,k);
		}
		up(p);
	}
	Matrix query(int p,int tl,int tr,int l,int r){
		if(l<=tl&&tr<=r){ 
			return t[p].value;
		}
		Matrix ans(3,3);
		ans.unit();
		int mid=tl+tr>>1;
		if(mid+1<=r){
			ans*=query(p<<1|1,mid+1,tr,l,r);
		}
		if(l<=mid){
			ans*=query(p<<1,tl,mid,l,r);
		}
		return ans;
	}
}t;
int main(){
	/*freopen("test.in","r",stdin);
	freopen("test.out","w",stdout);*/
	
	ios::sync_with_stdio(false);
	cin.tie(0);cout.tie(0);
	
	int q;
	cin>>n>>q;
	for(int i=1;i<=n;i++){
		cin>>a[i];
	}
	t.build(1,1,n);
	while(q--){
		int op,x,y;
		cin>>op>>x>>y;
		if(op==1){
			if(x>y){
				swap(x,y);
			} 
			Matrix pl(3,1);
			pl.a[0][0]=pl.a[1][0]=a[x];
			pl.a[2][0]=0;
			if(x+1<=y){
				pl=t.query(1,1,n,x+1,y)*pl;
			}
			cout<<pl.a[1][0]<<'\n';
		}else{
			a[x]=y; 
			t.change(1,1,n,x,y);
		}
	}
	
	cout.flush();
	 
	/*fclose(stdin);
	fclose(stdout);*/
	return 0;
}
/*
4
1 2 3 4
4
1 1 3
0 3 -3
1 2 4
1 3 3

6
4
-3
*/

总结

可以发现,DDP 求解带修最大子段和问题时,利用矩阵乘法优化 DP 的思路,将状态转移方程写为了广义矩阵乘法的形式,并利用相关数据结构维护矩阵乘积,从而优化 DP。

DDP 维护树上信息

维护树上最大权独立集为例。

给定一棵 \(n\) 个点的树,点带点权。

\(m\) 次操作,每次操作给定 \(x,y\),表示修改点 \(x\) 的权值为 \(y\)

你需要在每次操作之后求出这棵树的最大权独立集的权值大小。

\(1\leq n\leq m\leq10^5\)

朴素 DP

显然需要 DP 求解。不妨钦定 \(1\) 为根节点。

\(f_{x,1},f_{x,0}\) 分别表示是否选择 \(x\) 时的 \(x\) 子树内的最大权独立集的权值大小。

\(v_x\) 表示 \(x\) 的子节点集,则有:

\[\begin{aligned} f_{x,0}&=\sum_{y\in v_x}\max(f_{y,0},f_{y,1})\\ f_{x,1}&=a_x+\sum_{y\in v_x}f_{y,0} \end{aligned} \]

答案即 \(\max(f_{1,0},f_{1,1})\)

如果不考虑修改,这些东西都可以通过一次 \(\mathcal O(n)\) 树形 DP 求出来。


容易发现,更新一个点 \(i\) 的点权 \(a_i\),受到影响的 \(f_{x,0},f_{x,1}\)\(x\) 必须是 \(i\) 的祖先

最坏情况下,整棵树为一条链,修改的复杂度便是 \(\mathcal O(n)\)。总时间复杂度 \(\mathcal O(nm)\) 的,我们需要优化这个朴素 DP。

树链剖分

发现修改的信息是从修改节点 \(i\) 到根节点 \(1\) 的一条链上的信息,因此可以尝试使用树链剖分优化。

\(x\) 的重子节点为 \(\textit{son}_x\)

不妨将重子节点和轻子节点的贡献分开算,设 \(g_{x,0},g_{x,1}\) 表示是否选 \(x\) 时,子树内除重子节点及其子树外的答案。即:

\[\begin{aligned} g_{x,0}&=\sum_{y\in v_x}[y\neq\textit{son}_x]\max(f_{y,0},f_{y,1})\\ g_{x,1}&=a_x+\sum_{y\in v_x}[y\neq\textit{son}_x]f_{y,0} \end{aligned} \]

则有:

\[\begin{aligned} f_{x,0}&=g_{x,0}+\max(f_{\textit{son}_x,0},f_{\textit{son}_x,1})\\ f_{x,1}&=g_{x,1}+f_{\textit{son}_x,0} \end{aligned} \]

\(f_{x,0},f_{x,1}\),对于 \(g_{x,0},g_{x,1}\),在修改操作前都可以通过一次 \(\mathcal O(n)\) 的树形 DP 求出来。

同时,修改点权 \(a_i\) 即修改 \(g_{i,1}\)。不妨设 \(a_i\) 被修改为了 \(a_i+\Delta\),则将 \(g_{i,1}\) 修改为 \(g_{i,1}+\Delta\) 即可。

考虑利用广义矩阵乘法优化递推,不难设计:

\[\begin{bmatrix} f_{x,0}\\ f_{x,1} \end{bmatrix} = \begin{bmatrix} g_{x,0}&g_{x,0}\\ g_{x,1}&-\infty \end{bmatrix} \begin{bmatrix} f_{\textit{son}_x,0}\\ f_{\textit{son}_x,1} \end{bmatrix} \]

\(\begin{bmatrix} g_{x,0}&g_{x,0}\\ g_{x,1}&-\infty \end{bmatrix}\)\(x\)转移矩阵

现在来考虑修改 \(g_{i,1}\),如何加速修改操作。

对于节点 \(x\),记 \(\textit{father}_x,\textit{top}_x\) 分别为 \(x\) 的父节点、\(x\) 所在重链链顶节点。

也就是说,对于 \(x\) 所在的一条重链,只有链顶 \(\textit{top}_x\) 是轻节点。那么修改节点 \(i\) 时,链内有且仅有 \(\textit{top}_i\) 的父节点 \(\textit{father}_{\textit{top}_i}\) 的转移矩阵会被修改。考虑会修改成什么样。

\(x\) 所在链的链底节点为 \(\textit{bottom}_x\),注意到:

\[\begin{aligned} \begin{bmatrix} f_{x,0}\\ f_{x,1} \end{bmatrix} &= \begin{bmatrix} g_{x,0}&g_{x,0}\\ g_{x,1}&-\infty \end{bmatrix} \begin{bmatrix} f_{\textit{son}_x,0}\\ f_{\textit{son}_x,1} \end{bmatrix} \\ &= \begin{bmatrix} g_{x,0}&g_{x,0}\\ g_{x,1}&-\infty \end{bmatrix} \begin{bmatrix} g_{\textit{son}_x,0}&g_{\textit{son}_x,0}\\ g_{\textit{son}_x,1}&-\infty \end{bmatrix} \begin{bmatrix} f_{\textit{son}_{\textit{son}_x},0}\\ f_{\textit{son}_{\textit{son}_x},1} \end{bmatrix}\\ &= \begin{bmatrix} g_{x,0}&g_{x,0}\\ g_{x,1}&-\infty \end{bmatrix} \begin{bmatrix} g_{\textit{son}_x,0}&g_{\textit{son}_x,0}\\ g_{\textit{son}_x,1}&-\infty \end{bmatrix} \cdots \begin{bmatrix} g_{\textit{father}_{\textit{bottom}_x},0}&g_{\textit{father}_{\textit{bottom}_x},0}\\ g_{\textit{father}_{\textit{bottom}_x},1}&-\infty \end{bmatrix} \begin{bmatrix} f_{\textit{bottom}_x,0}\\ f_{\textit{bottom}_x,1} \end{bmatrix}\\ \end{aligned} \]

又注意到 \(\textit{bottom}_x\) 必为叶节点(否则 \(\textit{bottom}_x\) 存在子节点,不为链底节点),因此可以确定:

\[\begin{bmatrix} f_{\textit{bottom}_x,0}\\ f_{\textit{bottom}_x,1} \end{bmatrix} = \begin{bmatrix} 0\\ a_{\textit{bottom}_x} \end{bmatrix} \]

首先,可以通过线段树维护重链上的转移矩阵,并同时维护区间乘积,但是需要注意运算顺序。此处因为进行了树链剖分,则有 \(\textit{dfn}_{x}=\textit{dfn}_{\textit{son}_x}-1\),因此线段树维护时“左乘右”即可。

修改 \(g_{i,1}\) 时,只需要在线段树上单点修改转移矩阵即可。

之后便是算出 \(\begin{bmatrix}f_{\textit{top}_x,0}\\f_{\textit{top}_x,1}\end{bmatrix}\),从而利用 \(f_{\textit{top}_x,0},f_{\textit{top}_x,1}\) 更新 \(\textit{father}_{\textit{top}_x}\) 的转移矩阵。

考虑到只有 \(\mathcal O(\log n)\) 条重链,而线段树单次修改 \(\mathcal O(\log n)\),故修改总复杂度 \(\mathcal O\left(\log^2n\right)\)

代码实现时,需要注意区分点的实际编号与点在树链剖分意义下的 DFS 序

总时间复杂度:\(\mathcal O\left(n+m\log^2 n\right)\)

参考代码
//#include<bits/stdc++.h>
#include<algorithm>
#include<iostream>
#include<cstring>
#include<iomanip>
#include<cstdio>
#include<string>
#include<vector>
#include<cmath>
#include<ctime>
#include<deque>
#include<queue>
#include<stack>
#include<list>
using namespace std;
//#define DEBUG 
constexpr const int N=1e5,inf=0x3f3f3f3f;
int n,a[N+1];
int f[N+1][2],g[N+1][2];
vector<int>edge[N+1];
struct Matrix{
	int n,m;
	int a[2][2];
	Matrix(int nn=0,int mm=-1){
		if(mm==-1){
			mm=nn;
		}
		n=nn,m=mm;
	}
	void unit(){
		for(int i=0;i<n;i++){
			for(int j=0;j<m;j++){
				a[i][j]=-inf;
			}
			a[i][i]=0;
		}
	}
};
Matrix operator*(Matrix A,Matrix B){
	Matrix C(A.n,B.m);
	for(int i=0;i<C.n;i++){
		for(int j=0;j<C.m;j++){
			C.a[i][j]=-inf;
			for(int k=0;k<A.m;k++){
				C.a[i][j]=max(C.a[i][j],A.a[i][k]+B.a[k][j]);
			}
		}
	}
	return C;
}
Matrix& operator*=(Matrix &A,Matrix B){
	return A=A*B;
}
namespace hld{
	int size[N+1],father[N+1],son[N+1];
	void dfs1(int x,int fx){
		father[x]=fx;
		size[x]=1;
		for(int i:edge[x]){
			if(i==fx){
				continue;
			}
			dfs1(i,x);
			size[x]+=size[i];
			if(size[i]>size[son[x]]){
				son[x]=i;
			}
		}
	}
	int top[N+1],bottom[N+1],dfn[N+1],rnk[N+1];
	void dfs2(int x,int topx){
		top[x]=topx;
		static int cnt;
		dfn[x]=++cnt;
		rnk[cnt]=x;
		if(son[x]){
			dfs2(son[x],topx);
			bottom[x]=bottom[son[x]];
			for(int i:edge[x]){
				if(i==father[x]||i==son[x]){
					continue;
				}
				dfs2(i,i);
			}
		}else{
			bottom[x]=x;
		}
	}
	void build(){
		dfs1(1,0);
		dfs2(1,1);
	}
	struct segTree{
		struct node{
			Matrix value;
			int l,r;
		}t[N<<2|1];
		
		Matrix create(int x){
			Matrix ans(2);
			ans.a[0][0]=ans.a[0][1]=g[x][0];
			ans.a[1][0]=g[x][1];
			ans.a[1][1]=-inf;
			return ans;
		}
		void up(int p){
			t[p].value=t[p<<1].value*t[p<<1|1].value;
		}
		void build(int p,int l,int r){
			t[p].l=l,t[p].r=r;
			if(l==r){
				t[p].value=create(rnk[l]);
				return;
			}
			int mid=l+r>>1;
			build(p<<1,l,mid);
			build(p<<1|1,mid+1,r);
			up(p);
		}
		Matrix query(int p,int l,int r){
			if(r<l){
				Matrix ans(2);
				ans.unit();
				return ans;
			}
			if(l<=t[p].l&&t[p].r<=r){
				return t[p].value;
			}
			Matrix ans(2);
			ans.unit();
			if(l<=t[p<<1].r){
				ans*=query(p<<1,l,r);
			}
			if(t[p<<1|1].l<=r){
				ans*=query(p<<1|1,l,r);
			}
			return ans;
		}
		void change(int p,int x){
			if(t[p].l==t[p].r){
				t[p].value=create(rnk[x]);
				return;
			}
			if(x<=t[p<<1].r){
				change(p<<1,x);
			}else{
				change(p<<1|1,x);
			}
			up(p);
		}
	}segTree;
	
	void change(int x,int y){
		g[x][1]+=-a[x]+y;
		a[x]=y;
		segTree.change(1,dfn[x]);
		x=top[x];
		while(x!=1){
			Matrix pl(2,1);
			pl.a[0][0]=0;
			pl.a[1][0]=a[bottom[x]];
			pl=segTree.query(1,dfn[x],dfn[bottom[x]]-1)*pl;
			int fx0=f[x][0],fx1=f[x][1];
			f[x][0]=pl.a[0][0];
			f[x][1]=pl.a[1][0];
			g[father[x]][0]+=max(f[x][0],f[x][1])-max(fx0,fx1);
			g[father[x]][1]+=f[x][0]-fx0;
			segTree.change(1,dfn[father[x]]);
			x=top[father[x]];
		}
	}
	int query(){
		Matrix pl(2,1);
		pl.a[0][0]=0;
		pl.a[1][0]=a[bottom[1]];
		if(1<=dfn[bottom[1]]-1){
			pl=segTree.query(1,1,dfn[bottom[1]]-1)*pl;
		}
		return max(pl.a[0][0],pl.a[1][0]);
	}
}

void dfs(int x,int fx){
	for(int i:edge[x]){
		if(i==fx){
			continue;
		}
		dfs(i,x);
		f[x][0]+=max(f[i][0],f[i][1]);
		f[x][1]+=f[i][0];
	}
	f[x][1]+=a[x];
	g[x][0]=f[x][0]-max(f[hld::son[x]][0],f[hld::son[x]][1]);
	g[x][1]=f[x][1]-f[hld::son[x]][0];
}
void pre(){
	hld::build();
	dfs(1,0);
	hld::segTree.build(1,1,n);
}
int main(){
	/*freopen("test.in","r",stdin);
	freopen("test.out","w",stdout);*/
	
	ios::sync_with_stdio(false);
	cin.tie(0);cout.tie(0);
	
	int m;
	cin>>n>>m;
	for(int i=1;i<=n;i++){
		cin>>a[i];
	}
	for(int i=1;i<n;i++){
		int u,v;
		cin>>u>>v;
		edge[u].push_back(v);
		edge[v].push_back(u);
	}
	pre();
	while(m--){
		int x,y;
		cin>>x>>y;
		hld::change(x,y);
		cout<<hld::query()<<'\n';
	}
	
	cout.flush();
	 
	/*fclose(stdin);
	fclose(stdout);*/
	return 0;
}
/*
10 1
-11 98 -99 -76 56 38 92 -51 -44 47 
2 1
3 1
4 3
5 2
6 2
7 1
8 2
9 4
10 7
7 -58

145
*/

参考资料

  1. captain1 的DDP入门
  2. RenaMoe 的广义矩阵乘法与动态 DP 学习笔记
  3. 动态 DP - OI Wiki
posted @ 2025-07-28 11:02  TH911  阅读(34)  评论(0)    收藏  举报