差值 dp
最近遇到的一类 dp 题。
适用范围
某些问题规定了变量 \(A\) 和变量 \(B\) 的大小关系,但并不在意其具体数值。
如,\(A>B,A=B\) 等。
暴力 dp 通常为 \(f_{i,j}\) 表示变量 \(A\) 的值为 \(i\),变量 \(B\) 的值为 \(j\),这样需要转移两维。
于是考虑设 \(f_i\) 表示变量 \(A-B\) 的值为 \(i\),可以少转移一维。
注意 \(A-B\) 的值可能为负,通常处理方式是下标整体右移。
例题
- P1651 塔
设 \(f_{i,j}\) 表示考虑前 \(i\) 块积木,第一座塔减去第二座塔的高度为 \(j\) 的第一座塔的高度。
转移可以不放,放第一座,放第二座。
- \(f_{i,j}\gets f_{i-1,j}\)
- \(f_{i,j}\gets f_{i-1,j-a_i}+a_i\)
- \(f_{i,j}\gets f_{i-1,j+a_i}\)
#include<bits/stdc++.h>
#define sd std::
//#define int long long
#define F(i,a,b) for(int i=(a);i<=(b);i++)
#define ff(i,a,b) for(int i=(a);i>=(b);i--)
#define me(x,y) memset(x,y,sizeof x)
#define pii sd pair<int,int>
#define X first
#define Y second
#define Fr(a) for(auto it:a)
#define inf 2e7
#define dbg(x) sd cout<<#x<<":"<<x<<" "
#define dg(x) sd cout<<#x<<":"<<x<<"\n"
int read(){int w=1,c=0;char ch=getchar();for(;ch>'9'||ch<'0';ch=getchar()) if(ch=='-') w=-1;for(;ch>='0'&&ch<='9';ch=getchar()) c=(c<<3)+(c<<1)+ch-48;return w*c;}
void printt(int x){if(x>9) printt(x/10);putchar(x%10+48);}
void print(int x){if(x<0) putchar('-'),printt(-x);else printt(x);}
void put(int x){print(x);putchar('\n');}
void printk(int x){print(x);putchar(' ');}
const int N=110;
int n,a[N];
int f[2][1000010];
int sum;
void solve()
{
n=read();
F(i,1,n) a[i]=read(),sum+=a[i];
F(j,0,2*sum) f[1][j]=f[0][j]=-inf;
f[0][sum]=0;//初始化
F(i,1,n)//j为第一座塔减去第二座塔的高度
{
F(j,-sum,sum)
{
f[i&1][j+sum]=f[(i-1)&1][j+sum];//不放
if(j-a[i]>=-sum) f[i&1][j+sum]=sd max(f[i&1][j+sum],f[(i-1)&1][j-a[i]+sum]+a[i]);//第一座塔
if(j+a[i]<=sum) f[i&1][j+sum]=sd max(f[i&1][j+sum],f[(i-1)&1][j+a[i]+sum]);//第二座塔
}
}
put(f[n&1][sum]>0?f[n&1][sum]:-1);
}
int main()
{
int T=1;
// T=read();
while(T--) solve();
return 0;
}
- 「CSP-S2019」 Emiya 家今天的饭
发现不满足第二个条件的列最多只有一个,考虑容斥。假设上面两个条件是 \(P1,P2\)。
答案即满足 \(P1\) 的方案数减去满足 \(P1\) 不满足 \(P2\) 的方案数。
- 满足 \(P1\) 的方案数。
设 \(f_{i}\) 为前 \(i\) 行的方案数,\(f_0=1\)。
则 \(f_i=f_{i-1}\times (1+\sum a_{i,k})\),最后 \(f_n-1\) 就是方案数,减去全不选的方案。
- 满足 \(P1\),不满足 \(P2\) 的方案数。
枚举第 \(k\) 列不满足。
设第 \(k\) 列选了 \(a\) 个,其他列选了 \(b\) 个。
则 \(a> \lfloor\frac{a+b}{2}\rfloor\)。
分类一下:
- 若 \(a+b\) 为奇数,则 \(a\ge b\)。
- 若 \(a+b\) 为偶数,则 \(a> b\)。
考虑差值 dp。
设 \(g_{i,j,0/1}\) 为前 \(i\) 行,\(a-b\) 的值,\(a+b\) 为偶数/奇数的方案。
转移分为三类:
- 第 \(i\) 行不选。
- 第 \(i\) 行选择第 \(k\) 列。共有 \(a_{i,k}\) 种选法。
- 第 \(i\) 行选择其他列。共有 \(\sum\limits_{p\not =k} a_{i,p}\) 种选法。
因为 \(j\) 可能是负数,所以右移 \(n\) 个。
初始化 \(g_{0,0,0}=1\),右移后即 \(g_{0,n,0}=1\)。
最终答案为 \(\sum g_{n,p>n,0}+g_{n,p\ge n,1}\)。
#include<bits/stdc++.h>
#define sd std::
#define int long long
#define F(i,a,b) for(int i=(a);i<=(b);i++)
#define ff(i,a,b) for(int i=(a);i>=(b);i--)
#define me(x,y) memset(x,y,sizeof x)
#define pii sd pair<int,int>
#define X first
#define Y second
#define Fr(a) for(auto it:a)
#define dbg(x) sd cout<<#x<<":"<<x<<" "
#define dg(x) sd cout<<#x<<":"<<x<<"\n"
int read(){int w=1,c=0;char ch=getchar();for(;ch>'9'||ch<'0';ch=getchar()) if(ch=='-') w=-1;for(;ch>='0'&&ch<='9';ch=getchar()) c=(c<<3)+(c<<1)+ch-48;return w*c;}
void printt(int x){if(x>9) printt(x/10);putchar(x%10+48);}
void print(int x){if(x<0) putchar('-'),printt(-x);else printt(x);}
void put(int x){print(x);putchar('\n');}
void printk(int x){print(x);putchar(' ');}
const int N=2010,P=998244353;
int n,m,a[210][N],g[210][210][2];
int sum[N];
void solve()
{
n=read();m=read();
F(i,1,n) F(j,1,m) a[i][j]=read(),(sum[i]+=a[i][j])%=P;
int val=1;
F(i,1,n) val=val*(sum[i]+1)%P;
val=(val-1+P)%P;
g[0][n][0]=1;
int ans=0;
F(k,1,m)//枚举哪一列
{
F(i,1,n) F(j,n-i,n+i) g[i][j][0]=g[i][j][1]=0;
F(i,1,n)//行
{
F(j,n-i,n+i)
{
g[i][j][0]=g[i-1][j][0];
g[i][j][1]=g[i-1][j][1];
(g[i][j][0]+=g[i-1][j+1][1]*(sum[i]-a[i][k])%P)%=P;
(g[i][j][1]+=g[i-1][j+1][0]*(sum[i]-a[i][k])%P)%=P;
if(j) (g[i][j][1]+=g[i-1][j-1][0]*a[i][k]%P)%=P;
if(j) (g[i][j][0]+=g[i-1][j-1][1]*a[i][k]%P)%=P;
}
}
F(i,n,2*n) (ans+=g[n][i][1])%=P;
F(i,n+1,2*n) (ans+=g[n][i][0])%=P;
}
put((val-ans+P)%P);
}
signed main()
{
int T=1;
// T=read();
while(T--) solve();
return 0;
}
- 「AGC043D」 Merge Triplets
考虑通过计算 \(A\) 的数量来计算 \(P\)。
但某些不同的 \(A\) 计算出的 \(P\) 是相同的。比如若每个大小为 \(3\) 的块都长度递增的话,最终 \(P\) 都是 \(1\sim n\)。
假设一个块内的数是 \(x,y,z\)。
发现若 \(y>z\),则总序列中 \(z\) 一定会接在 \(y\) 后面。同样的,若 \(x>y\),则总序列中 \(y\) 一定会接在 \(x\) 后面。
于是这个块的弹出方式有 \([1,3]\) 个一起弹出。
不妨将一起弹出的数看作一个段,然后 \(A\) 生成 \(P\) 的方式就是将段按头排序,然后依次生成。
考虑这些段的限制:
- 段的长度为 \([1,3]\)。
- 同一个段长度递减。
- 长度为 \(1\) 的段比长度为 \(2\) 的段多。因为从一个长度为 \(3\) 的块中生成一个 \(2\) 段就必须生成一个 \(1\) 段。
我们只关心合法的段头集合和剩下的数的排列方案。
于是我们先选择一些数作为段头,然后计算剩下的数有多少种放置方式。
假设这里有一个集合 \(1\sim n\),被选择作为段头的数被取走,剩余不是段头的数。
考虑每次取走一个最大的数作为段头,然后在剩余的数中选择一些作为这段的数,显然剩余的数都比这个数小。
可能有一个疑问是可以不取最大的数,先取其它的,但由于段的限制 \(2\),这个最大的数最终还是会变为段头,而我们只关心段头集合,不关心实际取的顺序。所以这样选是不重不漏的。
考虑差值 dp。
设 \(f_{i,j}\) 为取 \(i\) 个数,\(1\) 段比 \(2\) 段多 \(j\) 个的总方案。
转移:
- 可以是一个 \(3\) 段:
此时已经选了 \(i-3\) 个数,然后选了一个最大的作为段头,还剩 \((n-i+2)\) 个数可选,从中选择两个(因为选择的段头是最大的,剩余的都比它小)作为段内的数,因为有顺序关系所以不是 \(\binom{n-i+2}{2}\)。
- 可以是一个 \(2\) 段:
- 可以是一个 \(1\) 段:
初始值 \(f_{0,0}=1\)。
#include<bits/stdc++.h>
#define sd std::
#define int long long
#define F(i,a,b) for(int i=(a);i<=(b);i++)
#define ff(i,a,b) for(int i=(a);i>=(b);i--)
#define me(x,y) memset(x,y,sizeof x)
#define pii sd pair<int,int>
#define X first
#define Y second
#define Fr(a) for(auto it:a)
#define dbg(x) sd cout<<#x<<":"<<x<<" "
#define dg(x) sd cout<<#x<<":"<<x<<"\n"
int read(){int w=1,c=0;char ch=getchar();for(;ch>'9'||ch<'0';ch=getchar()) if(ch=='-') w=-1;for(;ch>='0'&&ch<='9';ch=getchar()) c=(c<<3)+(c<<1)+ch-48;return w*c;}
void printt(int x){if(x>9) printt(x/10);putchar(x%10+48);}
void print(int x){if(x<0) putchar('-'),printt(-x);else printt(x);}
void put(int x){print(x);putchar('\n');}
void printk(int x){print(x);putchar(' ');}
const int N=6010,M=N*2;
int n,P;
int f[N][M],jc[M],inv[M];
void prework(int n)
{
inv[0]=jc[0]=inv[1]=1;
F(i,1,n) jc[i]=jc[i-1]*i%P;
F(i,2,n) inv[i]=(P-P/i)*inv[P%i]%P;
F(i,1,n) inv[i]=inv[i-1]*inv[i]%P;
}
int C(int n,int m)
{
if(n<m) return 0;
return jc[n]*inv[m]%P*inv[n-m]%P;
}
void solve()
{
n=read(),P=read();
n*=3;
prework(12000);
f[0][n]=1;//整体右移 n
F(i,1,n)
{
F(j,0,2*n)
{
if(j) f[i][j]=f[i-1][j-1];
if(i>1) f[i][j]=(f[i][j]+f[i-2][j+1]*(n-i+1)%P)%P;
if(i>2) f[i][j]=(f[i][j]+f[i-3][j]*(n-i+2)*(n-i+1)%P)%P;
}
}
int ans=0;
F(j,0,n) ans=(ans+f[n][j+n])%P;
put(ans);
}
signed main()
{
int T=1;
// T=read();
while(T--) solve();
return 0;
}

浙公网安备 33010602011771号