[ZJOI2019]麻将

dp 套 dp 经典例题。

这种题一般都是给你一个奇怪的合法条件,然后去做一些计数之类的东西,直接设计状态很不好做。我们考虑先设计一个判定合法的 dp,以这个 dp 的状态和结果作为状态去 dp。

更一般的,我们发现 dp 的过程有初始状态和终止状态,转移看成有向边,可以建出一个自动机来。dp 套 dp就是在这个自动机上再跑dp。

对于这道题,我们先考虑判断能否胡牌。假设已经考虑了前 \(i-1\) 种牌,正在加第 \(i\) 种牌,\(f_{0/1,j,k}\) 表示 是否预留了雀头,留了 \(j\)\((i,i-1)\), \(k\)\((i,i)\),能组成的最大面子,显然 \(j,k<3\)(三个顺子可以变成三个刻子)。对于七对子的情况,我们还要记对子的个数特判。

于是我们用三元组 \((cnt,f[0],f[1])\) 来表示自动机的一个节点。其中 \(cnt\) 记对子个数用来特判七对,\(f[0/1]\)\(3\times 3\) 的矩阵,意义如上。每个非终止状态会伸出四条转移边,代表加入 \(0~3\) 张牌转移到的节点。bfs一遍就可以构建出整个自动机。节点数只有1092个,很少。

于是我们将检查是否合法的过程变成了一种一种牌的加入,在自动机上游走的过程。我们考虑在这个 DAG 上 dp。设 \(f_i\) 表示摸了 \(i\) 张牌仍无法胡牌的方案数,容易通过 \(f_i\) 算出答案。

\(dp_{i,j,k}\) 表示考虑前 \(i\) 种牌,摸了 \(j\) 张牌,自动机上走到点 \(k\) 的概率。转移直接枚举摸了多少张牌即可。

#include <bits/stdc++.h>

using namespace std;

const int N = 405, mod = 998244353;

inline int read() {
	register int s = 0, f = 1; register char ch = getchar();
	while (!isdigit(ch)) f = (ch == '-' ? -1 : 1), ch = getchar();
	while (isdigit(ch)) s = (s * 10) + (ch & 15), ch = getchar();
	return s * f;
}

inline int power(int a, int b) {
	int t = 1, y = a, k = b;
	while (k) {
		if (k & 1) t = 1ll * t * y % mod;
		y = 1ll * y * y % mod; k >>= 1;
	} return t;
}

struct DP {
	int a[3][3];
	DP() { memset(a, -1, sizeof a); }
	inline int* operator [](int x) { return a[x]; }
	inline bool operator < (const DP &b) const {
		for (int i = 0; i <= 2; ++i)
			for (int j = 0; j <= 2; ++j)
				if (a[i][j] != b.a[i][j]) return a[i][j] < b.a[i][j];
		return 0;
	}
	inline bool operator > (const DP &b) const {
		for (int i = 0; i <= 2; ++i)
			for (int j = 0; j <= 2; ++j)
				if (a[i][j] != b.a[i][j]) return a[i][j] > b.a[i][j];
		return 0;
	}
	inline bool operator == (const DP &b) const {
		for (int i = 0; i <= 2; ++i)
			for (int j = 0; j <= 2; ++j)
				if (a[i][j] != b.a[i][j]) return 0;
		return 1;
	}
};

inline DP calc(DP a, int x) {
	DP res;
	for (int i = 0; i <= 2; ++i) {
		for (int j = 0; j <= 2; ++j) {
			if (a[i][j] == -1) continue;
			for (int k = 0; k <= 2 && k <= x - i - j; ++k)
				res[j][k] = max(res[j][k], min(4, a[i][j] + i + (x - i - j - k) / 3));
		}
	} return res;
}

struct node {
	int pr;
	DP f[2];
	node() { pr = 0; }
	inline bool check() {
		if (pr >= 7) return 1;
		for (int i = 0; i <= 2; ++i)
			for (int j = 0; j <= 2; ++j)
				if (f[1][i][j] >= 4) return 1;
		return 0; 
	}
	inline bool operator < (const node &b) const { return pr == b.pr ? (f[0] == b.f[0] ? f[1] < b.f[1] : f[0] < b.f[0]) : pr < b.pr; }
	inline bool operator > (const node &b) const { return pr == b.pr ? (f[0] == b.f[0] ? f[1] > b.f[1] : f[0] > b.f[0]) : pr > b.pr; }
	inline bool operator == (const node &b) const { return pr == b.pr && f[0] == b.f[0] && f[1] == b.f[1]; }
	inline DP& operator [](int x) { return f[x]; }
};

inline node calc(node a, int x) {
	if (a.pr == -1) return a; node res;
	res.pr = a.pr + (x >= 2);
	res[0] = calc(a[0], x);
	res[1] = calc(a[1], x);
	if (x >= 2) {
		DP q = calc(a[0], x - 2);
		for (int i = 0; i <= 2; ++i)
			for (int j = 0; j <= 2; ++j)
				res[1][i][j] = max(res[1][i][j], q[i][j]);
	}
	if (res.check()) {
		memset(res[0].a, -1, sizeof res[0].a);
		memset(res[1].a, -1, sizeof res[1].a);
		res.pr = -1;
	}
	return res;
} 

map<node, int> id;
int tr[N << 3][5], cnt = 0, T = 0;

inline void bfs() {
	queue<node> q; node st; st[0][0][0] = 0;
	id[st] = ++cnt; q.push(st);
	while (!q.empty()) {
		node x = q.front(); int t = id[x]; q.pop();
		if (x.pr == -1) T = t;
		for (int i = 0; i <= 4; ++i) {
			node v = calc(x, i);
			if (id.find(v) == id.end())
				id[v] = ++cnt, q.push(v);
			tr[t][i] = id[v];
		}
	}
}

int fac[N], ifac[N], c[N], f[2][N][N << 3];

inline int C(int n, int m) {
	if (n < m) return 0;
	return 1ll * fac[n] * (1ll * ifac[m] * ifac[n - m] % mod) % mod;
} 

int main() {
	int n = read();
	for (int i = 1; i <= 13; ++i) {
		++c[read()]; read();
	} fac[0] = 1;
	for (int i = 1; i <= (n << 2); ++i) fac[i] = 1ll * i * fac[i - 1] % mod;
	ifac[n << 2] = power(fac[n << 2], mod - 2);
	for (int i = n << 2; i; --i) ifac[i - 1] = 1ll * i * ifac[i] % mod;
	bfs(); f[0][0][1] = 1; 
	for (int i = 1; i <= n; ++i) {
		int x = i & 1, y = x ^ 1;
		memset(f[x], 0, sizeof f[x]);
		for (int j = 1; j <= cnt; ++j)
			for (int k = 0; k <= 4 * i - 4; ++k)
				for (int l = 0; l <= 4 - c[i]; ++l) {
					f[x][k + l][tr[j][l + c[i]]] += 1ll * f[y][k][j] * C(4 - c[i], l) % mod;
					if (f[x][k + l][tr[j][l + c[i]]] >= mod)
						f[x][k + l][tr[j][l + c[i]]] -= mod;
				}
	} int res = 0;
	for (int i = 0; i <= (n << 2) - 13; ++i)
		for (int j = 1; j <= cnt; ++j)
			if (j != T) {
				res += 1ll * power(C((n << 2) - 13, i), mod - 2) * f[n & 1][i][j] % mod;
				if (res >= mod) res -= mod; 
			}
	printf("%d\n", res); return 0;
}
posted @ 2023-05-30 18:33  Smallbasic  阅读(61)  评论(0)    收藏  举报