[题解]CF2077C Binary Subsequence Value Sum

感谢此题将我送上 Master。

思路

注意观察 \(F(v,l,r)\) 的定义,容易将其刻画成 \(v_{l \sim r}\)\(1\) 的数量减去 \(0\) 的数量。

不妨将 \(1\) 的权值记作 \(1\)\(0\) 的权值记作 \(-1\),令这个序列的权值序列为 \(val_i\)。则对于一个子序列 \(v\) 的得分是 \(\lfloor \frac{S^2}{4} \rfloor\),其中 \(S = \sum val_i\)。考虑如下证明:因为无论选择哪一个点作为分界点其左右两边的 \(F\) 函数的和不变,由均值不等式得到该函数的最大值。

那么问题转化为了求所有子序列的 \(\lfloor \frac{S^2}{4} \rfloor\) 的和,将下取整去掉,即:

\[\frac{\sum_{T}{\sum S_T^2 - S_T \bmod 2}}{4} \]

注意到 \(val_i\) 的取值并不会影响 \(\sum_T (S_T \bmod 2)\) 的结果,因此原式等价于:

\[\frac{\sum_{T}{\sum S_T^2 - |T| \bmod 2}}{4} = \frac{(\sum_{T} \sum S_T^2) - 2^{n - 1}}{4} \]

接下来只需要维护所有子序列的 \(\sum S^2\),即 \(\sum_T(\sum_{x \in T} val_x)^2\)。考虑拆掉平方项,得:

\[\sum_T(\sum_{x \in T}val_x^2 + 2\sum_{x,y \in T \wedge x < y}val_x val_y) \]

分别计算 \(\sum_T\sum_{x \in T}val_x^2\)\(2\sum_T\sum_{x,y \in T \wedge x < y}val_x val_y\)

  • 对于前者,注意到 \(\forall i \in T,val_i \in \{1,-1\}\),因此无论给定字符串是什么 \(val_i^2\) 恒为 \(1\)。即计算 \(\sum_T |T|\),因为 \(T\) 要非空,因此贡献为 \(n \times 2^{n - 1}\)
  • 对于后者,考虑计算每一对 \(i < j\) 的贡献和,即计算 \(2 \sum_{i < j}2^{n - 2}val_i val_j = 2^{n - 1}\sum_{i < j}val_i val_j\)。接下来只需化简 \(\sum_{i < j}val_i val_j\),注意到如下等式:

\[(\sum_{i = 1}^{n}val_i)^2 = \sum_{i = 1}^{n}val_i^2 + 2\sum_{i < j}val_i val_j \]

\(sum = \sum_{i = 1}^{n}val_i\),则移项易得 \(\sum_{i < j}val_i val_j = \frac{sum^2 - n}{2}\)

整理一下可以得到 \(\sum S^2 = n \times 2^{n - 1} + 2^{n - 2} \times (sum^2 - n) = 2^{n - 2} \times (sum^2 + n)\)。答案为:\(\frac{2^{n - 2} \times (sum^2 + n) - 2^{n - 1}}{4}\)

动态维护 \(sum\) 是容易的。其实完全可以做到区间修改。

Code

#include <bits/stdc++.h>
#define re register
#define int long long
#define Add(a,b) (((a) + (b)) % mod)
#define Sub(a,b) (((a) - (b) + mod) % mod)
#define Mul(a,b) ((a) * (b) % mod)
#define chMul(a,b) (a = Mul(a,b))

using namespace std;

const int mod = 998244353;
const int inv = 748683265;
const int N = 2e5 + 10;
int n,q;
int val[N];
char s[N];

inline int read(){
    int r = 0,w = 1;
    char c = getchar();
    while (c < '0' || c > '9'){
        if (c == '-') w = -1;
        c = getchar();
    }
    while (c >= '0' && c <= '9'){
        r = (r << 3) + (r << 1) + (c ^ 48);
        c = getchar();
    }
    return r * w;
}

inline int qmi(int a,int b){
    int res = 1;
    while (b){
        if (b & 1) chMul(res,a);
        chMul(a,a); b >>= 1;
    } return res;
}

inline void solve(){
    n = read(),q = read();
    scanf("%s",s + 1);
    int sum = 0;
    for (re int i = 1;i <= n;i++) sum += (val[i] = (s[i] == '1') ? 1 : -1);
    while (q--){
        int x; x = read();
        sum -= (2 * val[x]); val[x] = -val[x];
        printf("%lld\n",Mul(Sub(Mul(qmi(2,n - 2 + mod - 1),Add(Mul(sum,sum),n)),qmi(2,n - 1)),inv));
    }
}

signed main(){
    int T; T = read();
    while (T--) solve();
    return 0;
}
posted @ 2025-03-14 14:40  WBIKPS  阅读(48)  评论(0)    收藏  举报