Luogu P13272 [NOI 2025] 序列变换

因为昨天没写题所以我只能拿场上过的题凑数了(好摆啊好摆啊。

首先因为这题要计数,所以尝试去找一个刻画方式刻画出所有能被 \(a\) 生成的序列。

考虑操作实际就是让 \(a_i, a_{i + 1}\) 同时减去 \(\min\{a_i, a_{i + 1}\}\)

在这之后一定会有一个数变为 \(0\),且发现若操作的两个数中有 \(0\),那么情况一定不变。

于是可以知道,如果多次选择了 \((i, i + 1)\),那么除第一次的操作肯定都是无效的。
那么直接指定每对 \((i, i + 1)\) 至多被操作一次即可。

考虑操作时若 \(a_i \le a_{i + 1}\),那么认为有 \(i\to i + 1\);若 \(a_i \ge a_{i + 1}\),那么有 \(i\gets i + 1\)(相等会有点问题,但是这里先暂时不考虑)。
这个边集一定能被分成若干个 \(\cdots\to\to\to\gets\gets\gets\cdots\) 的连续段,即左端从 \(l\),右端从 \(r\) 往内汇聚,直到一个分界线 \(k\) 挡住了两测。

于是可以首先预处理出单侧的情况,记 \(lv_{l, k}\) 表示 \(l\) 汇聚到 \(k\) 时的值,易知有 \(lv_{l, l} = a_l, lv_{l, k + 1} = a_{k + 1} - lv_{l, k}\)
需要保证中途都是可以操作的,即若 \(lv_{l, k} < 0\),那么其实都走不到 \(k\) 这里,从 \(k\) 开始的 \(lv_{l, k}\) 都是不合法的。
同理定义 \(rv_{r, k}\)

那么考虑暴力的做法,直接枚举 \((l, r, k)\),首先肯定要满足 \(l, r\) 都能走到 \(k\),然后再来分讨一下 \(k\) 处的情况:

  • \(lv_{l, k} + rv_{r, k} < a_k\),那 \(k\) 这里的操作肯定怎样都不合法。
  • \(lv_{l, k} + rv_{r, k} > a_k\),最后的 \(a\) 一定形如,\([l, k), (k, r]\)\(0\),但是 \(k\) 处不为 \(0\)
  • \(lv_{l, k} + rv_{r, k} = a_k\),最后的 \(a\) 一定满足 \([l, r]\) 都为 \(0\)

(这里与 \(a_k\) 比较只是因为前面定义的 \(lv_{l, k}, rv_{r, k}\) 相加时会多一个 \(a_k\),实际上还是与 \(0\) 比较。)

于是现在就有了个 \(\mathcal{O}(n^3)\) 的做法,不过会发现这个做法其实有点问题,计数的时候可能会计重。

例如 \(a = [1, 1, 1, 1]\),会认为 \([1, 2], [3, 4], [1, 4]\) 都可以消除,那这就爆了。
思考一下原因,对于 \((l, r, k)\),其实如果删除一段前缀或后继(不包括 \(k\)),假设通过这段后得到的值是 \(x\),那么对最终 \(a_k\) 的影响肯定是有 \(+x\)\(-x\) 的,那么肯定就需要另一边同样调整以使最终 \(a_k\) 不变,但是这样又会影响周边的段继续修改,所以一定会继续往外扩张的,一定不合法——吗?
于是发现了问题:当 \(x = 0\) 时,且如果换一个方向传入的值也是 \(0\) 时,那么这一段是可以分到任意一边的。

为了解决这个问题,只需要考虑小修一下 \(lv_{l, k}, rv_{r, k}\) 的定义:当 \(\le 0\) 时就是不合法的。这样的 \(\mathcal{O}(n^3)\) 做法就是正确的。

\(\mathcal{O}(n^3)\) 的代码也会放在后面供参考。

接下来考虑继续优化,发现 dp 中直接对于 \(r\) 枚举 \(l\) 并统计贡献看着已经比较优了,于是需要解决的应当是快速处理对于所有 \(k\)\((l, r, k)\) 的贡献。

上文已经发现了一个事实:对于一个数 \(x\),其对中心的贡献一定是 \(+x\)\(-x\)
从这里入手,能发现符号一定是与中心距离为偶数时为 \(+\),为奇数时为 \(-\)。即包括中心的右侧符号应当形如 \(+-+-+-+-\cdots\),左侧也是对称的。

于是对于 \(l, r\) 来说,若 \(k\) 的奇偶性相同,则 \(a_k\) 最后得到的值一定也是相同的,且不同奇偶性的 \(k\) 得到的值一定互为相反数(不过不知道这个也没有啥)。

\(l\) 能扩展到的最远的 \(k\)\(lb_l\),同理记 \(rb_r\)
那么合法的 \(k\) 一定是在 \([\max(l, rb_r), \min(r, lb_l)]\) 中的所有奇数或偶数或全部(中间值 \(= 0\) 时),只需要知道这部分的 \(\max \{-b_k\}\)\(\sum \frac{1}{c_k}\),可以写个 st 表前缀和,不过因为数据范围并不大,直接 \(\mathcal{O}(n^2)\) 预处理也可以。

一个小细节:上述判断中间值 \(= 0\) 的方法还是有点问题,不过能发现问题只出在 \(a_i = a_{i + 1}\),特殊处理一下即可。

代码是复刻的,应该没有问题。

\(\mathcal{O}(n^3)\)(判断 \(=0\) 的方式有点不一样):

inline void solve() {
    scanf("%d", &n);
    for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
    for (int i = 1; i <= n; i++) scanf("%d", &b[i]);
    for (int i = 1; i <= n; i++) scanf("%d", &c[i]);
    
    for (int i = 1; i <= n; i++) {
        memset(lv[i] + 1, -1, sizeof(int) * n);
        lv[i][i] = a[i];
        for (int j = i; j < n && lv[i][j] < a[j + 1]; j++) {
            lv[i][j + 1] = a[j + 1] - lv[i][j];
        }
    }
    for (int i = n; i >= 1; i--) {
        memset(rv[i] + 1, -1, sizeof(int) * n);
        rv[i][i] = a[i];
        for (int j = i; j > 1 && rv[i][j] < a[j - 1]; j--) {
            rv[i][j - 1] = a[j - 1] - rv[i][j];
        }
    }

    for (int l = 1; l <= n; l++) {
        for (int r = l; r <= n; r++) {
            ok0[l][r] = false;
        }
    }
    for (int i = 1; i < n; i++) {
        for (int l = 1; l <= i; l++) {
            for (int r = i + 1; r <= n; r++) {
                ok0[l][r] |= lv[l][i] != -1 && rv[r][i + 1] != -1 && lv[l][i] == rv[r][i + 1];
            }
        }
    }

    for (int i = 1; i <= n; i++) {
        for (int l = 1; l <= i; l++) {
            for (int r = i; r <= n; r++) {
                ok[i][l][r] = lv[l][i] != -1 && rv[r][i] != -1 && lv[l][i] + rv[r][i] > a[i];
            }
        }
    }

    pre[0] = 0;
    for (int i = 1; i <= n; i++) pre[i] = pre[i - 1] + b[i];
    f[0] = 0;
    for (int i = 1; i <= n; i++) {
        f[i] = f[i - 1];
        for (int j = 1; j <= i; j++) {
            if (ok0[j][i]) f[i] = std::max(f[i], f[j - 1] + pre[i] - pre[j - 1]);
            for (int k = j; k <= i; k++) {
                if (ok[k][j][i]) f[i] = std::max(f[i], f[j - 1] + pre[i] - pre[j - 1] - b[k]);
            }
        }
    }

    pr[0] = ipr[0] = 1;
    for (int i = 1; i <= n; i++) {
        ic[i] = qpow(c[i], mod - 2);
        pr[i] = pr[i - 1] * c[i] % mod;
        ipr[i] = ipr[i - 1] * ic[i] % mod;
    }
    g[0] = 1;
    for (int i = 1; i <= n; i++) {
        g[i] = 0;
        for (int j = 1; j <= i; j++) {
            if (ok0[j][i]) g[i] = (g[i] + g[j - 1] * pr[i] % mod * ipr[j - 1]) % mod;
            for (int k = j; k <= i; k++) {
                if (ok[k][j][i]) g[i] = (g[i] + g[j - 1] * pr[i] % mod * ipr[j - 1] % mod * ic[k]) % mod;
            }
        }
    }

    printf("%lld %lld\n", f[n], g[n]);
}

\(\mathcal{O}(n^2)\)

#include <bits/stdc++.h>

using ll = long long;

constexpr ll mod = 1e9 + 7;

inline ll qpow(ll a, ll b) {
    ll v = 1;
    for (; b; b >>= 1, a = a * a % mod) {
        if (b & 1) v = v * a % mod;
    }
    return v;
}

constexpr int maxn = 5000 + 10;

int n;
int a[maxn], b[maxn], c[maxn];

int lb[maxn], rb[maxn];

bool ok0[maxn][maxn];

ll preb[maxn], fv[maxn][maxn], f[maxn];
ll prec[maxn], iprec[maxn], ic[maxn], gv[maxn][maxn], g[maxn];

ll prea[maxn];

int pic[maxn][maxn][2], mxb[maxn][maxn][2];

inline void solve() {
    scanf("%d", &n);
    for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
    for (int i = 1; i <= n; i++) scanf("%d", &b[i]);
    for (int i = 1; i <= n; i++) scanf("%d", &c[i]);

    prea[0] = 0;
    for (int i = 1; i <= n; i++) {
        prea[i] = prea[i - 1] + (i % 2 ? a[i] : -a[i]);
    }

    preb[0] = 0, prec[0] = iprec[0] = 1;
    for (int i = 1; i <= n; i++) {
        ic[i] = qpow(c[i], mod - 2);
        preb[i] = preb[i - 1] + b[i];
        prec[i] = prec[i - 1] * c[i] % mod;
        iprec[i] = iprec[i - 1] * ic[i] % mod;
    }

    for (int l = 1; l <= n; l++) {
        pic[l][l - 1][0] = pic[l][l - 1][1] = 0;
        mxb[l][l - 1][0] = mxb[l][l - 1][1] = -2e9;
        for (int r = l; r <= n; r++) {
            pic[l][r][0] = pic[l][r - 1][0];
            pic[l][r][1] = pic[l][r - 1][1];
            mxb[l][r][0] = mxb[l][r - 1][0];
            mxb[l][r][1] = mxb[l][r - 1][1];
            pic[l][r][r % 2] = (pic[l][r][r % 2] + ic[r]) % mod;
            mxb[l][r][r % 2] = std::max(mxb[l][r][r % 2], -b[r]);
        }
    }
    
    for (int i = 1; i <= n; i++) {
        lb[i] = i;
        for (int x = a[i]; lb[i] < n && x < a[lb[i] + 1]; ) {
            x = a[++lb[i]] - x;
        }
    }
    for (int i = n; i >= 1; i--) {
        rb[i] = i;
        for (int x = a[i]; rb[i] > 1 && x < a[rb[i] - 1]; ) {
            x = a[--rb[i]] - x;
        }
    }

    for (int l = 1; l <= n; l++) {
        for (int r = l; r <= n; r++) {
            fv[l][r] = -1e18, gv[l][r] = 0;
        }
    }

    auto conv = [&](const int x) {
        return x == -2e9 ? (ll)-1e18 : (ll)x;
    };

    for (int l = 1; l <= n; l++) {
        for (int r = l; r <= n; r++) {
            if (prea[r] == prea[l - 1] || lb[l] < rb[r]) continue;
            const int op = prea[r] > prea[l - 1];
            const int st = std::max(rb[r], l);
            const int ed = std::min(lb[l], r);
            fv[l][r] = std::max(fv[l][r], preb[r] - preb[l - 1] + conv(mxb[st][ed][op]));
            gv[l][r] = (gv[l][r] + prec[r] * iprec[l - 1] % mod * pic[st][ed][op]) % mod;
        }
    }

    for (int l = 1; l <= n; l++) {
        for (int r = l; r <= n; r++) {
            ok0[l][r] = prea[r] - prea[l - 1] == 0 && rb[r] <= lb[l];
        }
    }
    for (int i = 1; i < n; i++) {
        ok0[i][i + 1] |= a[i] == a[i + 1];
    }

    for (int l = 1; l <= n; l++) {
        for (int r = l; r <= n; r++) {
            if (ok0[l][r]) {
                fv[l][r] = std::max(fv[l][r], preb[r] - preb[l - 1]);
                gv[l][r] = (gv[l][r] + prec[r] * iprec[l - 1]) % mod;
            }
        }
    }

    f[0] = 0;
    for (int i = 1; i <= n; i++) {
        f[i] = -1e18;
        for (int j = 1; j <= i; j++) {
            f[i] = std::max(f[i], f[j - 1] + fv[j][i]);
        }
    }

    g[0] = 1;
    for (int i = 1; i <= n; i++) {
        g[i] = 0;
        for (int j = 1; j <= i; j++) {
            g[i] = (g[i] + g[j - 1] * gv[j][i]) % mod;
        }
    }

    printf("%lld %lld\n", f[n], g[n]);
}

int main() {
    freopen("sequence.in", "r", stdin);
    freopen("sequence.out", "w", stdout);

    int testid, t;
    scanf("%d%d", &testid, &t);
    while (t--) solve();

    return 0;
}
posted @ 2025-07-15 14:22  rizynvu  阅读(162)  评论(0)    收藏  举报