「2020noip模拟赛 fzh」期望收益 | CodeChef - EXPREP Expected Repetitions 题解
前言
题目链接:CodeChef。
题意简述
给定仅包含小写字母字符串 \(S\) 和每个小写字母对应的权值 \(\omega_{\mathtt{a}\sim\mathtt{z}}\)。
定义 \(w(S):=\sum_{c\in S}\omega_c\),即每一个字符的权值和。定义 \(f(S)\) 为满足以下条件的字符串 \(r\) 的 \(w(r)\) 之和:\(S\) 可以被表示成 \(r+r+\cdots+r+p\),其中 \(+\) 为字符串拼接,\(p\) 为 \(r\) 的一个前缀,可以为空。
从 \(S\) 的所有连续子串中均匀随机选取一个,记为 \(T\),你要求 \(f(T)\) 的期望值,对 \(998244353\) 取模。
\(n\leq5\times10^5\)。
题目分析
期望显然没用,考虑求和。
显然有 \(\mathcal{O}(n^2)\) 枚举 \(r\),\(\mathcal{O}(\log n)\) 二分配合哈希找到 \(r\) 能贡献的区间,累加答案的做法。首先枚举了 \(r=S[a,b]\),然后二分出一个极长子串 \(S[a,p]\),使得对于每一个 \(d\in[b,p]\),\(w(r)\) 都对 \(S[a,d]\) 有贡献。
考虑我们到底在做什么。考虑二分判断的时候,记 \(c=d-(b-a)\),那么我们就是判断 \(S[a,c)\) 是否等于 \(S(b,d]\)。也就是说,只要 \(d\) 满足这个条件,\(S[a,b]\) 就对 \(S[a,d]\) 有贡献。

转换视角,不再枚举 \(r\),而是对于 \(S\) 中两个相等的子串 \(S[a,c)\) 和 \(S(b,d]\),其中 \(a\leq b\),统计 \(S[a,b]\) 对 \(S[a,d]\) 有贡献。把子串的形式变好看一点就是 \(S[l_1,r_1]=S[l_2,r_2]\),其中 \(r_1\leq r_2\),统计 \(S[l_1,l_2)\) 对 \(S[l_1,r_2]\) 的贡献。\(r_1=r_2\) 的情况是简单的,我们后文仅考虑 \(r_1\lt r_2\),即不包括和自身的贡献。
对于 \(S\) 的某一个子串 \(T\),它在 \(S\) 中出现的位置集合为 \(\Big\{[l_i,r_i]\Big\}_{i=1}^m\),那么 \(T\) 的贡献为 \(\sum\limits_{i=1}^m\sum\limits_{j=i+1}^m w\Big(S[l_i,l_j)\Big)\)。
子串所有出现的位置?考虑放到 SAM 上进行。那么就和 \(\operatorname{endpos}\) 有关了,但是还有点区别,这里 \([l_i,l_j)\) 是左端点,而 \(\operatorname{endpos}\) 是右端点。其实只用把 \(S\) 翻转,在反串 \(S'\) 上统计就行了。对于每个本质不同的子串 \(T\),贡献为 \(\sum\limits_{i=1}^{|\operatorname{edp}(T)|}\sum\limits_{j=i+1}^{|\operatorname{edp}(T)|} w\Big(S'(\operatorname{edp}_i,\operatorname{edp}_j]\Big)\)。
在 SAM 的一个节点 \(u\) 上,\(\Big(\operatorname{len}(\operatorname{fa}(u)),\operatorname{len}(u)\Big]\) 的 \(\operatorname{endpos}\) 相同,于是一并统计。如何维护 \(\operatorname{endpos}\)?显然为 parent 树上,孩子的 \(\operatorname{endpos}\) 和自身的并。
问题变到了树上。考虑用 dsu on tree 来维护集合 \(A\)。不妨设 \(p_i\) 表示 \(w\Big(S[1,i]\Big)\)。需要支持插入、删除一个数,维护 \(\sum\limits_{i=1}^{|A|}\sum\limits_{j=i+1}^{|A|} (p_{A_j}-p_{A_i})\)。
还是不直观,转换计数视角,考虑 \(p_{A_i}\) 在减号前统计了 \(i-1\) 次,在减号后统计了 \(|A|-i\) 次,所以等价表示为 \(\sum\limits_{i=1}^{|A|}p_{A_i}(2i-1-|A|)=2\sum\limits_{i=1}^{|A|}i\cdot p_{A_i}-(1+|A|)\sum\limits_{i=1}^{|A|}p_{A_i}\)。注意 \(i\) 的意义是在 \(A\) 中的排名。两个求和可以用线段树搞定。插入的时候不要忘记对右子树打上排名提升的标记。时间复杂度 \(\mathcal{O}(n\log^2 n)\),但是大常数,不太能过。
事实上,可以直接用两个树状数组,分别维护个数和 \(p\) 之和,在插入 \(x\) 的时候,分别统计 \(x\) 和两侧的贡献即可。时间复杂度 \(\mathcal{O}(n\log^2n)\)。
dsu on tree 常见优化方法为线段树合并,本题也可以。pushup 的时候,考虑跨过中点的贡献对即可。时间复杂度 \(\mathcal{O}(n\log n)\)。
代码
dsu on tree + segment tree
#include <cstdio>
#include <iostream>
#include <vector>
#include <cstring>
using namespace std;
using ll = long long;
const int N = 1e6 + 10;
const int M = 26;
const int mod = 998244353;
const int inv2 = (mod + 1) >> 1;
inline int add(int a, int b) { return a += b, a >= mod ? a - mod : a; }
inline int sub(int a, int b) { return a -= b, a < 0 ? a + mod : a; }
inline int mul(int a, int b) { return 1ll * a * b % mod; }
inline int pow(int a, int b) {
    int res = 1;
    for (; b; b >>= 1, a = mul(a, a))
        if (b & 1) res = mul(res, a);
    return res;
}
inline int inv(int a) { return pow(a, mod - 2); }
int tr[N][M], fa[N], len[N], idx[N];
int tot, lst;
vector<int> e[N];
inline int gen() {
    ++tot;
    len[tot] = idx[tot] = fa[tot] = 0;
    memset(tr[tot], 0x00, sizeof(*tr));
    return tot;
}
inline void init() { tot = 0, lst = gen(); }
inline void append(char s, int i) {
    int x = s - 'a';
    int p = lst, w = gen();
    idx[w] = i, len[w] = len[p] + 1;
    for (; p && !tr[p][x]; p = fa[p]) tr[p][x] = w;
    if (!p) {
        fa[w] = 1;
    } else {
        int q = tr[p][x];
        if (len[q] == len[p] + 1) {
            fa[w] = q;
        } else {
            int o = gen();
            memcpy(tr[o], tr[q], sizeof(*tr));
            len[o] = len[p] + 1;
            fa[o] = fa[q];
            for (; p && tr[p][x] == q; p = fa[p]) tr[p][x] = o;
            fa[w] = fa[q] = o;
        }
    }
    lst = w;
}
inline void build() {
    for (int i = 1; i <= tot; ++i) e[i].clear();
    for (int i = 2; i <= tot; ++i) {
        e[fa[i]].emplace_back(i);
    }
}
int n, ans;
char s[N];
int w[M], p[N];
int L[N], R[N], dfn[N], tim;
int siz[N], son[N];
void dfs(int u) {
    L[u] = tim + 1;
    if (idx[u]) dfn[++tim] = u;
    siz[u] = 1, son[u] = 0;
    for (int v : e[u]) {
        dfs(v);
        siz[u] += siz[v];
        if (siz[v] > siz[son[u]]) son[u] = v;
    }
    R[u] = tim;
}
#define ls (u << 1)
#define rs (u << 1 | 1)
struct {
    int s, sp, tg, ct;
} T[N << 2];
inline void ptag(int u, int x) {
    T[u].tg = add(T[u].tg, x);
    T[u].sp = add(T[u].sp, mul(x, T[u].s));
}
void upd(int u, int l, int r, int x, int f, int i) {
    if (l == r) {
        T[u].s = add(T[u].s, mul(f, p[l]));
        T[u].sp = add(T[u].sp, mul(f, mul(p[l], i)));
        T[u].ct = add(T[u].ct, f);
        return;
    }
    if (T[u].tg) {
        ptag(ls, T[u].tg), ptag(rs, T[u].tg);
        T[u].tg = 0;
    }
    int mid = (l + r) >> 1;
    if (x <= mid) {
        upd(ls, l, mid, x, f, i);
        ptag(rs, f);
    } else {
        upd(rs, mid + 1, r, x, f, i + T[ls].ct);
    }
    T[u].s = add(T[ls].s, T[rs].s);
    T[u].sp = add(T[ls].sp, T[rs].sp);
    T[u].ct = add(T[ls].ct, T[rs].ct);
}
#undef ls
#undef rs
inline void upd(int u, int f) {
    upd(1, 1, n, u, f, 1);
}
inline void upd(int l, int r, int f) {
    for (int i = l; i <= r; ++i)
        if (idx[dfn[i]]) {
            upd(idx[dfn[i]], f);
        }
}
void redfs(int u) {
    for (int v : e[u])
        if (v != son[u]) {
            redfs(v);
            upd(L[v], R[v], mod - 1);
        }
    if (son[u]) redfs(son[u]);
    if (idx[u]) upd(idx[u], 1);
    for (int v : e[u])
        if (v != son[u]) {
            upd(L[v], R[v], 1);
        }
    if (fa[u]) {
        int res = sub(mul(2, T[1].sp), mul(T[1].ct + 1, T[1].s));
        ans = add(ans, mul(res, len[u] - len[fa[u]]));
    }
}
void solve() {
    ans = 0;
    scanf("%s", s + 1);
    for (int i = 0; i < M; ++i) {
        scanf("%d", &w[i]);
    }
    n = strlen(s + 1);
    for (int i = 1; i <= n; ++i) {
        p[i] = add(p[i - 1], w[s[i] - 'a']);
        ans = add(ans, mul(p[i], sub(i * 2, n)));
    }
    init();
    for (int i = n; i; --i) {
        append(s[i], i);
    }
    build();
    tim = 0;
    dfs(1);
    redfs(1);
    upd(1, tot, mod - 1);
    ans = mul(ans, inv(mul(mul(n, n + 1), inv2)));
    printf("%d\n", ans);
}
int main() {
#ifndef XuYueming
    freopen("exprep.in", "r", stdin);
    freopen("exprep.out", "w", stdout);
#endif
    int t;
    scanf("%d", &t);
    while (t--) solve();
    return 0;
}
dsu on tree + bit
#include <cstdio>
#include <iostream>
#include <vector>
#include <cstring>
using namespace std;
using ll = long long;
const int N = 1e6 + 10;
const int M = 26;
const int mod = 998244353;
const int inv2 = (mod + 1) >> 1;
inline int add(int a, int b) { return a += b, a >= mod ? a - mod : a; }
inline int sub(int a, int b) { return a -= b, a < 0 ? a + mod : a; }
inline int mul(int a, int b) { return 1ll * a * b % mod; }
inline int pow(int a, int b) {
    int res = 1;
    for (; b; b >>= 1, a = mul(a, a))
        if (b & 1) res = mul(res, a);
    return res;
}
inline int inv(int a) { return pow(a, mod - 2); }
int tr[N][M], fa[N], len[N], idx[N];
int tot, lst;
vector<int> e[N];
inline int gen() {
    ++tot;
    len[tot] = idx[tot] = fa[tot] = 0;
    memset(tr[tot], 0x00, sizeof(*tr));
    return tot;
}
inline void init() { tot = 0, lst = gen(); }
inline void append(char s, int i) {
    int x = s - 'a';
    int p = lst, w = gen();
    idx[w] = i, len[w] = len[p] + 1;
    for (; p && !tr[p][x]; p = fa[p]) tr[p][x] = w;
    if (!p) {
        fa[w] = 1;
    } else {
        int q = tr[p][x];
        if (len[q] == len[p] + 1) {
            fa[w] = q;
        } else {
            int o = gen();
            memcpy(tr[o], tr[q], sizeof(*tr));
            len[o] = len[p] + 1;
            fa[o] = fa[q];
            for (; p && tr[p][x] == q; p = fa[p]) tr[p][x] = o;
            fa[w] = fa[q] = o;
        }
    }
    lst = w;
}
inline void build() {
    for (int i = 1; i <= tot; ++i) e[i].clear();
    for (int i = 2; i <= tot; ++i) {
        e[fa[i]].emplace_back(i);
    }
}
int n, ans;
char s[N];
int w[M], p[N];
int L[N], R[N], dfn[N], tim;
int siz[N], son[N];
void dfs(int u) {
    L[u] = tim + 1;
    if (idx[u]) dfn[++tim] = u;
    siz[u] = 1, son[u] = 0;
    for (int v : e[u]) {
        dfs(v);
        siz[u] += siz[v];
        if (siz[v] > siz[son[u]]) son[u] = v;
    }
    R[u] = tim;
}
int t[2][N];
inline void upd(int t[], int x, int v) {
    for (; x <= n; x += x & -x) t[x] = add(t[x], v);
}
inline int qry(int t[], int x) {
    int r = 0;
    for (; x; x &= x - 1) r = add(r, t[x]);
    return r;
}
inline int qry(int t[], int l, int r) {
    return sub(qry(t, r), qry(t, l - 1));
}
int res;
inline void upd(int u, int f) {
    int R = sub(qry(t[0], u, n), mul(qry(t[1], u, n), p[u]));
    int L = sub(mul(qry(t[1], 1, u), p[u]), qry(t[0], 1, u));
    res = add(res, mul(f, add(L, R)));
    upd(t[1], u, f), upd(t[0], u, mul(f, p[u]));
}
inline void upd(int l, int r, int f) {
    for (int i = l; i <= r; ++i)
        if (idx[dfn[i]]) {
            upd(idx[dfn[i]], f);
        }
}
void redfs(int u) {
    for (int v : e[u])
        if (v != son[u]) {
            redfs(v);
            upd(L[v], R[v], mod - 1);
        }
    if (son[u]) redfs(son[u]);
    if (idx[u]) upd(idx[u], 1);
    for (int v : e[u])
        if (v != son[u]) {
            upd(L[v], R[v], 1);
        }
    if (fa[u]) {
        ans = add(ans, mul(res, len[u] - len[fa[u]]));
    }
}
void solve() {
    ans = 0;
    scanf("%s", s + 1);
    for (int i = 0; i < M; ++i) {
        scanf("%d", &w[i]);
    }
    n = strlen(s + 1);
    for (int i = 1; i <= n; ++i) {
        p[i] = add(p[i - 1], w[s[i] - 'a']);
        ans = add(ans, mul(p[i], sub(i * 2, n)));
    }
    init();
    for (int i = n; i; --i) {
        append(s[i], i);
    }
    build();
    tim = 0;
    dfs(1);
    redfs(1);
    upd(1, n, mod - 1);
    ans = mul(ans, inv(mul(mul(n, n + 1), inv2)));
    printf("%d\n", ans);
}
int main() {
#ifndef XuYueming
    freopen("exprep.in", "r", stdin);
    freopen("exprep.out", "w", stdout);
#endif
    int t;
    scanf("%d", &t);
    while (t--) solve();
    return 0;
}
segment tree merging
#include <cstdio>
#include <iostream>
#include <vector>
#include <cstring>
using namespace std;
using ll = long long;
const int N = 1e6 + 10;
const int M = 26;
const int mod = 998244353;
const int inv2 = (mod + 1) >> 1;
inline int add(int a, int b) { return a += b, a >= mod ? a - mod : a; }
inline int sub(int a, int b) { return a -= b, a < 0 ? a + mod : a; }
inline int mul(int a, int b) { return 1ll * a * b % mod; }
inline int pow(int a, int b) {
    int res = 1;
    for (; b; b >>= 1, a = mul(a, a))
        if (b & 1) res = mul(res, a);
    return res;
}
inline int inv(int a) { return pow(a, mod - 2); }
int tr[N][M], fa[N], len[N], idx[N];
int tot, lst;
vector<int> e[N];
inline int gen() {
    ++tot;
    len[tot] = idx[tot] = fa[tot] = 0;
    memset(tr[tot], 0x00, sizeof(*tr));
    return tot;
}
inline void init() { tot = 0, lst = gen(); }
inline void append(char s, int i) {
    int x = s - 'a';
    int p = lst, w = gen();
    idx[w] = i, len[w] = len[p] + 1;
    for (; p && !tr[p][x]; p = fa[p]) tr[p][x] = w;
    if (!p) {
        fa[w] = 1;
    } else {
        int q = tr[p][x];
        if (len[q] == len[p] + 1) {
            fa[w] = q;
        } else {
            int o = gen();
            memcpy(tr[o], tr[q], sizeof(*tr));
            len[o] = len[p] + 1;
            fa[o] = fa[q];
            for (; p && tr[p][x] == q; p = fa[p]) tr[p][x] = o;
            fa[w] = fa[q] = o;
        }
    }
    lst = w;
}
inline void build() {
    for (int i = 1; i <= tot; ++i) e[i].clear();
    for (int i = 2; i <= tot; ++i) {
        e[fa[i]].emplace_back(i);
    }
}
int n, ans;
char s[N];
int w[M], p[N];
struct {
    int ls, rs;
    int s, c, ans;
} T[N * 30];
int Tc;
inline void pushup(int u) {
    int r = T[u].rs, l = T[u].ls;
    T[u].ans = sub(mul(T[r].s, T[l].c), mul(T[l].s, T[r].c));
    T[u].ans = add(T[u].ans, add(T[l].ans, T[r].ans));
    T[u].s = add(T[l].s, T[r].s);
    T[u].c = add(T[l].c, T[r].c);
}
int merge(int u, int v) {
    if (!u || !v) return u | v;
    T[u].ls = merge(T[u].ls, T[v].ls);
    T[u].rs = merge(T[u].rs, T[v].rs);
    pushup(u);
    return u;
}
void upd(int &u, int l, int r, int x) {
    if (!u) T[u = ++Tc] = { 0, 0, 0, 0, 0 };
    if (l == r) {
        T[u].s = p[x];
        T[u].c = 1;
        return;
    }
    int mid = (l + r) >> 1;
    if (x <= mid)
        upd(T[u].ls, l, mid, x);
    else
        upd(T[u].rs, mid + 1, r, x);
    pushup(u);
}
int rt[N];
void dfs(int u) {
    rt[u] = 0;
    for (int v : e[u]) {
        dfs(v);
        rt[u] = merge(rt[u], rt[v]);
    }
    if (idx[u]) upd(rt[u], 1, n, idx[u]);
    if (fa[u]) {
        int res = T[rt[u]].ans;
        ans = add(ans, mul(res, len[u] - len[fa[u]]));
    }
}
void solve() {
    ans = 0;
    scanf("%s", s + 1);
    for (int i = 0; i < M; ++i) {
        scanf("%d", &w[i]);
    }
    n = strlen(s + 1);
    for (int i = 1; i <= n; ++i) {
        p[i] = add(p[i - 1], w[s[i] - 'a']);
        ans = add(ans, mul(p[i], sub(i * 2, n)));
    }
    init();
    for (int i = n; i; --i) {
        append(s[i], i);
    }
    build();
    Tc = 0;
    dfs(1);
    ans = mul(ans, inv(mul(mul(n, n + 1), inv2)));
    printf("%d\n", ans);
}
int main() {
#ifndef XuYueming
    freopen("exprep.in", "r", stdin);
    freopen("exprep.out", "w", stdout);
#endif
    int t;
    scanf("%d", &t);
    while (t--) solve();
    return 0;
}
本文作者:XuYueming,转载请注明原文链接:https://www.cnblogs.com/XuYueming/p/18989536。
若未作特殊说明,本作品采用 知识共享署名-非商业性使用 4.0 国际许可协议 进行许可。

                
            
        
浙公网安备 33010602011771号