题解:SP5830 ALTPERM - Alternating Permutations
题意:给你 \(K\) 个下标,保证 \(A_1=1,A_K=N\),且对任意的 \(i<N\) 有 \(A_i<A_{i+1}\)。
如果一个排列,在下标 \(A_1\) 到 \(A_2\) 处单调递增,在下标 \(A_2\) 到 \(A_3\) 处单调递减,在下标 \(A_3\) 到 \(A_4\) 处单调递增,依此类推,那么这个排列是合法的。
求合法的 \(N\) 长度合法排列数量,对 \(10^9+7\) 取模。
\(N\le 2\times 10^4,K\le 22\)。多组数据。
做法:
考虑最大值会在哪里取到,手玩一下,发现只有可能在满足 \(2\nmid i\) 的 \(A_i\) 取到。类似讨论最小值也会在 \(A\) 中取到,但是也有可能会在两端取到。
考虑直接记状态 \(dp_{l,r}\) 代表区间 \([l,r]\) 的答案,枚举最大值或者最小值情况可以转移,但是复杂度太大了。
我们再手玩,发现其实无论什么时候,我的区间内其实都有一端为 \(A\) 作为最小值或者最大值,所以我们状态可以改为 \(dp_{i,j}\),代表我一个端点为 \(A_i\),另一个端点为 \(j\),构成的区间 \((A_i,j]\) 或者 \([j,A_i)\),且 \(A_i\) 为极大值或极小值。这样状态数就压到 \(O(NK)\) 了。
转移考虑,如果我目前 \(A_i\) 为极大值,那么最小值就一定位于中间端点或者 \(j\),当然在 \(j\) 处取到最小值的条件是到其前面一个断点是递减的才行,枚举即可,枚举之后还需要乘上一个将数分到两侧的组合数系数。
代码:
#include <bits/stdc++.h>
using namespace std;
#define int long long
const int maxn = 2e4 + 5, mod = 1e9 + 7;
int n = 20000, k, pre[maxn], suf[maxn], a[maxn], jc[maxn], revjc[maxn];
int C(int m, int n) {
return jc[m] * revjc[n] % mod * revjc[m - n] % mod;
}
int dp[23][maxn], vis[23][maxn];
int cal(int p, int x) {
if(vis[p][x])
return dp[p][x];
vis[p][x] = 1;
//cout << p << " " << x << endl;
if(a[p] < x) {
if(a[p + 1] >= x)
return dp[p][x] = 1;
if((pre[x] % 2) == p % 2)
dp[p][x] = (dp[p][x] + cal(p, x - 1)) % mod;
for (int i = p + 1; i <= k && a[i] < x; i += 2)
dp[p][x] = (dp[p][x] + C(x - a[p] - 1, a[i] - a[p] - 1) * cal(i, a[p] + 1) % mod * cal(i, x)) % mod;
}
else {
if(a[p - 1] <= x)
return dp[p][x] = 1;
if(suf[x] % 2 == p % 2)
dp[p][x] = (dp[p][x] + cal(p, x + 1)) % mod;
for (int i = p - 1; i > 0 && a[i] > x; i -= 2)
dp[p][x] = (dp[p][x] + C(a[p] - x - 1, a[p] - a[i] - 1) * cal(i, a[p] - 1) % mod * cal(i, x)) % mod;
}
return dp[p][x];
}
void solve() {
cin >> n >> k;
for (int i = 1; i <= k; i++)
cin >> a[i];
for (int i = 1; i < k; i++) {
for (int j = a[i] + 1; j <= a[i + 1]; j++)
pre[j] = i;
}
for (int i = k; i > 1; i--) {
for (int j = a[i] - 1; j >= a[i - 1]; j--)
suf[j] = i;
}
suf[n] = k + 1, pre[1] = 0;
int ans = 0;
memset(dp, 0, sizeof(dp));
memset(vis, 0, sizeof(vis));
for (int i = 2; i <= k; i += 2)
ans = (ans + C(n - 1, a[i] - 1) * cal(i, 1) % mod * cal(i, n)) % mod;
cout << ans << endl;
}
void prepare() {
jc[0] = jc[1] = revjc[0] = revjc[1] = 1;
for (int i = 2; i <= n; i++)
revjc[i] = (mod - mod / i) * revjc[mod % i] % mod,
jc[i] = jc[i - 1] * i % mod;
for (int i = 2; i <= n; i++)
revjc[i] = revjc[i - 1] * revjc[i] % mod;
}
signed main() {
prepare();
int T; cin >> T;
while(T--)
solve();
return 0;
}

浙公网安备 33010602011771号