【CSP-S 2019】D2T1 Emiya 家今天的饭
Description
Solution
算法1 32pts
爆搜,复杂度\(O((m+1)^n)\)
算法2 84pts
裸的dp,复杂度\(O(n^3m)\)
首先有一个显然的性质要知道:
最多只有一种主要食材出现在超过一半的主要食材里。
接下来考虑如果只有前两个限制条件的情况,那么答案就是
其中\(sum_i = \sum \limits_{j=1}^m a_{i,j}\),\(+1\)是因为对于每一行只有选一道菜或者不选这些选择,\(-1\)是因为要去除一道菜都不选的情况。
对于第3个限制条件,发现直接做不太好做,考虑容斥,即用总方案数,也就是上面的式子,减去不合法的方案数。
由最开始的那个性质可以得到一个做法:
枚举不合法的那一种主要食材,然后进行\(dp\)。发现我们并不需要知道每一种主要食材具体用在了多少道菜上,只需要知道当前枚举的食材用在了多少道菜,其它的并不影响方案。那么设\(f_{i,j,k}\)表示前\(i\)中烹饪方式,选了\(j\)道菜,其中\(k\)道的主要食材是枚举的不合法食材。转移分三种情况:令\(s\)表示当前枚举的不合法食材,
- 
不在这一种烹饪方式中进行选择:\(f_{i,j,k}=f_{i-1,j,k}\) 
- 
在这种烹饪方式中选择了合法的食材:\(f_{i,j,k}=(sum_i-a_{i,s}) \times f_{i,j-1,k}\) 
- 
在这种烹饪方式中选择了不合法的食材:\(f_{i,j,k}=a_{i,s}\times f_{i,j-1,k-1}\) 
那么不合法的方案数就是
code
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll mod = 998244353;
const int _ = 100 + 10;
const int __ = 2000 + 10;
int n, m, A[_][__];
ll sum[_], f[_][_ << 1], tmp, ans = 1;
int main() {
	scanf("%d%d", &n, &m);
	for (int i = 1; i <= n; ++i) {
		for (int j = 1; j <= m; ++j) {
			scanf("%d", &A[i][j]);
			sum[i] = (sum[i] + A[i][j]) % mod;
		}
		ans = ans * (sum[i] + 1) % mod;
	}
	ans = (ans - 1 + mod) % mod;
	for (int k = 1; k <= m; ++k) {
		memset(f, 0, sizeof(f));
		f[0][n] = 1;
		for (int i = 1; i <= n; ++i) {
			for (int j = -i + n; j <= i + n; ++j) {
				f[i][j] = (f[i - 1][j] + f[i - 1][j - 1] * A[i][k] % mod + f[i - 1][j + 1] * (sum[i] - A[i][k]) % mod) % mod;
				if (i == n && j > n) tmp = (tmp + f[i][j]) % mod;
			}
		}
	}
	ans = (ans - tmp + mod) % mod;
	printf("%lld\n", ans);
	return 0;
}
算法三 100pts
考虑如何对算法二的\(dp\)进行优化,减少不必要的状态。对限制三进行转化,限制三即为
发现并不需要关心使用了食材的菜的具体数量,只需要关心合法与不合法的菜的差值即可,即这个差值与原来差值相同的状态的集合是对应的,那么我们就可以以此为状态进行dp,转移与上面是类似的。
唯一要注意的一点是可能出现负数,要加上一个偏移量\(n\)
- 
不在这一种烹饪方式中进行选择:\(f_{i,j}=f_{i-1,j}\) 
- 
在这种烹饪方式中选择了合法的食材:\(f_{i,j}=(sum_i-a_{i,s}) \times f_{i,j+1}\) 
- 
在这种烹饪方式中选择了不合法的食材:\(f_{i,j}=a_{i,s}\times f_{i,j-1}\) 
code
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll mod = 998244353;
const int _ = 100 + 10;
const int __ = 2000 + 10;
int n, m, A[_][__];
ll sum[_], f[_][_ << 1], tmp, ans = 1;
int main() {
	scanf("%d%d", &n, &m);
	for (int i = 1; i <= n; ++i) {
		for (int j = 1; j <= m; ++j) {
			scanf("%d", &A[i][j]);
			sum[i] = (sum[i] + A[i][j]) % mod;
		}
		ans = ans * (sum[i] + 1) % mod;
	}
	ans = (ans - 1 + mod) % mod;
	for (int k = 1; k <= m; ++k) {
		memset(f, 0, sizeof(f));
		f[0][n] = 1;
		for (int i = 1; i <= n; ++i) {
			for (int j = -i + n; j <= i + n; ++j) {
				f[i][j] = (f[i - 1][j] + f[i - 1][j - 1] * A[i][k] % mod + f[i - 1][j + 1] * (sum[i] - A[i][k]) % mod) % mod;
				if (i == n && j > n) tmp = (tmp + f[i][j]) % mod;
			}
		}
	}
	ans = (ans - tmp + mod) % mod;
	printf("%lld\n", ans);
	return 0;
}

 
                
            
         浙公网安备 33010602011771号
浙公网安备 33010602011771号