[dp] [组合计数] [神仙题] ARC146e Simple Speed
显然这是 \(\rm dp\)。
容易想到一个以 \(B\) 序列长度定义的做法,但是没法优化。
改变角度,从值域入手,即将元素从小到大依次插入 \(B\) 序列,这样的好处是元素大小有序,即新插入的元素一定大于原来插入的元素。
考虑插入型 \(dp\),定义 \(f_{i,j}\) 表示考虑 \([1,i]\) 的元素构成了 \(j\) 个合法区间的方案数,注意到每个区间的左边界和右边界必须为 \(i\),否则此后无法被合并成大的合法区间。
设当前插入 \(i\),枚举 \(f_{i-1,j}\)。那么每个合法区间的间隙都必须插入一个 \(i\),然后我们随意将剩下的 \(a_i-(j-1)\) 个 \(i\) 安排在这些间隙,这是个经典问题,方案数是可以算出来的。最后考虑插入完毕后的合法区间数,易发现安排完 \(j-1\) 个 \(i\) 后,剩下的所有 \(i\) 要么将一个合法区间断开成为两个,要么独自构成一个合法区间,总的来说就是每添加一个 \(i\) 就会使得区间数 \(+1\),因此最终区间数为 \(a_i-(j-1)+1\)。
但是这样会少算一些方案,我们注意到最左 / 最右区间的左边 / 右边不一定要是 \(i\),于是完善状态: \(f_{i,j,0/1,0/1}\) 表示考虑最左边是否为 \(i\) 和最右边是否为 \(i\),若钦定某一边必须放 \(i\),就在一开始分配一个 \(i\) 给它,并且之后随意分配时允许在这一边放 \(i\)。
理清思路后,转移方程就很简单了,建议自己手推。
至此,复杂度为 \(O(n\sum a_i)\),不过容易发现 \(j\) 这一维必须 \(\le a\),否则此后不可能将合法区间合并为一个区间,复杂度 \(O(n^2)\)。
然后怎么办?如果你将状态转移画成图,就能发现对于每个 \(i\),可能有值的 \(f_{i,j}\) 不超过 \(3\) 个,这道题神就神在这了。于是拿个 map 或 set 保存状态即可。
代码
#include<bits/stdc++.h>
using namespace std;
#define int long long
#define ADD(a, b) (a) = ((a) + (b)) % mod
const int N = 2e5 + 5, mod = 998244353;
int n, a, f[N][2][2], f2[N][2][2];
int jc[N << 1], jcinv[N << 1];
set<int>now, now2;
inline int qstp(int a, int k) {int res = 1; for(; k; a = a * a % mod, k >>= 1) if(k & 1) res = res * a % mod; return res;}
inline void init() {jcinv[0] = jc[0] = 1; for(int i = 1; i < N << 1; ++i) jcinv[i] = qstp(jc[i] = jc[i - 1] * i % mod, mod - 2);}
inline int C(int n, int m) {return (n < 0 || m < 0 || n < m) ? 0 : jc[n] * jcinv[n - m] % mod * jcinv[m] % mod;}
inline int S(int n, int m) {return C(m + n - 1, n - 1);}
signed main(){
init();
cin >> n;
for(int i = 1; i <= n; ++i){
scanf("%lld", &a), now2.clear();
if(i == 1) {
f[a][1][1] = 1, now.insert(a);
continue;
}
for(auto j : now){
int cnt = a - j + 2, sum = a - j + 1;
if(sum < 0) continue;
if(f[j][0][0])
ADD(f2[cnt][0][0], f[j][0][0] * S(j - 1, sum)), now2.insert(cnt);
if(f[j][0][1]){
ADD(f2[cnt][0][0], f[j][0][1] * S(j - 1, sum)), now2.insert(cnt);
ADD(f2[cnt - 1][0][1], f[j][0][1] * S(j, sum - 1)), now2.insert(cnt - 1);
}
if(f[j][1][0]){
ADD(f2[cnt][0][0], f[j][1][0] * S(j - 1, sum)), now2.insert(cnt);
ADD(f2[cnt - 1][1][0], f[j][1][0] * S(j, sum - 1)), now2.insert(cnt - 1);
}
if(f[j][1][1]){
ADD(f2[cnt][0][0], f[j][1][1] * S(j - 1, sum)), now2.insert(cnt);
ADD(f2[cnt - 1][0][1], f[j][1][1] * S(j, sum - 1)), now2.insert(cnt - 1);
ADD(f2[cnt - 1][1][0], f[j][1][1] * S(j, sum - 1)), now2.insert(cnt - 1);
ADD(f2[cnt - 2][1][1], f[j][1][1] * S(j + 1, sum - 2)), now2.insert(cnt - 2);
}
}
for(auto j : now)
f[j][0][0] = f[j][0][1] = f[j][1][0] = f[j][1][1] = 0;
for(auto j : now2)
for(int q = 0; q < 2; ++q)
for(int p = 0; p < 2; ++p)
f[j][q][p] = f2[j][q][p], f2[j][q][p] = 0;
now = now2;
}
cout << (f[1][0][0] + f[1][0][1] + f[1][1][0] + f[1][1][1]) % mod;
return 0;
}

浙公网安备 33010602011771号