差值 dp

最近遇到的一类 dp 题。

适用范围

某些问题规定了变量 \(A\) 和变量 \(B\) 的大小关系,但并不在意其具体数值。

如,\(A>B,A=B\) 等。

暴力 dp 通常为 \(f_{i,j}\) 表示变量 \(A\) 的值为 \(i\),变量 \(B\) 的值为 \(j\),这样需要转移两维。

于是考虑设 \(f_i\) 表示变量 \(A-B\) 的值为 \(i\),可以少转移一维。

注意 \(A-B\) 的值可能为负,通常处理方式是下标整体右移。

例题

  • P1651 塔

\(f_{i,j}\) 表示考虑前 \(i\) 块积木,第一座塔减去第二座塔的高度为 \(j\) 的第一座塔的高度。

转移可以不放,放第一座,放第二座。

  • \(f_{i,j}\gets f_{i-1,j}\)
  • \(f_{i,j}\gets f_{i-1,j-a_i}+a_i\)
  • \(f_{i,j}\gets f_{i-1,j+a_i}\)
#include<bits/stdc++.h>
#define sd std::
//#define int long long
#define F(i,a,b) for(int i=(a);i<=(b);i++)
#define ff(i,a,b) for(int i=(a);i>=(b);i--)
#define me(x,y) memset(x,y,sizeof x)
#define pii sd pair<int,int>
#define X first
#define Y second
#define Fr(a) for(auto it:a)
#define inf 2e7
#define dbg(x) sd cout<<#x<<":"<<x<<" "
#define dg(x) sd cout<<#x<<":"<<x<<"\n"
int read(){int w=1,c=0;char ch=getchar();for(;ch>'9'||ch<'0';ch=getchar()) if(ch=='-') w=-1;for(;ch>='0'&&ch<='9';ch=getchar()) c=(c<<3)+(c<<1)+ch-48;return w*c;}
void printt(int x){if(x>9) printt(x/10);putchar(x%10+48);}
void print(int x){if(x<0) putchar('-'),printt(-x);else printt(x);}
void put(int x){print(x);putchar('\n');}
void printk(int x){print(x);putchar(' ');}
const int N=110;
int n,a[N];
int f[2][1000010];
int sum;
void solve()
{
	n=read();
	F(i,1,n) a[i]=read(),sum+=a[i];
	F(j,0,2*sum) f[1][j]=f[0][j]=-inf;
	f[0][sum]=0;//初始化
	F(i,1,n)//j为第一座塔减去第二座塔的高度
	{
		F(j,-sum,sum)
		{
			f[i&1][j+sum]=f[(i-1)&1][j+sum];//不放
			if(j-a[i]>=-sum) f[i&1][j+sum]=sd max(f[i&1][j+sum],f[(i-1)&1][j-a[i]+sum]+a[i]);//第一座塔
			if(j+a[i]<=sum) f[i&1][j+sum]=sd max(f[i&1][j+sum],f[(i-1)&1][j+a[i]+sum]);//第二座塔
		}
	}
	put(f[n&1][sum]>0?f[n&1][sum]:-1);
}
int main()
{
	int T=1;
//	T=read();
	while(T--) solve();
    return 0;
}
  • 「CSP-S2019」 Emiya 家今天的饭

发现不满足第二个条件的列最多只有一个,考虑容斥。假设上面两个条件是 \(P1,P2\)

答案即满足 \(P1\) 的方案数减去满足 \(P1\) 不满足 \(P2\) 的方案数。

  • 满足 \(P1\) 的方案数。

\(f_{i}\) 为前 \(i\) 行的方案数,\(f_0=1\)

\(f_i=f_{i-1}\times (1+\sum a_{i,k})\),最后 \(f_n-1\) 就是方案数,减去全不选的方案。

  • 满足 \(P1\),不满足 \(P2\) 的方案数。

枚举第 \(k\) 列不满足。

设第 \(k\) 列选了 \(a\) 个,其他列选了 \(b\) 个。

\(a> \lfloor\frac{a+b}{2}\rfloor\)

分类一下:

  • \(a+b\) 为奇数,则 \(a\ge b\)
  • \(a+b\) 为偶数,则 \(a> b\)

考虑差值 dp。

\(g_{i,j,0/1}\) 为前 \(i\) 行,\(a-b\) 的值,\(a+b\) 为偶数/奇数的方案。

转移分为三类:

  1. \(i\) 行不选。

\[g_{i,j,1}\gets g_{i-1,j,1} \]

\[g_{i,j,0}\gets g_{i-1,j,0} \]

  1. \(i\) 行选择第 \(k\) 列。共有 \(a_{i,k}\) 种选法。

\[g_{i,j,1}\gets g_{i-1,j-1,0}\times a_{i,k} \]

\[g_{i,j,0}\gets g_{i-1,.j-1,1}\times a_{i,k} \]

  1. \(i\) 行选择其他列。共有 \(\sum\limits_{p\not =k} a_{i,p}\) 种选法。

\[g_{i,j,1}\gets g_{i-1,j+1,0}\times \sum\limits_{p\not =k} a_{i,p} \]

\[g_{i,j,0}\gets g_{i-1,j+1,1}\times \sum\limits_{p\not =k} a_{i,p} \]

因为 \(j\) 可能是负数,所以右移 \(n\) 个。

初始化 \(g_{0,0,0}=1\),右移后即 \(g_{0,n,0}=1\)

最终答案为 \(\sum g_{n,p>n,0}+g_{n,p\ge n,1}\)

#include<bits/stdc++.h>
#define sd std::
#define int long long
#define F(i,a,b) for(int i=(a);i<=(b);i++)
#define ff(i,a,b) for(int i=(a);i>=(b);i--)
#define me(x,y) memset(x,y,sizeof x)
#define pii sd pair<int,int>
#define X first
#define Y second
#define Fr(a) for(auto it:a)
#define dbg(x) sd cout<<#x<<":"<<x<<" "
#define dg(x) sd cout<<#x<<":"<<x<<"\n"
int read(){int w=1,c=0;char ch=getchar();for(;ch>'9'||ch<'0';ch=getchar()) if(ch=='-') w=-1;for(;ch>='0'&&ch<='9';ch=getchar()) c=(c<<3)+(c<<1)+ch-48;return w*c;}
void printt(int x){if(x>9) printt(x/10);putchar(x%10+48);}
void print(int x){if(x<0) putchar('-'),printt(-x);else printt(x);}
void put(int x){print(x);putchar('\n');}
void printk(int x){print(x);putchar(' ');}
const int N=2010,P=998244353;
int n,m,a[210][N],g[210][210][2];
int sum[N];
void solve()
{
	n=read();m=read();
	F(i,1,n) F(j,1,m) a[i][j]=read(),(sum[i]+=a[i][j])%=P;
	int val=1;
	F(i,1,n) val=val*(sum[i]+1)%P;
	val=(val-1+P)%P;
	g[0][n][0]=1;
	int ans=0;
	F(k,1,m)//枚举哪一列
	{
		F(i,1,n) F(j,n-i,n+i) g[i][j][0]=g[i][j][1]=0;
		F(i,1,n)//行
		{
			F(j,n-i,n+i)
			{
				g[i][j][0]=g[i-1][j][0];
				g[i][j][1]=g[i-1][j][1];
				(g[i][j][0]+=g[i-1][j+1][1]*(sum[i]-a[i][k])%P)%=P;
				(g[i][j][1]+=g[i-1][j+1][0]*(sum[i]-a[i][k])%P)%=P;
				if(j) (g[i][j][1]+=g[i-1][j-1][0]*a[i][k]%P)%=P;
				if(j) (g[i][j][0]+=g[i-1][j-1][1]*a[i][k]%P)%=P;
			}
		}
		F(i,n,2*n) (ans+=g[n][i][1])%=P;
		F(i,n+1,2*n) (ans+=g[n][i][0])%=P;
	}
	put((val-ans+P)%P);
}
signed main()
{
	int T=1;
//	T=read();
	while(T--) solve();
    return 0;
}
  • 「AGC043D」 Merge Triplets

考虑通过计算 \(A\) 的数量来计算 \(P\)

但某些不同的 \(A\) 计算出的 \(P\) 是相同的。比如若每个大小为 \(3\) 的块都长度递增的话,最终 \(P\) 都是 \(1\sim n\)

假设一个块内的数是 \(x,y,z\)

发现若 \(y>z\),则总序列中 \(z\) 一定会接在 \(y\) 后面。同样的,若 \(x>y\),则总序列中 \(y\) 一定会接在 \(x\) 后面。

于是这个块的弹出方式有 \([1,3]\) 个一起弹出。

不妨将一起弹出的数看作一个段,然后 \(A\) 生成 \(P\) 的方式就是将段按头排序,然后依次生成。

考虑这些段的限制:

  1. 段的长度为 \([1,3]\)
  2. 同一个段长度递减。
  3. 长度为 \(1\) 的段比长度为 \(2\) 的段多。因为从一个长度为 \(3\) 的块中生成一个 \(2\) 段就必须生成一个 \(1\) 段。

我们只关心合法的段头集合和剩下的数的排列方案。

于是我们先选择一些数作为段头,然后计算剩下的数有多少种放置方式。

假设这里有一个集合 \(1\sim n\),被选择作为段头的数被取走,剩余不是段头的数。

考虑每次取走一个最大的数作为段头,然后在剩余的数中选择一些作为这段的数,显然剩余的数都比这个数小。

可能有一个疑问是可以不取最大的数,先取其它的,但由于段的限制 \(2\),这个最大的数最终还是会变为段头,而我们只关心段头集合,不关心实际取的顺序。所以这样选是不重不漏的。

考虑差值 dp。

\(f_{i,j}\) 为取 \(i\) 个数,\(1\) 段比 \(2\) 段多 \(j\) 个的总方案。

转移:

  • 可以是一个 \(3\) 段:

\[f_{i,j}\gets f_{i-3,j}\times (n-i+2)\times (n-i+1) \]

此时已经选了 \(i-3\) 个数,然后选了一个最大的作为段头,还剩 \((n-i+2)\) 个数可选,从中选择两个(因为选择的段头是最大的,剩余的都比它小)作为段内的数,因为有顺序关系所以不是 \(\binom{n-i+2}{2}\)

  • 可以是一个 \(2\) 段:

\[f_{i,j}\gets f_{i-2,j+1}\times (n-i+1) \]

  • 可以是一个 \(1\) 段:

\[f_{i,j}\gets f_{i-1,j-1} \]

初始值 \(f_{0,0}=1\)

#include<bits/stdc++.h>
#define sd std::
#define int long long
#define F(i,a,b) for(int i=(a);i<=(b);i++)
#define ff(i,a,b) for(int i=(a);i>=(b);i--)
#define me(x,y) memset(x,y,sizeof x)
#define pii sd pair<int,int>
#define X first
#define Y second
#define Fr(a) for(auto it:a)
#define dbg(x) sd cout<<#x<<":"<<x<<" "
#define dg(x) sd cout<<#x<<":"<<x<<"\n"
int read(){int w=1,c=0;char ch=getchar();for(;ch>'9'||ch<'0';ch=getchar()) if(ch=='-') w=-1;for(;ch>='0'&&ch<='9';ch=getchar()) c=(c<<3)+(c<<1)+ch-48;return w*c;}
void printt(int x){if(x>9) printt(x/10);putchar(x%10+48);}
void print(int x){if(x<0) putchar('-'),printt(-x);else printt(x);}
void put(int x){print(x);putchar('\n');}
void printk(int x){print(x);putchar(' ');}
const int N=6010,M=N*2;
int n,P;
int f[N][M],jc[M],inv[M];
void prework(int n)
{
    inv[0]=jc[0]=inv[1]=1;
    F(i,1,n) jc[i]=jc[i-1]*i%P;
    F(i,2,n) inv[i]=(P-P/i)*inv[P%i]%P;
    F(i,1,n) inv[i]=inv[i-1]*inv[i]%P;
}
int C(int n,int m)
{
    if(n<m) return 0;
    return jc[n]*inv[m]%P*inv[n-m]%P;
}
void solve()
{
	n=read(),P=read();
	n*=3;
	prework(12000);
	f[0][n]=1;//整体右移 n
	F(i,1,n)
	{
		F(j,0,2*n)
		{
			if(j) f[i][j]=f[i-1][j-1];
			if(i>1) f[i][j]=(f[i][j]+f[i-2][j+1]*(n-i+1)%P)%P;
			if(i>2) f[i][j]=(f[i][j]+f[i-3][j]*(n-i+2)*(n-i+1)%P)%P;
		}
	}
	int ans=0;
	F(j,0,n) ans=(ans+f[n][j+n])%P;
	put(ans);
}
signed main()
{
	int T=1;
//	T=read();
	while(T--) solve();
    return 0;
}
posted @ 2025-08-13 21:28  _E_M_T  阅读(13)  评论(0)    收藏  举报