概念

DDP,可以理解为转移会发生改变的动态规划。

当然这个改变是题目中给的,包括系数,转移位置的改变。显然暴力枚举这些改变是不现实的,我们要把改变体现到其他地方。

最经典的,体现到矩阵上。

我们把转移写成矩阵,那么改变转移就是改变转移矩阵。

具体的改变会落实到具体的题目上。

广义矩阵乘法

因为转移的多样性,矩阵乘法不一定需要用一般乘法的乘完相加。在满足结合律的情况下,可以是乘完取 \(\min\),加完取 \(\max\) 等。

如 CF750E,要删除最少,转移中需要取 \(\min\),所以写成矩阵时,重载乘法就用到了加完取 \(\min\),同时因为其有结合律,其仍旧可以像一般矩阵乘法进行上树等操作。

线段树维护

矩阵满足结合律,可以用线段树维护。

面对每一位转移不同的题目或者只需统计区间答案的题目时,使用线段树维护区间转移矩阵的积是很必要的。

主要是代码实现的难度。

struct mat
{
	int mat[6][6];
}a,c;
mat operator *(mat a,mat b)
{
    mat c;
    memset(c.mat,63,sizeof(c.mat));
    for(int k=0;k<5;k++)
    {
        for(int i=0;i<5;i++)
        {
            for(int z=0;z<5;z++)
            {
                c.mat[i][z]=min(c.mat[i][z],a.mat[i][k]+b.mat[k][z]);
            }
        }
    }
    return c;
}
mat mul(mat a,mat b)
{
    mat c;
    memset(c.mat,63,sizeof(c.mat));
    for(int k=0;k<5;k++)
    {
        for(int i=0;i<1;i++)
        {
            for(int z=0;z<5;z++)
            {
                c.mat[i][z]=min(c.mat[i][z],a.mat[i][k]+b.mat[k][z]);
            }
        }
    }
    return c;
}
int n,m,q,rt,w[200001];
mat sum[800001],inn;
void add(int o,int l,int r,int x,mat y)
{
    if(l==r)
    {
        sum[o]=y;
        return;
    }
    int mid=r+l>>1;
    if(x<=mid) add((o<<1),l,mid,x,y);
    else add((o<<1)+1,mid+1,r,x,y);
    sum[o]=sum[(o<<1)]*sum[(o<<1)+1];
}
mat get(int o,int l,int r,int x,int y)
{
    if(x<=l&&y>=r) return sum[o];
    int mid=l+r>>1;
    if(mid>=y)
    {
    	return get(o<<1,l,mid,x,y);
	}
	if(x>mid)
	{
		return get((o<<1)+1,mid+1,r,x,y);
	}
    return get(o<<1,l,mid,x,y)*get((o<<1)+1,mid+1,r,x,y);	
}

解决树上DDP问题

使用树链剖分把树断为链,重链内是序列问题可以自己解决。而重链之间的转移成为难点。

我们称一个重链顶与他的父亲组成一个卡口。改变一个点的值后,所有他到父亲的卡口值会改变。体现轻重链,我们设 \(g_u\) 为只与 \(u\) 亲儿子有关的转移,\(f_{uw}\)\(u\) 的重儿子的 \(DP\) 值,我们必须把 \(f_u\) 转移写成只与 \(g_u\)\(f_{uw}\) 有关的式子。

为什么呢?

保证时间复杂度,因为每个重链内是序列问题,它是不用改变的,而到了卡口,\(g\) 值会变。若和其他 \(f\) 有关,那么改变一个点的值将导致他到根的所有 \(f\) 值改变,因为他们的转移都依赖于此。

模板题

#include<iostream>
#include<cstdio>
#include<cstring>
#include<vector>
using namespace std;
struct mat
{
	int mat[2][2];
}gg[100001];
mat operator *(mat a,mat b)
{
    mat c;
    for(int i=0;i<2;i++)
    {
    	for(int z=0;z<2;z++)
    	{
    		c.mat[i][z]=-100000000;
		}
	}
    for(int k=0;k<2;k++)
    {
        for(int i=0;i<2;i++)
        {
            for(int z=0;z<2;z++)
            {
                c.mat[i][z]=max(c.mat[i][z],a.mat[i][k]+b.mat[k][z]);
            }
        }
    }
    return c;
}
mat mul(mat a,mat b)
{
    mat c;
    for(int i=0;i<2;i++)
    {
    	for(int z=0;z<2;z++)
    	{
    		c.mat[i][z]=-100000000;
		}
	}
    for(int k=0;k<2;k++)
    {
        for(int i=0;i<1;i++)
        {
            for(int z=0;z<2;z++)
            {
                c.mat[i][z]=max(c.mat[i][z],a.mat[i][k]+b.mat[k][z]);
            }
        }
    }
    return c;
}
int n,m,q,rt,w[200001];
mat sum[800001];
int fat[100001],siz[100001],dep[100001],hson[100001],top[100001],cnt,dfn[100001],dis[100001],f[100001][2],downd[100001];
vector<int> g[1000001];
void add(int o,int l,int r,int x,mat y)
{
    if(l==r)
    {
        sum[o]=y;
        return;
    }
    int mid=r+l>>1;
    if(x<=mid) add((o<<1),l,mid,x,y);
    else add((o<<1)+1,mid+1,r,x,y);
    sum[o]=sum[(o<<1)]*sum[(o<<1)+1];
}
mat get(int o,int l,int r,int x,int y)
{
    if(x<=l&&y>=r) return sum[o];
    int mid=l+r>>1;
    if(mid>=y)
    {
    	return get(o<<1,l,mid,x,y);
	}
	if(x>mid)
	{
		return get((o<<1)+1,mid+1,r,x,y);
	}
    return get(o<<1,l,mid,x,y)*get((o<<1)+1,mid+1,r,x,y);	
}
void getdfsh(int u,int fa)
{
    fat[u]=fa;
    dep[u]=dep[fa]+1;
    int lll=0;
    f[u][1]=w[u];
    for(int i=0;i<g[u].size();i++)
    {
        int v=g[u][i];
        if(v==fa) continue;
        getdfsh(v,u);
        if(siz[v]>lll)
        {
            hson[u]=v;
            lll=siz[v];
        }
        siz[u]+=siz[v];
	    f[u][1]+=f[v][0];
	    f[u][0]+=max(f[v][0],f[v][1]);
    }
    siz[u]++;
}
void gettd(int u,int fa)
{
	gg[u].mat[1][0]=w[u];
	gg[u].mat[1][1]=-100000000;
    dfn[u]=++cnt;
    dis[u]=cnt;
    if(hson[fat[u]]==u)
    {
        top[u]=top[fa];
        downd[top[u]]=dfn[u];
    }
    else
    {
        top[u]=u;
        downd[top[u]]=dfn[u];
    }
    if(hson[u]!=0) gettd(hson[u],u);
    for(int i=0;i<g[u].size();i++)
    {
        int v=g[u][i];
        if(v==fa||v==hson[u]) continue;
        gettd(v,u);
	    gg[u].mat[0][0]+=max(f[v][0],f[v][1]);
	    gg[u].mat[1][0]+=f[v][0];
    }
  	gg[u].mat[0][1]=gg[u].mat[0][0];
}
void getdis(int u, int fa) {
    for(int i=0;i<g[u].size();i++)
	{
        int v=g[u][i];
        if (v==fa) continue;
        getdis(v,u);
        dis[u]=max(dis[u],dis[v]);
    }
}
void update(int x,int val)
{
 	gg[x].mat[1][0]+=val-w[x];
	w[x]=val;
 	while(x)
	{
   	 	mat las=get(1,1,n,dfn[top[x]],downd[top[x]]);
   	 	add(1,1,n,dfn[x],gg[x]);
   	 	mat now=get(1,1,n,dfn[top[x]],downd[top[x]]);
   	 	x=fat[top[x]];
   	 	gg[x].mat[0][0]+=max(now.mat[0][0],now.mat[1][0])-max(las.mat[0][0],las.mat[1][0]);
   	 	gg[x].mat[0][1]=gg[x].mat[0][0];
   	 	gg[x].mat[1][0]+=now.mat[0][0]-las.mat[0][0];
	}
}
signed main()
{
	scanf("%d",&n);
	scanf("%d",&m);
    for(int i=1;i<=n;i++)
    {
        scanf("%d",&w[i]);
    }
    for(int i=1,u,v;i<n;i++)
    {
        scanf("%d%d",&u,&v);
        g[u].push_back(v);
        g[v].push_back(u);
    }
    getdfsh(1,0);
    gettd(1,0);
    getdis(1,0);
    for(int i=1;i<=n;i++)
    {
    	add(1,1,n,dfn[i],gg[i]);
	}
  	for(int i=1;i<=m;i++)
	{
  	  	int x,val;
  	 	scanf("%d%d",&x,&val);
  		update(x,val);
  	 	mat ans=get(1,1,n,1,downd[1]);
  	 	printf("%d\n",max(ans.mat[0][0],ans.mat[1][0]));
	}
}
posted on 2023-07-11 21:09  lizhous  阅读(12)  评论(0编辑  收藏  举报