sol. ABC246Ex

[ABC246Ex] 01? Queries

算法标签:动态 DP。

题目描述

给一个长度为 \(N\) 只包含 ?10 的字符串 \(a\)\(Q\) 次操作,每次操作会修改字符串中的一个字符,并求出此时整个字符串中本质不同的子序列个数。

数据范围:\(1 \le N,Q \le 10^5\)

题解

众所周知,动态 DP 依然是 DP 的一种。

先考虑没有修改操作怎么做,那么本题变为一个普通的 DP,我们设 \(dp_{i,0/1}\) 分别表示到第 \(i\) 位,以 \(0/1\) 结尾的本质不同的子序列个数。

考虑转移,以 \(a_i =\) 0 为例,那么在 \(1\)\(i-1\) 之间的所有本质不同的子序列后都可以加第 \(i\) 位的这个 \(0\),这些串依旧本质不同且长度不超过 \(2\),再加上 0 自己本身,因此 \(dp_{i,0}=dp_{i-1,1}+dp_{i-1,0}+1\)

又因为 \(a_i\)\(0\),不可能形成更多的以 \(1\) 结尾的本质不同子序列,所以 \(dp_{i,1}=dp_{i-1,1}\)

\(a_i\)1? 的转移同理,因此:

\(a_i\)0 时,

  • \(dp_{i,0}=dp_{i-1,0}+(dp_{i-1,1}+1)\)

  • \(dp_{i,1}=dp_{i-1,1}\)

\(a_i\)1 时,

  • \(dp_{i,0}=dp_{i-1,0}\)

  • \(dp_{i,1}=dp_{i-1,1}+(dp_{i-1,0}+1)\)

\(a_i\)? 时,

  • \(dp_{i,0}=dp_{i-1,0}+(dp_{i-1,1}+1)\)

  • \(dp_{i,1}=dp_{i-1,1}+(dp_{i-1,0}+1)\)

最终答案为 \(dp_{n,0}+dp_{n,1}\)

如果对于每次修改,重新做一遍 DP,时间复杂度变为 \(\mathcal{O}(NQ)\),无法通过,注意到方程的转移是一个线性递推式,这启示我们用矩阵优化转移

我们把递推式写成矩阵的形式,可以用以下的式子描述:

\[S_{a_i}\begin{bmatrix} dp_{i-1,0} \\ dp_{i-1,1} \\ 1\end{bmatrix}=\begin{bmatrix} dp_{i,0} \\ dp_{i,1} \\ 1\end{bmatrix} \]

不难写出:

\[S_{0}=\begin{bmatrix} 1&1&1 \\ 0&1&0 \\ 0&0&1\end{bmatrix},S_{1}=\begin{bmatrix} 1&0&0 \\ 1&1&1 \\ 0&0&1\end{bmatrix},S_{?}=\begin{bmatrix} 1&1&1 \\ 1&1&1 \\ 0&0&1\end{bmatrix} \]

那么整个字符串的转移可以用以下式子表示:

\[\prod_{i=1}^n S_{a_i} \times \begin{bmatrix}0\\0\\1\end{bmatrix} = \begin{bmatrix} dp_{n,0} \\ dp_{n,1} \\ 1\end{bmatrix} \]

本质上,一次修改字符的操作就是修改一个字符的转移矩阵。又由上面的式子可知要维护的信息是整个字符串的 \(S_{a_i}\) 的乘积。

由于矩阵乘法具有结合律,所以可以用线段树维护。

时间复杂度为 \(O(k^3(N + Q \log N))\),本题中 \(k = 3\)

#include<bits/stdc++.h>
#define L p << 1
#define R p << 1 | 1
#define mid ((l + r) >> 1)
#define int long long
#define fi first
#define se second
#define MULT_TEST 0
using namespace std;
typedef unsigned long long ull;
const int INF = 0x3f3f3f3f;
const int MOD = 998244353;
const int N = 100005;
int tot = 0;
char a[N];
inline int read() {
    int w = 0, f = 1;
    char ch = getchar();
    while (ch < '0' || ch > '9') {
        if (ch == '-') f = -1;
        ch = getchar();
    }
    while (ch >= '0' && ch <= '9') {
        w = (w << 1) + (w << 3) + ch - 48;
        ch = getchar();
    }
    return w * f;
}
struct Matrix {
    int n, m;
    vector<vector<int>> val;
    Matrix() {}
    Matrix(int n, int m) : n(n), m(m), val(n, vector<int>(m, 0)) {}
    inline void resize(int _n, int _m) { 
        n = _n; m = _m; val.resize(n);
        for (int i = 0; i < n; i++) val[i].resize(m, 0);
    }
    inline vector<int> & operator [] (int x) { return val[x]; }
    inline Matrix operator * (Matrix &T) const {
        assert(m == T.n);
        Matrix ans(n, T.m);
        for (int i = 0; i < n; i++) 
            for (int k = 0; k < m; k++) {
                if (!val[i][k]) continue;
                for (int j = 0; j < T.m; j++) {
                    ans[i][j] += val[i][k] * T[k][j] % MOD;
                    ans[i][j] = (ans[i][j] % MOD + MOD) % MOD;
                }
            }
        return ans;
    }
}; // 矩阵
Matrix S1, S2, S3;
struct Seg {
    Matrix val;
} tr[N << 2];
inline void Pushup(int p) {
    tr[p].val = tr[L].val * tr[R].val;
}
inline void Build(int p, int l, int r) {
    tr[p].val.resize(3, 3);
    if (l == r) {
        if (a[l] == '0') tr[p].val = S1;
        else if (a[l] == '1') tr[p].val = S2;
        else tr[p].val = S3;
        return ;
    }
    Build(L, l, mid);
    Build(R, mid + 1, r);
    Pushup(p);
}
inline void Modify(int p, int l, int r, int x, Matrix d) {
    if (l == r) {
        tr[p].val = d;
        return ;
    }
    if (x <= mid) Modify(L, l, mid, x, d);
    else Modify(R, mid + 1, r, x, d);
    Pushup(p);
} // 线段树
signed main() {
    int n, Q;
    n = read(); Q = read();
    scanf("%s", a + 1);
    S1.resize(3, 3); S2.resize(3, 3); S3.resize(3, 3); 
    for (int i = 0; i < 3; i++) S1[i][i] = S2[i][i] = S3[i][i] = 1;
    S1[0][1] = S1[0][2] = 1;
    S2[1][0] = S2[1][2] = 1;
    S3[0][1] = S3[0][2] = S3[1][0] = S3[1][2] = 1;
    // 表示三个不同的转移矩阵
    Build(1, 1, n);
    while (Q--) {
        int x = read();
        char c[2];
        scanf("%s", c + 1);
        if (c[1] == '0') Modify(1, 1, n, x, S1);
        else if (c[1] == '1') Modify(1, 1, n, x, S2);
        else Modify(1, 1, n, x, S3);
        Matrix res(3, 1);
        res[2][0] = 1;
        res = tr[1].val * res;
        printf("%lld\n", (res[0][0] + res[1][0]) % MOD);
    }
    return 0;
}
posted @ 2023-08-18 13:59  WTR2007  阅读(24)  评论(0)    收藏  举报