「ZJOI2019」麻将(dp套dp)

Address

loj3042

Solution

\(ans_i\) 表示在给定的 \(13\) 张牌以外,再选出 \(i\) 张牌,使得这 \(13+i\) 张牌不存在 胡的子集的方案数,那么答案就是 \((\frac{1}{(4n-13)!}\sum_{i=1}^{4n-13}i!(4n-13-i)!ans_i)+1\)

接下来,考虑给你一个牌的集合,怎么判断它是否存在一个胡的子集。

首先判断胡的第二个条件:记 \(a_i\) 表示集合中有多少张第 \(i\) 种牌。若 \(\sum [a_i\ge 2]\ge 7\),则存在胡的子集。

再判断第一个条件:考虑 \(dp\)。记 \(f_{i,j,k}\) 表示考虑前 \(i\) 种牌,拿走 \(j\)\((i-1,i)\),拿走 \(k\)\(i\),剩下的牌最多能组成多少个面子。(注意 \(j\)\((i-1,i)\)\(k\)\(i\) 拿出来必须跟后面的牌组成面子)特殊地,\(f_{i,j,k}=-1\) 表示不存在这种状态。

\(g_{i,j,k}\) 表示考虑前 \(i\) 种牌,拿走 \(j\)\((i-1,i)\),拿走 \(k\)\(i\)再拿走一个对子,剩下的牌最多能组成多少个面子。

考虑到 \(3\) 个相同的顺子(形如 \(x,x+1,x+2\))可以变成 \(3\) 个相同的刻子 (形如 \(x,x,x\)),因此 \(j,k\in[0,2]\)

记集合中最大的牌为 \(m\),如果存在 \(g_{m,j,k}\ge 4\),那么存在胡的子集。

考虑转移,枚举加入 \(x\) 张大小为 \(i+1\) 的牌,枚举拿走 \(h\)\(i+1\),那么要组成 \(k\)\((i,i+1)\),组成 \(j\)\((i-1,i,i+1)\),再枚举要不要拿走 \(i+1\) 当对子,有:

\[f_{i+1,k,h}=max(f_{i,j,k}+j+[x-j-k-h\ge 3]),j+k+h\le x\ \&\&\ k\le 2 \]

\[g_{i+1,k,h}=max(f_{i,j,k}+j),j+k+h\le x-2 \]

\[g_{i+1,k,h}=max(g_{i,j,k}+j+[x-j-k-h\ge 3]),j+k+h\le x\ \&\&\ k\le 2 \]

考虑建一个自动机,自动机上的每一个节点对应一些不存在胡的子集的集合。每个节点都记录信息:\(f_{m,j,k},g_{m,j,k},cnt\)\(f_{m,j,k},g_{m,j,k},cnt\) 都相同的集合对应同一个节点,注意 \(m\) 可以不同,所以只要记 \(f_{j,k},g_{j,k},cnt\)。节点之间的转移边权 \(x\) 表示加入 \(x\) 张大小为 \(m+1\) 的牌。

初始节点:\(f_{0,0}=cnt=0\),其它为 \(-1\)。考虑用 \(dfs\) 构造自动机,枚举加入 \(x(x∈[0,4])\) 张新牌转移即可。转移可能成环,扩展出重复状态要剪枝。\(dfs\) 后可得节点数为 \(2091\)

\(ch_{x,y}\) 表示节点 \(x\) 走转移边 \(y\) 到达的节点,\(dp_{i,j,k}\) 表示考虑前 \(i\) 种牌,总共取走 \(j\) 张,目前走到自动机上的节点 \(k\) 的方案数。枚举第 \(i\) 张牌取了 \(h\) 张,有:$$dp_{i,j+h,ch_{k,h}}+=dp_{i-1,j,k}*c_{4-b_i}^{h-b_i}$$

其中 \(b_i\) 表示给定的 \(13\) 张牌中,有多少张大小为 \(i\) 的牌。

\(dp\) 要使用滚动数组,时间复杂度 \(O(2091×n^2)\)

Code

#include <bits/stdc++.h>

using namespace std;

#define ll long long

template <class t>
inline void read(t & res)
{
	char ch;
	while (ch = getchar(), !isdigit(ch));
	res = ch ^ 48;
	while (ch = getchar(), isdigit(ch))
	res = res * 10 + (ch ^ 48);
}

const int e = 505, o = 3005, mod = 998244353;

struct point
{
	int cnt, f[3][3], g[3][3];
	
	inline bool check()
	{
		if (cnt >= 7) return 1;
		for (int i = 0; i <= 2; i++)
		for (int j = 0; j <= 2; j++)
		if (g[i][j] >= 4) return 1;
		return 0;
	}
	
	inline point trans(int x)
	{
		point a;
		int i, j, k;
		a.cnt = min(7, cnt + (x >= 2));
		for (i = 0; i <= 2; i++)
		for (j = 0; j <= 2; j++)
		a.f[i][j] = a.g[i][j] = -1;
		for (i = 0; i <= 2; i++)
		for (j = 0; j <= 2; j++)
		{
			if (f[i][j] != -1)
			{
				for (k = 0; i + j + k <= x && k <= 2; k++)
				a.f[j][k] = max(a.f[j][k], f[i][j] + i + (x - i - j - k >= 3));
				for (k = 0; i + j + k <= x - 2; k++)
				a.g[j][k] = max(a.g[j][k], f[i][j] + i);
			}
			if (g[i][j] != -1)
			{
				for (k = 0; i + j + k <= x && k <= 2; k++)
				a.g[j][k] = max(a.g[j][k], g[i][j] + i + (x - i - j - k >= 3));
			}
		}
		for (i = 0; i <= 2; i++)
		for (j = 0; j <= 2; j++)
		a.f[i][j] = min(a.f[i][j], 4), a.g[i][j] = min(a.g[i][j], 4);
		return a;
	}
};

inline bool operator < (point a, point b)
{
	if (a.cnt != b.cnt) return a.cnt < b.cnt;
	int i, j;
	for (i = 0; i <= 2; i++)
	for (j = 0; j <= 2; j++)
	{
		if (a.f[i][j] != b.f[i][j]) return a.f[i][j] < b.f[i][j];
		if (a.g[i][j] != b.g[i][j]) return a.g[i][j] < b.g[i][j];
	}
	return 0;
}

inline bool operator == (point a, point b)
{
	if (a.cnt != b.cnt) return 0;
	int i, j;
	for (i = 0; i <= 2; i++)
	for (j = 0; j <= 2; j++)
	{
		if (a.f[i][j] != b.f[i][j]) return 0;
		if (a.g[i][j] != b.g[i][j]) return 0;
	}
	return 1;
}

map<point, int> id;
int cnt, fac[e], inv[e], ch[o][6], n, m, a[e], dp[2][e][o], ans;

inline void dfs(point a)
{
	int x = id[a];
	for (int i = 0; i <= 4; i++)
	{
		point b = a.trans(i);
		if (b.check()) continue;
		int y = id[b];
		if (y) ch[x][i] = y;
		else
		{
			id[b] = ++cnt;
			ch[x][i] = cnt;
			dfs(b);
		}
	}
}

inline int ksm(int x, int y)
{
	int res = 1;
	while (y)
	{
		if (y & 1) res = (ll)res * x % mod;
		y >>= 1;
		x = (ll)x * x % mod;
	}
	return res;
}

inline void init()
{
	point s;
	s.cnt = 0;
	int i, j;
	for (i = 0; i <= 2; i++)
	for (j = 0; j <= 2; j++)
	s.f[i][j] = s.g[i][j] = -1;
	s.f[0][0] = 0;
	id[s] = cnt = 1;
	dfs(s);
}

inline void add(int &x, int y)
{
	(x += y) >= mod && (x -= mod);
}

inline int c(int x, int y)
{
	return (ll)fac[x] * inv[y] % mod * inv[x - y] % mod;
}

inline void prepare()
{
	int i;
	fac[0] = 1;
	for (i = 1; i <= m; i++) fac[i] = (ll)fac[i - 1] * i % mod;
	inv[m] = ksm(fac[m], mod - 2);
	for (i = m - 1; i >= 0; i--) inv[i] = (ll)inv[i + 1] * (i + 1) % mod;
}

int main()
{
	freopen("mahjong.in", "r", stdin);
	freopen("mahjong.out", "w", stdout);
	read(n); m = n << 2;
	int i, j, k, h, x, y, sum = 0;
	init(); prepare();
	for (i = 1; i <= 13; i++) read(x), read(y), a[x]++;
	dp[0][0][1] = 1;
	for (i = 1; i <= n; i++)
	{
		int nxt = i & 1, lst = nxt ^ 1;
		for (j = 0; j <= sum + 4; j++)
		for (k = 1; k <= cnt; k++)
		dp[nxt][j][k] = 0;
		for (j = 0; j <= sum; j++)
		for (k = 1; k <= cnt; k++)
		if (dp[lst][j][k])
		{
			int v = dp[lst][j][k];
			for (h = a[i]; h <= 4; h++)
			if (ch[k][h])
			dp[nxt][j + h][ch[k][h]] = (dp[nxt][j + h][ch[k][h]] + (ll)v
			* c(4 - a[i], h - a[i])) % mod;
		}
		sum += 4;
	}
	for (i = 1; i <= m - 13; i++)
	{
		sum = 0;
		for (j = 1; j <= cnt; j++) add(sum, dp[n & 1][13 + i][j]);
		ans = (ans + (ll)sum * fac[i] % mod * fac[m - 13 - i]) % mod;
	}
	ans = (ll)ans * inv[m - 13] % mod;
	add(ans, 1);
	cout << ans << endl;
	fclose(stdin);
	fclose(stdout);
	return 0;
}

posted @ 2020-02-11 22:52  花淇淋  阅读(300)  评论(1编辑  收藏  举报