Luogu P14031 【MX-X20-T5】「FAOI-R7」连接时光 II
场上死活不会,下来再看一下就会了,两周没开电脑导致的。
首先考虑如何计算 \(f_S(p)\)。
首先因为 \(f_S(p)\) 的限制都是对于前缀的图的限制,所以先来考察前缀的图的结构和变化情况。
经过手玩能够知道,对于前 \(i\) 个数的图,根据值域划分,连通块就为一些相邻的区间 。
然后在最后加入了一个数 \(a_{i + 1} = x\)(此时考虑的是相对大小)后,就相当于是加入了 \([x, x + 1)\) 这个区间,并且把区间右端点大于 \(x\) 的区间都合并到一起。
此时能够发现,合并的一定都是这些区间里的后缀。
再结合这个限制,相当于是要求只有一个区间。
贪心的考虑,因为每次合并都是合并一段后缀,那么前 \(i\) 个数形成的连通块如果不是一个区间(不满足限制),最靠前的区间的左端点一定不为 \(i\)。
于是会发现在过程中只关心最靠前的区间的右端点,那就可以考虑设计 dp 了。
设 \(f_{i, j}\) 表示前 \(i\) 个数,最靠前的区间的右端点是 \(j\)(前 \(i\) 个的相对顺序)的答案。
考虑如果加入了 \(a_{i + 1} = x\)(前 \(i + 1\) 个的相对顺序),此时 \(j\) 的变化是什么。
经过分讨容易知道有 \(j' = \begin{cases}i + 1 & 1\le x\le j\\j & j < x \le i + 1\end{cases}\)。
那么转移就很好写出了:\(f_{i + 1, i + 1}\gets f_{i, j}\times \sum\limits_{k = 1}^j a_{i + 1}^{i + 1 - k}, f_{i + 1, j}\gets f_{i, j}\times \sum\limits_{k = j + 1}^{i + 1} a_{i + 1}^{i + 1 - k}\)。
因为这都对应的是 \(a_{i + 1}^k\) 的前缀和或后缀和,所以容易优化到 \(\mathcal{O}(1)\)。
当然也可以把 \(f_{i + 1, i + 1}\) 视作“不合法”的方案数,类似总和减掉合法方案数,求出 \(f_{i + 1, j}(j\le i)\) 有 \(f_{i + 1, i + 1} = \sum\limits_{j = 1}^i f_{i, j}\sum\limits_{j = 0}^i a_{i + 1}^j - \sum\limits_{j = 1}^i f_{i + 1, j}\),这样就只需要关心 \(a_{i + 1}^k\) 的前缀和,会方便写一些。
对于 \(s_i = 1\) 的位置,要求前缀必须为一个连通块,把不合法的 \(f_{i, j}(j < i)\) 都置为 \(0\) 即可。
最后套上外层的 \(\sum\limits_{T\subseteq S}\),对于 \(s_i = 1\) 的位置,可以是为 \(1\),就只保留 \(f_{i, i}\);也可以为 \(0\),所有数都不变。
于是扩展到 \(\sum\limits_{T\subseteq S}\) 只需要加上一个 \(f_{i, i}\gets f_{i, i}\times (1 + s_i)\)。
时间复杂度 \(\mathcal{O}(n^2)\)。
#include <bits/stdc++.h>
using ll = long long;
constexpr ll mod = 998244353;
constexpr int maxn = 5000 + 10;
int n, a[maxn];
char s[maxn];
ll f[maxn];
inline void solve() {
scanf("%d", &n);
for (int i = 1; i <= n; i++) {
scanf("%d", &a[i]);
}
scanf("%s", s + 1);
f[1] = 1 + (s[1] - '0');
for (int i = 2; i <= n; i++) {
ll pw = 1, sum = 1, sumf = 0;
for (int j = i - 1; j >= 1; j--) {
sumf = (sumf + f[j]) % mod;
f[j] = f[j] * sum % mod;
pw = pw * a[i] % mod;
sum = (sum + pw) % mod;
}
f[i] = sumf * sum % mod;
for (int j = 1; j < i; j++) {
f[i] = (f[i] - f[j] + mod) % mod;
}
f[i] = f[i] * (1 + s[i] - '0') % mod;
}
ll ans = 0;
for (int i = 1; i <= n; i++) {
ans = (ans + f[i]) % mod;
}
printf("%lld\n", ans);
}
int main() {
int t;
scanf("%d", &t);
while (t--) {
solve();
}
return 0;
}
浙公网安备 33010602011771号