Luogu P13272 [NOI 2025] 序列变换
因为昨天没写题所以我只能拿场上过的题凑数了(好摆啊好摆啊。
首先因为这题要计数,所以尝试去找一个刻画方式刻画出所有能被 \(a\) 生成的序列。
考虑操作实际就是让 \(a_i, a_{i + 1}\) 同时减去 \(\min\{a_i, a_{i + 1}\}\)。
在这之后一定会有一个数变为 \(0\),且发现若操作的两个数中有 \(0\),那么情况一定不变。
于是可以知道,如果多次选择了 \((i, i + 1)\),那么除第一次的操作肯定都是无效的。
那么直接指定每对 \((i, i + 1)\) 至多被操作一次即可。
考虑操作时若 \(a_i \le a_{i + 1}\),那么认为有 \(i\to i + 1\);若 \(a_i \ge a_{i + 1}\),那么有 \(i\gets i + 1\)(相等会有点问题,但是这里先暂时不考虑)。
这个边集一定能被分成若干个 \(\cdots\to\to\to\gets\gets\gets\cdots\) 的连续段,即左端从 \(l\),右端从 \(r\) 往内汇聚,直到一个分界线 \(k\) 挡住了两测。
于是可以首先预处理出单侧的情况,记 \(lv_{l, k}\) 表示 \(l\) 汇聚到 \(k\) 时的值,易知有 \(lv_{l, l} = a_l, lv_{l, k + 1} = a_{k + 1} - lv_{l, k}\)。
需要保证中途都是可以操作的,即若 \(lv_{l, k} < 0\),那么其实都走不到 \(k\) 这里,从 \(k\) 开始的 \(lv_{l, k}\) 都是不合法的。
同理定义 \(rv_{r, k}\)。
那么考虑暴力的做法,直接枚举 \((l, r, k)\),首先肯定要满足 \(l, r\) 都能走到 \(k\),然后再来分讨一下 \(k\) 处的情况:
- \(lv_{l, k} + rv_{r, k} < a_k\),那 \(k\) 这里的操作肯定怎样都不合法。
- \(lv_{l, k} + rv_{r, k} > a_k\),最后的 \(a\) 一定形如,\([l, k), (k, r]\) 为 \(0\),但是 \(k\) 处不为 \(0\)。
- \(lv_{l, k} + rv_{r, k} = a_k\),最后的 \(a\) 一定满足 \([l, r]\) 都为 \(0\)。
(这里与 \(a_k\) 比较只是因为前面定义的 \(lv_{l, k}, rv_{r, k}\) 相加时会多一个 \(a_k\),实际上还是与 \(0\) 比较。)
于是现在就有了个 \(\mathcal{O}(n^3)\) 的做法,不过会发现这个做法其实有点问题,计数的时候可能会计重。
例如 \(a = [1, 1, 1, 1]\),会认为 \([1, 2], [3, 4], [1, 4]\) 都可以消除,那这就爆了。
思考一下原因,对于 \((l, r, k)\),其实如果删除一段前缀或后继(不包括 \(k\)),假设通过这段后得到的值是 \(x\),那么对最终 \(a_k\) 的影响肯定是有 \(+x\) 或 \(-x\) 的,那么肯定就需要另一边同样调整以使最终 \(a_k\) 不变,但是这样又会影响周边的段继续修改,所以一定会继续往外扩张的,一定不合法——吗?
于是发现了问题:当 \(x = 0\) 时,且如果换一个方向传入的值也是 \(0\) 时,那么这一段是可以分到任意一边的。
为了解决这个问题,只需要考虑小修一下 \(lv_{l, k}, rv_{r, k}\) 的定义:当 \(\le 0\) 时就是不合法的。这样的 \(\mathcal{O}(n^3)\) 做法就是正确的。
\(\mathcal{O}(n^3)\) 的代码也会放在后面供参考。
接下来考虑继续优化,发现 dp 中直接对于 \(r\) 枚举 \(l\) 并统计贡献看着已经比较优了,于是需要解决的应当是快速处理对于所有 \(k\),\((l, r, k)\) 的贡献。
上文已经发现了一个事实:对于一个数 \(x\),其对中心的贡献一定是 \(+x\) 或 \(-x\)。
从这里入手,能发现符号一定是与中心距离为偶数时为 \(+\),为奇数时为 \(-\)。即包括中心的右侧符号应当形如 \(+-+-+-+-\cdots\),左侧也是对称的。
于是对于 \(l, r\) 来说,若 \(k\) 的奇偶性相同,则 \(a_k\) 最后得到的值一定也是相同的,且不同奇偶性的 \(k\) 得到的值一定互为相反数(不过不知道这个也没有啥)。
记 \(l\) 能扩展到的最远的 \(k\) 为 \(lb_l\),同理记 \(rb_r\)。
那么合法的 \(k\) 一定是在 \([\max(l, rb_r), \min(r, lb_l)]\) 中的所有奇数或偶数或全部(中间值 \(= 0\) 时),只需要知道这部分的 \(\max \{-b_k\}\) 和 \(\sum \frac{1}{c_k}\),可以写个 st 表前缀和,不过因为数据范围并不大,直接 \(\mathcal{O}(n^2)\) 预处理也可以。
一个小细节:上述判断中间值 \(= 0\) 的方法还是有点问题,不过能发现问题只出在 \(a_i = a_{i + 1}\),特殊处理一下即可。
代码是复刻的,应该没有问题。
\(\mathcal{O}(n^3)\)(判断 \(=0\) 的方式有点不一样):
inline void solve() {
scanf("%d", &n);
for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
for (int i = 1; i <= n; i++) scanf("%d", &b[i]);
for (int i = 1; i <= n; i++) scanf("%d", &c[i]);
for (int i = 1; i <= n; i++) {
memset(lv[i] + 1, -1, sizeof(int) * n);
lv[i][i] = a[i];
for (int j = i; j < n && lv[i][j] < a[j + 1]; j++) {
lv[i][j + 1] = a[j + 1] - lv[i][j];
}
}
for (int i = n; i >= 1; i--) {
memset(rv[i] + 1, -1, sizeof(int) * n);
rv[i][i] = a[i];
for (int j = i; j > 1 && rv[i][j] < a[j - 1]; j--) {
rv[i][j - 1] = a[j - 1] - rv[i][j];
}
}
for (int l = 1; l <= n; l++) {
for (int r = l; r <= n; r++) {
ok0[l][r] = false;
}
}
for (int i = 1; i < n; i++) {
for (int l = 1; l <= i; l++) {
for (int r = i + 1; r <= n; r++) {
ok0[l][r] |= lv[l][i] != -1 && rv[r][i + 1] != -1 && lv[l][i] == rv[r][i + 1];
}
}
}
for (int i = 1; i <= n; i++) {
for (int l = 1; l <= i; l++) {
for (int r = i; r <= n; r++) {
ok[i][l][r] = lv[l][i] != -1 && rv[r][i] != -1 && lv[l][i] + rv[r][i] > a[i];
}
}
}
pre[0] = 0;
for (int i = 1; i <= n; i++) pre[i] = pre[i - 1] + b[i];
f[0] = 0;
for (int i = 1; i <= n; i++) {
f[i] = f[i - 1];
for (int j = 1; j <= i; j++) {
if (ok0[j][i]) f[i] = std::max(f[i], f[j - 1] + pre[i] - pre[j - 1]);
for (int k = j; k <= i; k++) {
if (ok[k][j][i]) f[i] = std::max(f[i], f[j - 1] + pre[i] - pre[j - 1] - b[k]);
}
}
}
pr[0] = ipr[0] = 1;
for (int i = 1; i <= n; i++) {
ic[i] = qpow(c[i], mod - 2);
pr[i] = pr[i - 1] * c[i] % mod;
ipr[i] = ipr[i - 1] * ic[i] % mod;
}
g[0] = 1;
for (int i = 1; i <= n; i++) {
g[i] = 0;
for (int j = 1; j <= i; j++) {
if (ok0[j][i]) g[i] = (g[i] + g[j - 1] * pr[i] % mod * ipr[j - 1]) % mod;
for (int k = j; k <= i; k++) {
if (ok[k][j][i]) g[i] = (g[i] + g[j - 1] * pr[i] % mod * ipr[j - 1] % mod * ic[k]) % mod;
}
}
}
printf("%lld %lld\n", f[n], g[n]);
}
\(\mathcal{O}(n^2)\):
#include <bits/stdc++.h>
using ll = long long;
constexpr ll mod = 1e9 + 7;
inline ll qpow(ll a, ll b) {
ll v = 1;
for (; b; b >>= 1, a = a * a % mod) {
if (b & 1) v = v * a % mod;
}
return v;
}
constexpr int maxn = 5000 + 10;
int n;
int a[maxn], b[maxn], c[maxn];
int lb[maxn], rb[maxn];
bool ok0[maxn][maxn];
ll preb[maxn], fv[maxn][maxn], f[maxn];
ll prec[maxn], iprec[maxn], ic[maxn], gv[maxn][maxn], g[maxn];
ll prea[maxn];
int pic[maxn][maxn][2], mxb[maxn][maxn][2];
inline void solve() {
scanf("%d", &n);
for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
for (int i = 1; i <= n; i++) scanf("%d", &b[i]);
for (int i = 1; i <= n; i++) scanf("%d", &c[i]);
prea[0] = 0;
for (int i = 1; i <= n; i++) {
prea[i] = prea[i - 1] + (i % 2 ? a[i] : -a[i]);
}
preb[0] = 0, prec[0] = iprec[0] = 1;
for (int i = 1; i <= n; i++) {
ic[i] = qpow(c[i], mod - 2);
preb[i] = preb[i - 1] + b[i];
prec[i] = prec[i - 1] * c[i] % mod;
iprec[i] = iprec[i - 1] * ic[i] % mod;
}
for (int l = 1; l <= n; l++) {
pic[l][l - 1][0] = pic[l][l - 1][1] = 0;
mxb[l][l - 1][0] = mxb[l][l - 1][1] = -2e9;
for (int r = l; r <= n; r++) {
pic[l][r][0] = pic[l][r - 1][0];
pic[l][r][1] = pic[l][r - 1][1];
mxb[l][r][0] = mxb[l][r - 1][0];
mxb[l][r][1] = mxb[l][r - 1][1];
pic[l][r][r % 2] = (pic[l][r][r % 2] + ic[r]) % mod;
mxb[l][r][r % 2] = std::max(mxb[l][r][r % 2], -b[r]);
}
}
for (int i = 1; i <= n; i++) {
lb[i] = i;
for (int x = a[i]; lb[i] < n && x < a[lb[i] + 1]; ) {
x = a[++lb[i]] - x;
}
}
for (int i = n; i >= 1; i--) {
rb[i] = i;
for (int x = a[i]; rb[i] > 1 && x < a[rb[i] - 1]; ) {
x = a[--rb[i]] - x;
}
}
for (int l = 1; l <= n; l++) {
for (int r = l; r <= n; r++) {
fv[l][r] = -1e18, gv[l][r] = 0;
}
}
auto conv = [&](const int x) {
return x == -2e9 ? (ll)-1e18 : (ll)x;
};
for (int l = 1; l <= n; l++) {
for (int r = l; r <= n; r++) {
if (prea[r] == prea[l - 1] || lb[l] < rb[r]) continue;
const int op = prea[r] > prea[l - 1];
const int st = std::max(rb[r], l);
const int ed = std::min(lb[l], r);
fv[l][r] = std::max(fv[l][r], preb[r] - preb[l - 1] + conv(mxb[st][ed][op]));
gv[l][r] = (gv[l][r] + prec[r] * iprec[l - 1] % mod * pic[st][ed][op]) % mod;
}
}
for (int l = 1; l <= n; l++) {
for (int r = l; r <= n; r++) {
ok0[l][r] = prea[r] - prea[l - 1] == 0 && rb[r] <= lb[l];
}
}
for (int i = 1; i < n; i++) {
ok0[i][i + 1] |= a[i] == a[i + 1];
}
for (int l = 1; l <= n; l++) {
for (int r = l; r <= n; r++) {
if (ok0[l][r]) {
fv[l][r] = std::max(fv[l][r], preb[r] - preb[l - 1]);
gv[l][r] = (gv[l][r] + prec[r] * iprec[l - 1]) % mod;
}
}
}
f[0] = 0;
for (int i = 1; i <= n; i++) {
f[i] = -1e18;
for (int j = 1; j <= i; j++) {
f[i] = std::max(f[i], f[j - 1] + fv[j][i]);
}
}
g[0] = 1;
for (int i = 1; i <= n; i++) {
g[i] = 0;
for (int j = 1; j <= i; j++) {
g[i] = (g[i] + g[j - 1] * gv[j][i]) % mod;
}
}
printf("%lld %lld\n", f[n], g[n]);
}
int main() {
freopen("sequence.in", "r", stdin);
freopen("sequence.out", "w", stdout);
int testid, t;
scanf("%d%d", &testid, &t);
while (t--) solve();
return 0;
}
浙公网安备 33010602011771号