Luogu5664 Emiya 家今天的饭
https://www.luogu.com.cn/problem/P5664
\(DP\)
首先,根据题意,每一行至多取一个
直接求状态数太多,正难则反
我们用全部方案数减去不符合要求的方案数(注意除去全不取情况)
这里不符合要求的方案即食材数超过一半的那些方案
枚举哪一列食材数超过一半,其他任意取,记录非当前列取了多少次,当前列取了多少次
\(dp_{i,j,k}\)表示前\(i\)行,非当前列取了\(j\)次,当前列取了\(k\)次的方案数
\(S_i\)表示第\(i\)行的所有数之和
注:以下式子中数据范围略去
\[dp_{i,j,k}=dp_{i-1,j,k}+dp_{i-1,j-1,k}\times (S_i-a_{i,j})+dp_{i-1,j,k-1}\times a_{i,j}\\
ans=\sum_{j<k} dp_{n,j,k}
\]
可以滚动掉一维数组
时间复杂度:\(O(n^3m)\)
\(Code:\)
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#define N 105
#define M 2005
#define p 998244353
#define ll long long
using namespace std;
int n,m,sm[N],a[N][M],dp[2][N][N];
int ans,del;
void solve1()
{
ans=1;
for (int i=1;i<=n;i++)
ans=(ll)ans*(sm[i]+1)%p;
}
void pl(int &x,int y)
{
x=(x+y)%p;
}
void solve2()
{
for (int i=1;i<=m;i++)
{
int cur=0;
memset(dp,0,sizeof(dp));
dp[cur][0][0]=1;
for (int j=1;j<=n;j++)
{
cur^=1;
for (int k=0;k<=j;k++)
for (int t=0;t<=j-k;t++)
{
dp[cur][k][t]=dp[cur^1][k][t];
if (k)
pl(dp[cur][k][t],(ll)dp[cur^1][k-1][t]*(sm[j]-a[j][i])%p);
if (t)
pl(dp[cur][k][t],(ll)dp[cur^1][k][t-1]*a[j][i]%p);
}
}
for (int j=0;j<=n;j++)
for (int k=j+1;k<=n-j;k++)
pl(del,dp[cur][j][k]);
}
}
int main()
{
scanf("%d%d",&n,&m);
for (int i=1;i<=n;i++)
for (int j=1;j<=m;j++)
scanf("%d",&a[i][j]),sm[i]=(sm[i]+a[i][j])%p;
solve1();
solve2();
ans-=del+1;
ans=(ans%p+p)%p;
printf("%d\n",ans);
return 0;
}
观察\(dp\)方程,我们只关心\(j,k\)的差,因此\(dp\)方程中仅记录\(j,k\)之差即可
令\(t=k-j+n\)(防负数)
\[dp_{i,t}=dp_{i-1,t}+dp_{i-1,t+1}\times (S_i-a_{i,j})+dp_{i-1,t-1}\times a_{i,j}\\
ans=\sum_{t>n} dp_{n,t}
\]
同样可以滚动掉一维数组
时间复杂度:\(O(n^2m)\)
\(C++ Code:\)
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#define N 105
#define M 2005
#define p 998244353
#define ll long long
using namespace std;
int n,m,sm[N],a[N][M],dp[2][N << 1];
int ans,del;
int read()
{
int s=0;
char c=getchar();
while (c<'0' || c>'9')
c=getchar();
while ('0'<=c && c<='9')
{
s=s*10+c-'0';
c=getchar();
}
return s;
}
void solve1()
{
ans=1;
for (int i=1;i<=n;i++)
ans=(ll)ans*(sm[i]+1)%p;
}
void pl(int &x,int y)
{
x=(x+y)%p;
}
void solve2()
{
for (int i=1;i<=m;i++)
{
int cur=0;
memset(dp,0,sizeof(dp));
dp[cur][n]=1;
for (int j=1;j<=n;j++)
{
cur^=1;
for (int k=n-j;k<=n+j;k++)
{
dp[cur][k]=dp[cur^1][k];
if (k<(n << 1))
pl(dp[cur][k],(ll)dp[cur^1][k+1]*(sm[j]-a[j][i])%p);
if (k)
pl(dp[cur][k],(ll)dp[cur^1][k-1]*a[j][i]%p);
}
}
for (int k=n+1;k<=(n << 1);k++)
pl(del,dp[cur][k]);
}
}
int main()
{
n=read(),m=read();
for (int i=1;i<=n;i++)
for (int j=1;j<=m;j++)
a[i][j]=read(),sm[i]=(sm[i]+a[i][j])%p;
solve1();
solve2();
ans-=del+1;
ans=(ans%p+p)%p;
printf("%d\n",ans);
return 0;
}

浙公网安备 33010602011771号