ARC104E Random LIS

题面

题解

假设每个位置的值已经确定,为 \(a_i\),那么将 \((a_i, i)\) 二元组排序的方法唯一。

枚举最后的排序结果 \(p\)\(p_i\) 表示排序之后在排名为 \(i\) 的下标为 \(p_i\)

那么可以反推出 \(a_{p_1} (< \mathrm{or} \leq) \ a_{p_2} (< \mathrm{or} \leq) \ \cdots \ (< \mathrm{or} \leq) \ a_{p_n}\)

其中如果 \(p_i > p_{i + 1}\),那么 \(a_{p_i} \leq a_{p_{i + 1}}\) 否则是 \(<\)

考虑将 \(<\) 变成 \(\leq\):若 \(a_{p_i}\)\(a_{p_{i + 1}}\) 之间的限制为 \(<\),那么将所有 \(a_{p_j} (j > i)\) 的值全部减少 \(1\),那么这个位置的限制就变成了 \(\leq\),也就是说,\(p_{j} (j > i)\) 的所有位置的限制都减小了 \(1\)

接下来就可以各显神通了:可以用这题的方法 dp,也可以用下面介绍的方法(orz Itst):

将限制放到格路上,那么相当于从 \((1, 1)\) 走到 \((n + 1, \infty)\),只能向上向右走,当横坐标为 \(i\) 时纵坐标必须 \(\leq A_i\)(也就是之前的限制)的方案数,可以容斥计算。这样,如果不将计算组合数的时间算在复杂度里面的话,时间复杂度为 \(\mathcal O(n^2)\)

综上,时间复杂度为 \(\mathcal O(n!n^3)\)比 std 好像要优秀一些

代码

#include <cstdio>
#include <cstring>
#include <algorithm>
#include <vector>

inline int read()
{
	int data = 0, w = 1; char ch = getchar();
	while (ch != '-' && (ch < '0' || ch > '9')) ch = getchar();
	if (ch == '-') w = -1, ch = getchar();
	while (ch >= '0' && ch <= '9') data = data * 10 + (ch ^ 48), ch = getchar();
	return data * w;
}

const int N(10), Mod(1e9 + 7);
int n, a[N], id[N], h[N], ans, f[N];
int fastpow(int x, int y)
{
	int ans = 1;
	for (; y; y >>= 1, x = 1ll * x * x % Mod)
		if (y & 1) ans = 1ll * ans * x % Mod;
	return ans;
}

int C(int n, int m)
{
	int s = 1;
	for (int i = 1; i <= m; i++) s = 1ll * s * (n - i + 1) % Mod * fastpow(i, Mod - 2) % Mod;
	return s;
}

int main()
{
	n = read();
	for (int i = 1; i <= n; i++) a[i] = read(), id[i] = i;
	do
	{
		for (int i = 1; i <= n; i++) h[i] = a[id[i]] - 1;
		for (int i = 1; i <= n; i++) if (id[i] < id[i + 1])
			for (int j = i + 1; j <= n; j++) --h[j];
		for (int i = n - 1; i; i--) h[i] = std::min(h[i], h[i + 1]);
		std::memset(f, 0, sizeof f);
		for (int i = 1; i <= n; i++)
		{
			f[i] = C(h[i] + i - 1, i - 1);
			for (int j = 1; j < i; j++)
				f[i] = (f[i] - 1ll * f[j] * C(h[i] - h[j] - 1 + i - j, i - j) % Mod + Mod) % Mod;
		}
		int res = C(h[n] + n, n), mx = 0;
		for (int i = 1; i <= n; i++)
			res = (res - 1ll * f[i] * C(h[n] - h[i] + n - i, n - i + 1) % Mod + Mod) % Mod;
		std::memset(f, 0, sizeof f);
		for (int i = 1; i <= n; mx = std::max(mx, f[i++]))
			for (int j = f[i] = 1; j < i; j++)
				if (id[j] < id[i]) f[i] = std::max(f[i], f[j] + 1);
		ans = (ans + 1ll * res * mx) % Mod;
	} while (std::next_permutation(id + 1, id + n + 1));
	for (int i = 1; i <= n; i++) ans = 1ll * ans * fastpow(a[i], Mod - 2) % Mod;
	printf("%d\n", ans);
	return 0;
}
posted @ 2020-10-04 17:28  xgzc  阅读(354)  评论(2编辑  收藏  举报