题解:P5664 [CSP-S2019] Emiya 家今天的饭
一些闲话
原题链接
挺好的一道 DP 题,容斥+压缩维度优化 DP。
题意简述
Emiya 是个擅长做菜的高中生,他共掌握 \(n\) 种烹饪方法,且会使用 \(m\) 种主要食材做菜。我们对烹饪方法从 \(1\sim n\) 编号,对主要食材从 \(1\sim m\) 编号。Emiya 做的每道菜都使用恰好一种烹饪方法与恰好一种主要食材。更具体地,Emiya 会做 \(a_{i,j}\) 道不同的使用烹饪方法 \(i\) 和主要食材 \(j\) 的菜。
Emiya 今天要准备一桌饭招待 Yazid 和 Rin 这对好朋友,然而三个人对菜的搭配有不同的要求。对于一种包含 \(k\) 道菜的搭配方案而言:
- Emiya 将做至少一道菜,即 \(k \geq 1\);
- Rin 希望品尝不同烹饪方法做出的菜,因此她要求每道菜的烹饪方法互不相同;
- Yazid 不希望品尝太多同一食材做出的菜,因此他要求每种主要食材至多在一半的菜(即 \(\lfloor\frac{k}{2}\rfloor\) 道菜)中被使用。
Emiya 想知道共有多少种不同的符合要求的搭配方案。两种方案不同,当且仅当存在至少一道菜在一种方案中出现,而不在另一种方案中出现。
请你帮他计算符合所有要求的搭配方案数对质数 \(998244353\) 取模的结果。
对于所有测试点,\(1 \leq n \leq 100\),\(1 \leq m \leq 2000\),\(0 \leq a_{i,j} \lt 998244353\)。
题解
题目差不多就是在网格里选位置的方案数。
这题主要的难度在于第三个要求是难以刻画的。所以正难则反,我们考虑容斥,用总方案数减去不合法的方案数。不合法的方案意味着有且仅有一列被选择的数量超过了总数量的一半。
令 \(s_i=\sum_{j=1}^{m}{a_{i,j}}\),于是总方案数即为
容易想到 DP 求不合法的方案数。我们枚举被选择的数量超过了总数量一半的那一列 \(c\),然后令 \(f_{i,j,k}\) 表示考虑前 \(i\) 行,总共选择了 \(j\) 个位置,第 \(c\) 列选择了 \(k\) 个位置的方案数。转移时考虑当前行不选、选第 \(c\) 列之外的位置或选第 \(c\) 列的位置,即可得到转移方程:
转移时注意 \(j\) 和 \(k\) 从 0 开始枚举。最终答案就是
时间复杂度为 \(O(mn^3)\)。可以拿下 84 分的高分。
考虑进一步优化。这一步是比较巧妙的:我们注意到我们只关心 \(j\) 和 \(k\) 的相对大小关系,并且在计算最终答案时
所以我们无需枚举 \(j\) 和 \(k\) 的具体值,而是转而枚举差值 \(d=2k-j\) 即可,状态转移方程只需做轻微改动:
其中 \(-i\leq d\leq i\)。计算最终答案时减去满足 \(1\leq d\leq n\) 的对应 DP 值即可。实际实现时记得给 \(d\) 加上偏移量。
时间复杂度少了一个 \(n\),为 \(O(mn^2)\),可以通过本题。
代码
#include <iostream>
#include <cstring>
using namespace std;
#define add_mod(x, v) (x) = ((ll)(x) + (v)) % MOD
#define sub_mod(x, v) (x) = (((ll)(x) - (v)) % MOD + MOD) % MOD
#define mul_mod(x, v) (x) = (1ll * (x) * (v)) % MOD
typedef long long ll;
typedef pair<int, int> pii;
const int MOD = 998244353;
const int MAX_N = 105, MAX_M = 2e3 + 5;
int n, m, a[MAX_N][MAX_M];
ll s[MAX_N], tot = 1, f[MAX_N][MAX_N << 1];
int main() {
ios::sync_with_stdio(false); cin.tie(nullptr);
cin >> n >> m;
for (int i = 1; i <= n; ++i) {
for (int j = 1; j <= m; ++j) {
cin >> a[i][j];
add_mod(s[i], a[i][j]);
}
mul_mod(tot, s[i] + 1);
}
sub_mod(tot, 1);
for (int c = 1; c <= m; ++c) {
memset(f, 0, sizeof(f));
f[0][n] = 1;
for (int i = 1; i <= n; ++i)
for (int d = -i; d <= i; ++d) {
add_mod(f[i][d + n], f[i - 1][d + n]);
add_mod(f[i][d + n], f[i - 1][d - 1 + n] * a[i][c] % MOD);
add_mod(f[i][d + n], f[i - 1][d + 1 + n] * (s[i] - a[i][c]) % MOD);
}
for (int d = 1; d <= n; ++d) sub_mod(tot, f[n][d + n]);
}
cout << tot << '\n';
return 0;
}

浙公网安备 33010602011771号