数据结构杂题

Trick

  1. 对于很难直接维护的修改,可以在上一次的答案上加上这次修改造成的贡献。(CF1928F Digital Patterns)

  2. 对于难维护的东西,可以考虑它的组合意义,变成 dp 问题,再用矩阵解决。(Distance)

  3. 对一个区间中的数去重后排序,可以考虑莫队。(Fibonacci-ish II)

  4. 快速维护斐波那契数列的变换,考虑用矩阵。(Fibonacci-ish II)

  5. 关于 \(\rm{mex}\) 的题目,可以考虑当 \(\rm{mex}\)\(x\) 时的贡献。(Yet Another MEX Problem)

  6. \(a_i\in[l,r]\) 的所有 \(a_i\) 的逆序对,可以令 \(id_{a_i}=i\),转化成求 \(i\in[l,r]\)\(id_i\) 的逆序对。(Book Sorting)

  7. 对于两棵树换根的题目,考虑一棵树换根,另一棵树用数据结构维护在它与第一棵树的关系。(Two tree)

  8. 无向图,每次添加一条边。对于一条边,它能够出现在一个 SCC 中的时间是满足单调性的,所以可以用二分求出一条边出现 SCC 中的最小时间。(Simultaneous Coloring)

  9. 动态维护处理边的一些操作(改边权、删边、加边),可以考虑时间分治。([HNOI2010] 城市建设)

  10. 对于同时涉及位运算与加减的东西,考虑拆位维护每一位上的信息。(「FeOI Round 4.5」はぐ)

  11. 对于中位数的题目,考虑二分判定是否存在 \(\ge x\) 的中位数。(Big Wins!)

  12. 区间开根考虑势能线段树。(Farmer John's Favorite Function)

  13. 对于同时涉及位运算与加减的东西,考虑拆位维护每一位上的信息。(树锯解构)

  14. 当区间同时又包含与相交关系时,可以考虑之维护其中之一。(Building Forest Trails)

  15. 求的东西很难维护时,考虑拆贡献,然后对拆完的贡献逐项维护。(点心)

题目

Process with Constant Sum

中文题面:

这种题,先手玩一下样例。

然后就可以发现,不管怎么操作,最终得到的结果都一样,那么就考虑怎么快速模拟这个过程。

分析第一种操作,发现就是把当前数的值移动 \(2\) 到下一个数,那么一直做第一种操作,直到做不了为止,最后的序列满足除了最后一个数,剩下的数要么 \(0\) 要么 \(1\)

这启发是可以给所有数值对 \(2\) 取模的。

然后再分析二操作,如果说对 1 0 1 操作,那么会得到 0 0 00 0 2\(2\) 取模后就是 0 0 0);若对 1 0 0 操作,则会得到 1 0 1

再手玩一下,就可以发现最后序列的形态一定为 0 0 0 ... 1 1 1 0/1,序列的最后一个数可以为 \(0\) 也可以为 \(1\)

当然,除了最后一个数以外为 \(0\) 的地方,它不对 \(2\) 取模的值也肯定是 \(0\),对于最后一个数就不一定了。

所以这里只需要考虑最后一个位置的真实值,不难发现,如果初始序列有大于 \(1\) 的数,最后一个数就不可能是 \(0\);或者初始序列都小于等于 \(1\),那么也不能出现 1 0 1 这样的形式,因为这样会合成出一个 \(2\)

所以,这个序列的初始值之和必须要等于最后的形态(也就是 0 0 0 ... 1 1 1 0/1)的和,才能确定最后一个数的真实值是否为 \(0\)

现在考虑怎么快速模拟此操作,不难发现是可以放在线段树上搞的。

显然线段树上的一个区间所代表的形态肯定也形如 0 0 0 ... 1 1 1 0/1,那么就考虑如何合并两个区间。

如果两个区间 \(1\) 中间的 \(0\) 的个数为奇数,那么就会抵消 \(1\),也就是说对于 0 0 0 1 1 10 0 0 1 1 1 1 合并,最后的结果是 0 0 0 0 0 0 0 0 0 0 0 0 1

如果两个区间 \(1\) 中间的 \(0\) 的个数为偶数,那么就会合并 \(1\),也就是说对于 0 0 0 1 1 10 0 1 1 1 1 合并,最后的结果是 0 0 0 0 0 1 1 1 1 1 1 1

所以合并是可以 \(O(1)\) 做的,讨论一下即可。

代码
#include <bits/stdc++.h>
#define int long long
#define ls k << 1
#define rs k << 1 | 1
#define mid ((l + r) >> 1)

void Freopen() {
    freopen("", "r", stdin);
    freopen("", "w", stdout);
}

using namespace std;
const int N = 2e5 + 10, M = 2e5 + 10, inf = 1e9, mod = 998244353;

int n, q;
int a[N];

struct sgt {
    int cnt, op, len, sum;
} tr[N * 4];

// cnt: 区间中 1 的个数;op: 区间的最后一个数是 0 还是 1;len: 区间长度;sum: 区间的初始值之和。

sgt merge( sgt a, sgt b) {
    if ((a.op + b.len - b.cnt - b.op) % 2) {
        if (b.cnt > a.cnt) return {b.cnt - a.cnt, b.op, a.len + b.len, a.sum + b.sum};
        else if (b.cnt == a.cnt) return {0, 1, a.len + b.len, a.sum + b.sum};
        else return {a.cnt - b.cnt, b.op ^ 1, a.len + b.len, a.sum + b.sum};
    }

    return {a.cnt + b.cnt, b.op, a.len + b.len, a.sum + b.sum};
}

void build( int k = 1, int l = 1, int r = n) {
    if (l == r) {
        if (a[l] % 2) tr[k] = {1, 0, 1, a[l]};
        else tr[k] = {0, 1, 1, a[l]};

        return ;
    }

    build(ls, l, mid), build(rs, mid + 1, r);
    tr[k] = merge(tr[ls], tr[rs]);
}

void upd( int x, int v, int k = 1, int l = 1, int r = n) {
    if (l == r) {
        if (v % 2) tr[k] = {1, 0, 1, v};
        else tr[k] = {0, 1, 1, v};

        return ;
    }

    x <= mid ? upd(x, v, ls, l, mid) : upd(x, v, rs, mid + 1, r);
    tr[k] = merge(tr[ls], tr[rs]);
}

sgt ask( int x, int y, int k = 1, int l = 1, int r = n) {
    if (x <= l && r <= y) return tr[k];

    if (y <= mid) return ask(x, y, ls, l, mid);
    if (x > mid) return ask(x, y, rs, mid + 1, r);

    return merge(ask(x, y, ls, l, mid), ask(x, y, rs, mid + 1, r));
}

signed main() {
    ios :: sync_with_stdio(false);
    cin.tie(0), cout.tie(0);

    cin >> n;
    for ( int i = 1; i <= n; i ++) cin >> a[i];

    build(1, 1, n);
    cin >> q;

    while (q --) {
        int op, l, r;
        cin >> op >> l >> r;

        if (op == 1) {
            upd(l, r);
        } else {
            sgt res = ask(l, r);
            cout << res.len - res.cnt - (res.op ? (res.sum > res.cnt ? 1 : 0) : 0) << '\n';
        }
    }

    return 0;
}

CF1928F Digital Patterns

首先考虑一个合法的正方形是怎么出来的。

分析后发现,如果一个左上角为 \((x,y)\) 长度为 \(d\) 的正方形合法,只需要满足 \(i\in[x,x+d-1),a_i\neq a_{i+1}\)\(j\in[y,y+d-1),b_j\neq b_{j+1}\) 即可。

然后考虑去维护极长的一个连续段 \(X\) 满足没有相邻的 \(a\) 相同,和极长连续段 \(Y\) 满足没有相邻的 \(b\) 相同。

假设 \(|X|\le |Y|\)\(|X|\) 指连续段 \(X\) 的长度)。

那么这两个连续段对合法正方形的贡献记为 \(f(X,Y)\),就有:

\[f(X,Y)=\sum_{i=1}^{|X|}(|X|-i+1)(|Y|-i+1) \]

这个式子显然可以化成只与 \(|X|\)\(|Y|\) 有关的式子:

\[f(X,Y)=\dfrac{|X|(|X|+1)(2|X|+1)}{6}-\dfrac{|X|^2(|X|+1)}{2}+\dfrac{|X|(|X|+1)|Y|}{2} \]

相当于知道 \(|X|\)\(|Y|\) 后可以 \(O(1)\) 计算贡献。

那么现在考虑给 \(a\) 数组分成若干个极长连续段,记为 \(X_1\)\(X_2\)\(\dots\)\(X_p\)

同理把 \(b\) 也分成若干极长连续段,记为 \(Y_1\)\(Y_2\)\(\dots\)\(Y_q\)

如何维护连续段?差分以后,两个 \(0\) 之间就是一个连续段,用 set 维护这些 \(0\) 的位置即可。

为了简化式子,再记 \(u(i)=i\)\(v(i)=\frac{i(i+1)}{2}\)\(w(i)=\frac{i(i+1)(2i+1)}{6}-\frac{i^2(i+1)}{2}\)

那么最后求的就是:

\[\sum_i^p\sum_j^q f(X_i,Y_j) \]

化到最简就是:

\[\sum_i^p(w(X_i)\sum_{Y_j\ge X_i}1+v(X_i)\sum_{Y_j\ge X_i}u(Y_j)+\sum_{Y_j\lt X_i}w(Y_j)+u(X_i)\sum_{Y_j\lt X_i}v(Y_j)) \]

维护这个东西,可以在修改之前先求一遍答案,然后考虑修改会造成的影响。

以修改 \(a\) 数组为例,没次修改只会使极长连续段的个数造成 \(O(1)\) 的变化,那么直接 set 里面改就行,计算贡献时用一个树状数组来处理形如这样 \(\sum_{Y_j\ge X_i}\) 的贡献。

修改 \(b\) 数组与修改 \(a\) 数组同理。

代码
#include <bits/stdc++.h>
#define int long long

void Freopen() {
    freopen(".in", "r", stdin);
    freopen(".out", "w", stdout);
}

using namespace std;
const int N = 3e5 + 10, M = 2e5 + 10, inf = 1e16, mod = 998244353;

int n, m, q;
int a[N], b[N], da[N], db[N];

int u( int i) {
    return i;
}

int v( int i) {
    return i * (i + 1) / 2;
}

int w ( int i) {
    return i * (i + 1) * (2 * i + 1) / 6 - i * i * (i + 1) / 2;
}

struct bit {
    int tr[N], n;

    void add( int u, int v) {
        if (u <= 0) return ;
        for (; u <= n; u += (u & -u)) tr[u] += v;
    }

    int ask( int u, int res = 0) {
        if (u <= 0) return 0;
        if (u > n) u = n;
        for (; u; u -= (u & -u)) res += tr[u];
        return res;
    }
} ;

bit y, yu, yw, yv, x, xu, xw, xv;

void addx( int k, int op) {
    x.add(n - k + 1, op), xu.add(n - k + 1, u(k) * op);
    xw.add(k, w(k) * op), xv.add(k, v(k) * op);
}

void addy( int k, int op) {
    y.add(m - k + 1, op), yu.add(m - k + 1, u(k) * op);
    yw.add(k, w(k) * op), yv.add(k, v(k) * op);
}

int askx( int k) {
    return w(k) * x.ask(n - k + 1) + v(k) * xu.ask(n - k + 1) + xw.ask(k - 1) + u(k) * xv.ask(k - 1);
}

int asky( int k) {
    return w(k) * y.ask(m - k + 1) + v(k) * yu.ask(m - k + 1) + yw.ask(k - 1) + u(k) * yv.ask(k - 1);
}

set< int> s1, s2;
int ans = 0;

void upd1( int k, int x) {
    if (k == n + 1) return ;
    if (! da[k] && da[k] + x) {
        auto it = s1.find(k), pre = prev(it), nxt = next(it);
        addx(* it - * pre, -1), ans -= asky(* it - * pre);
        addx(* nxt - * it, -1), ans -= asky(* nxt - * it);
        addx(* nxt - * pre, 1), ans += asky(* nxt - * pre);
        s1.erase(it);
    }
    if (da[k] && ! (da[k] + x)) {
        auto it = s1.insert(k).first, pre = prev(it), nxt = next(it);
        addx(* nxt - * pre, -1), ans -= asky(* nxt - * pre);
        addx(* it - * pre, 1), ans += asky(* it - * pre);
        addx(* nxt - * it, 1), ans += asky(* nxt - * it);
    }

    da[k] += x;
}

void upd2( int k, int x) {
    if (k == m + 1) return ;
    if (! db[k] && db[k] + x) {
        auto it = s2.find(k), pre = prev(it), nxt = next(it);
        addy(* it - * pre, -1), ans -= askx(* it - * pre);
        addy(* nxt - * it, -1), ans -= askx(* nxt - * it);
        addy(* nxt - * pre, 1), ans += askx(* nxt - * pre);
        s2.erase(it);
    }
    if (db[k] && ! (db[k] + x)) {
        auto it = s2.insert(k).first, pre = prev(it), nxt = next(it);
        addy(* nxt - * pre, -1), ans -= askx(* nxt - * pre);
        addy(* it - * pre, 1), ans += askx(* it - * pre);
        addy(* nxt - * it, 1), ans += askx(* nxt - * it);
    }

    db[k] += x;
}

signed main() {
    ios :: sync_with_stdio(false);
    cin.tie(0), cout.tie(0);

    cin >> n >> m >> q;
    a[0] = b[0] = inf;
    y.n = yu.n = yw.n = yv.n = m;
    x.n = xu.n = xw.n = xv.n = n;

    for ( int i = 1; i <= n; i ++)
        cin >> a[i], da[i] = a[i] - a[i - 1];

    for ( int i = 1; i <= m; i ++)
        cin >> b[i], db[i] = b[i] - b[i - 1];

    s1.insert(1), s2.insert(1), s1.insert(n + 1), s2.insert(m + 1);
    for ( int i = 1; i <= n; i ++) if (da[i] == 0) s1.insert(i);
    for ( int i = 1; i <= m; i ++) if (db[i] == 0) s2.insert(i);

    for ( auto it = s1.begin(); next(it) != s1.end(); it ++) addx(* next(it) - * it, 1);

    for ( auto it = s2.begin(); next(it) != s2.end(); it ++) addy(* next(it) - * it, 1);
    
    for ( auto it = s1.begin(); next(it) != s1.end(); it ++)
        ans += asky(* next(it) - * it);

    cout << ans << '\n';

    while (q --) {
        int op, l, r, x;
        cin >> op >> l >> r >> x;

        if (op == 1) upd1(l, x), upd1(r + 1, -x);
        if (op == 2) upd2(l, x), upd2(r + 1, -x);

        cout << ans << '\n';
    }
    
    return 0;
}

Distance

没有原题。

\(n,q\le3\times 10^5\)

考虑这个询问怎么刻画。

枚举路径 \(x\to y\) 上的点 \(z\),让 \(z\) 成为 \(x\)\(y\)\(lca\)

那么当 \(z\)\(lca\) 时,贡献就是 \((\sum_{e\in x\to z}W_e)(S_z)(\sum_{e\in z\to y}W_e)\)\(W_e\) 表示边 \(e\) 的权值,\(S_z\) 表示不经过路径能够到达点 \(z\) 的点的数量)。

如果把 \(S\) 看作点的点权,那么它的组合意义就是在路径上有顺序地选择一条边、一个点、一条边,把它们权值乘起来的和。

这似乎可以 dp,考虑在序列上怎么求。

\(f_{i,0/1/2}\) 表示选了一条边/一条边、一个点/一条边、一个点、一条边的权值乘积之和,转移就是 \(O(1)\) 合并。

这启发可以放在矩阵上转移,然后树剖维护。

但是现在有些细节的问题,就比如 \(S_u\) 怎么处理?因为不同的路径 \(S_u\) 的值不一样。

可以这样考虑,把 \(S_{fa_u}\) 的值挂在 \(u\) 上,这样 \(S_{fa_u}\) 的值唯一,也就是 \(siz_{fa_u}-siz_u\),同时把 \((fa_u,u)\) 这条边的权值挂在 \(u\) 上。

当然,对于路径的 \(lca\) 是特殊情况,因为 \(S_{lca}\)\(n\) 减去 \(lca\) 在路径上的两个儿子的子树大小,所以考虑把 \(lca\) 的矩阵单独处理出来。

还有些细节,比如说线段树合并的方向,树剖重链合并的方向都要注意。

当然 \(\log^2\) 加上矩阵乘法的复杂度是很高的,所以考虑只维护矩阵会改变的位置(上三角),一共只有 \(6\) 个位置,可以看作 \(O(1)\)

代码
#include <bits/stdc++.h>
#define int unsigned int

void Freopen() {
    freopen("", "r", stdin);
    freopen("", "w", stdout);
}

using namespace std;
const int N = 3e5 + 10, M = 2e5 + 10, inf = 1e9, mod = 998244353;

int n, q;
vector< pair< int, int> > G[N];

int W[N];
int dep[N], siz[N], fa[N], son[N];
int top[N], dfn[N], rev[N];
int tot;

void dfs1( int u, int fu) {
    fa[u] = fu, siz[u] = 1, dep[u] = dep[fu] + 1;

    for ( auto [v, w] : G[u]) {
        if (v == fu) continue ;

        dfs1(v, u);
        W[v] = w, siz[u] += siz[v];
        son[u] = (siz[v] > siz[son[u]] ? v : son[u]);
    }
}

void dfs2( int u, int topt) {
    top[u] = topt, dfn[u] = ++ tot, rev[tot] = u;
    if (son[u]) dfs2(son[u], topt);

    for ( auto [v, w] : G[u])
        if (v != fa[u] && v != son[u])
            dfs2(v, v);
}

int lca( int u, int v) {
    while (top[u] != top[v]) {
        if (dep[top[u]] < dep[top[v]]) swap(u, v);
        u = fa[top[u]];
    }

    return dep[u] < dep[v] ? u : v;
}

int Kth( int u, int k) {
    while (dep[u] - dep[top[u]] + 1 <= k) {
        k -= (dep[u] - dep[top[u]] + 1);
        u = fa[top[u]];
    }

    return rev[dfn[u] - k];
}

struct mat {
    int a[6];

    void reset() {
        for ( int i = 0; i < 6; i ++)
            a[i] = 0;
    }

    mat operator * ( const mat & x) const {
        mat z;
        z.a[0] = a[0] + x.a[0];
        z.a[1] = x.a[1] + a[0] * x.a[3] + a[1];
        z.a[2] = x.a[2] + x.a[4] * a[0] + x.a[5] * a[1] + a[2];
        z.a[3] = x.a[3] + a[3];
        z.a[4] = x.a[4] + a[3] * x.a[5] + a[4];
        z.a[5] = x.a[5] + a[5];

        return z;
    }
} ;

void tag( mat & a, int u, int op) {
    a.reset();

    int S = siz[fa[u]] - siz[u];
    a.a[0] = a.a[5] = W[u];
    a.a[3] = S;

    if (! op) a.a[1] = W[u] * S;
    else a.a[4] = W[u] * S;
}

struct sgt {
    mat tr[2][N * 4];

    #define ls k << 1
    #define rs k << 1 | 1
    #define mid ((l + r) >> 1)

    void psu( int k) {
        tr[0][k] = tr[0][rs] * tr[0][ls];
        tr[1][k] = tr[1][ls] * tr[1][rs];
    }

    void build( int k = 1, int l = 1, int r = n) {
        tr[0][k].reset(), tr[1][k].reset();

        if (l == r) {
            tag(tr[0][k], rev[l], 0);
            tag(tr[1][k], rev[l], 1);

            return ;
        }

        build(ls, l, mid), build(rs, mid + 1, r);
        psu(k);
    }

    void upd( int x, int w, int k = 1, int l = 1, int r = n) {
        if (l == r) {
            W[rev[l]] = w;
            tag(tr[0][k], rev[l], 0);
            tag(tr[1][k], rev[l], 1);

            return ;
        }

        x <= mid ? upd(x, w, ls, l, mid) : upd(x, w, rs, mid + 1, r);
        psu(k);
    }

    mat ask( int x, int y, int op, int k = 1, int l = 1, int r = n) {
        if (x <= l && r <= y) return tr[op][k];

        if (y <= mid) return ask(x, y, op, ls, l, mid);
        if (x > mid) return ask(x, y, op, rs, mid + 1, r);

        return (op ? 
        ask(x, y, op, ls, l, mid) * ask(x, y, op, rs, mid + 1, r) : 
        ask(x, y, op, rs, mid + 1, r) * ask(x, y, op, ls, l, mid));
    }

    #undef ls
    #undef rs
    #undef mid
} T;

mat ask( int u, int v, int op) {
    mat res;
    res.reset();

    while (top[u] != top[v]) {
        if (dep[top[u]] < dep[top[v]]) swap(u, v);

        res = (op ? T.ask(dfn[top[u]], dfn[u], op) * res : res * T.ask(dfn[top[u]], dfn[u], op));
        u = fa[top[u]];
    }

    if (dep[u] < dep[v]) swap(u, v);
    res = (op ? T.ask(dfn[v], dfn[u], op) * res : res * T.ask(dfn[v], dfn[u], op));

    return res;
}

struct edge {
    int u, v;
} E[N];

signed main() {
    ios :: sync_with_stdio(false);
    cin.tie(0), cout.tie(0);

    cin >> n >> q;

    for ( int i = 1; i < n; i ++) {
        int u, v, w; cin >> u >> v >> w;
        G[u].push_back({v, w});
        G[v].push_back({u, w});
        E[i] = {u, v};
    }

    dfs1(1, 0), dfs2(1, 1);
    T.build();

    for ( int i = 1; i < n; i ++)
        if (dep[E[i].u] < dep[E[i].v]) swap(E[i].u, E[i].v);

    while (q --) {
        int op, x, y; cin >> op >> x >> y;

        if (op == 1) T.upd(dfn[E[x].u], y);
        else {
            if (x == y) {
                cout << 0 << '\n';
                continue ;
            }

            int lca = :: lca(x, y);

            mat ans;
            ans.reset();

            if (y == lca) {
                int u = Kth(x, dep[x] - dep[lca] - 1);
                ans = ans * ask(x, u, 0);
                cout << ans.a[2] << '\n';
                continue ;
            }

            if (x == lca) {
                int v = Kth(y, dep[y] - dep[lca] - 1);
                ans = ans * ask(y, v, 1);
                cout << ans.a[2] << '\n';
                continue ;
            }

            mat L;

            int u = Kth(x, dep[x] - dep[lca] - 1);
            int v = Kth(y, dep[y] - dep[lca] - 1);

            int S = n - siz[u] - siz[v];
            L.a[0] = L.a[5] = W[u] + W[v];
            L.a[1] = W[u] * S;
            L.a[2] = W[u] * W[v] * S;
            L.a[4] = W[v] * S;
            L.a[3] = S;

            if (dep[x] - dep[lca] >= 2) {
                int X = Kth(x, dep[x] - dep[lca] - 2);
                ans = ans * ask(x, X, 0);
            }

            ans = ans * L;

            if (dep[y] - dep[lca] >= 2) {
                int Y = Kth(y, dep[y] - dep[lca] - 2);
                ans = ans * ask(y, Y, 1);
            }

            cout << ans.a[2] << '\n';
        }
    }

    return 0;
}

Fibonacci-ish II

对于去重排序的询问,考虑莫队。

那重点是当插入一个数时,贡献会如何变化。

主要问题是斐波那契数列的项数很难表示,所以考虑用矩阵来处理。

斐波那契数列有:

\[\begin{bmatrix} f_{i} & f_{i+1} \end{bmatrix} = \begin{bmatrix} f_{i-1} & f_i \end{bmatrix}\times \begin{bmatrix} 0 & 1 \\ 1 & 1 \end{bmatrix} \]

所以可以给每个数维护一个矩阵,假设这个数为 \(x\),那么它对应的矩阵就可以是 \(\begin{bmatrix} x\times f_{i-1} & x\times f_i \end{bmatrix}\)

当插入一个数的时候,先看一下它前面有多少个数,算出 \(\begin{bmatrix} f_{i-1} & f_i \end{bmatrix}\) 的值,然后把在它后面的数的矩阵都乘上 \(\begin{bmatrix} 0 & 1 \\ 1 & 1 \end{bmatrix}\),这可以用权值线段树做。

对于删除,因为矩阵 \(\begin{bmatrix} 0 & 1 \\ 1 & 1 \end{bmatrix}\) 有逆,所以也可做。

复杂度 \(O(n\sqrt{n}\log n)\)

很神秘的是手动维护矩阵还是过不了,不想卡常了。

代码
#include <iostream>
#include <algorithm>
#include <math.h>

void Freopen() {
    freopen("", "r", stdin);
    freopen("", "w", stdout);
}

using namespace std;
const int N = 3e4 + 10, M = 2e5 + 10, inf = 1e9;

int n, m, mod, blk;
int a[N], ind[N], f[N], ans[N];
int tot;

int L = 1, R;

struct que {
    int l, r, id;

    bool operator < ( const que & rhs) const {
        return (l / blk == rhs.l / blk) ? ((l / blk) & 1 ? r < rhs.r : r > rhs.r) : l < rhs.l;
    }
} q[N];

int add( int x, int v) {
    x += v;
    return x >= mod ? x - mod : x;
}

struct mat {
    int a00, a01, a10, a11;

    void reset() {
        a00 = a01 = a10 = a11 = 0;
    }

    void init() {
        a00 = a11 = 1, a01 = a10 = 0;
    }

    mat operator * ( const mat & x) const {
        mat z;

        z.a00 = (a00 * x.a00 + a01 * x.a10) % mod;
        z.a01 = (a00 * x.a01 + a01 * x.a11) % mod;
        z.a10 = (a10 * x.a00 + a11 * x.a10) % mod;
        z.a11 = (a10 * x.a01 + a11 * x.a11) % mod;

        return z;
    }

    mat operator + ( const mat & x) const {
        mat z;

        z.a00 = add(a00, x.a00);
        z.a01 = add(a01, x.a01);
        z.a10 = add(a10, x.a10);
        z.a11 = add(a11, x.a11);

        return z;
    }
} ;

mat A, B;

struct sgt {
    #define ls k << 1
    #define rs k << 1 | 1
    #define mid ((l + r) >> 1)

    int sum[N * 4];
    mat tr[N * 4], lz[N * 4];

    void build( int k = 1, int l = 1, int r = tot) {
        tr[k].reset();
        lz[k].init();

        if (l == r) return ;

        build(ls, l, mid), build(rs, mid + 1, r);
    }

    void psu( int k) {
        tr[k] = tr[ls] + tr[rs];
        sum[k] = sum[ls] + sum[rs];
    }

    void pst( int k, const mat & v) {
        tr[k] = tr[k] * v;
        lz[k] = lz[k] * v;
    }

    void psd( int k) {
        if (lz[k].a00 == 1 && lz[k].a11 == 1 && lz[k].a01 == 0 && lz[k].a10 == 0) return ;

        pst(ls, lz[k]), pst(rs, lz[k]);
        lz[k].init();
    }

    void insert( int x, int v, int id, int k = 1, int l = 1, int r = tot) {
        if (l == r) {
            sum[k] += v;

            if (id) {
                tr[k].a00 = f[id - 1] * ind[x] % mod;
                tr[k].a01 = f[id] * ind[x] % mod;                
            } else tr[k].a00 = tr[k].a01 = 0;

            return ;
        }

        psd(k);
        x <= mid ? insert(x, v, id, ls, l, mid) : insert(x, v, id, rs, mid + 1, r);
        psu(k);
    }

    void upd( int x, int y, const mat & v, int k = 1, int l = 1, int r = tot) {
        if (x > y) return ;
        if (x <= l && r <= y) return pst(k, v);
        psd(k);

        if (x <= mid) upd(x, y, v, ls, l, mid);
        if (y > mid) upd(x, y, v, rs, mid + 1, r);

        psu(k);
    }

    int ask( int x, int y, int k = 1, int l = 1, int r = tot) {
        if (x > y) return 0;
        if (x <= l && r <= y) return sum[k];
        psd(k);

        int res = 0;
        if (x <= mid) res += ask(x, y, ls, l, mid);
        if (y > mid) res += ask(x, y, rs, mid + 1, r);

        return res;
    }

    #undef ls
    #undef rs
    #undef mid
} T;

int cnt[N];

void add( int x) {
    cnt[x] ++;
    if (cnt[x] != 1) return ;

    int res = T.ask(1, x - 1) + 1;
    T.insert(x, 1, res);
    T.upd(x + 1, tot, A);
}

void del( int x) {
    cnt[x] --;
    if (cnt[x] != 0) return ;

    T.insert(x, -1, 0);
    T.upd(x + 1, tot, B);
}

signed main() {
    ios :: sync_with_stdio(false);
    cin.tie(0), cout.tie(0);

    cin >> n >> mod;
    blk = sqrt(n);

    A.a00 = 0;
    A.a01 = A.a10 = A.a11 = 1;

    B.a00 = mod - 1;
    B.a01 = B.a10 = 1;
    B.a11 = 0;

    f[1] = f[2] = 1;
    for ( int i = 3; i <= n; i ++) f[i] = add(f[i - 1], f[i - 2]);

    for ( int i = 1; i <= n; i ++) cin >> a[i], ind[i] = a[i];
    sort(ind + 1, ind + n + 1), tot = unique(ind + 1, ind + n + 1) - ind - 1;

    for ( int i = 1; i <= n; i ++) a[i] = lower_bound(ind + 1, ind + tot + 1, a[i]) - ind;
    for ( int i = 1; i <= n; i ++) ind[i] %= mod;

    T.build();

    cin >> m;

    for ( int i = 1; i <= m; i ++) {
        int l, r; cin >> l >> r;
        q[i] = {l, r, i};
    }

    sort(q + 1, q + m + 1);

    for ( int i = 1; i <= m; i ++) {
        int l = q[i].l, r = q[i].r, id = q[i].id;

        while (R < r) add(a[++ R]);
        while (L > l) add(a[-- L]);
        while (R > r) del(a[R --]);
        while (L < l) del(a[L ++]);

        ans[id] = T.tr[1].a01;
    }

    for ( int i = 1; i <= m; i ++) cout << ans[i] << '\n';

    return 0;
}

Yet Another MEX Problem

可以发现,对于一个区间,如果取的数不是 \(\rm{mex}\),而是任意一个没出现的数,那么求出来的贡献肯定小于等于取 \(\rm{mex}\) 时的答案。

这启发可以对 \(0\sim n\) 中的所有数求出取它的时候的贡献,取最大值就是答案。

那么考虑维护一个数组 \(f_x\),表示取 \(x\) 时,\(lst_x+1\sim r\) 中比 \(x\) 大的数的个数(\(lst_x\) 表示 \(x\) 最后一次出现的位置)。

初始 \(f_x=0\),每当 \(r\) 向右移动时,会添加 \(a_r\) 这个数,\(a_r\) 会给 \(0\sim a_r-1\) 中的 \(f_x\)\(+1\),然后将 \(f_{a_r}\leftarrow 0\),最后询问 \(\max_{x=0}^n f_x\) 即可。

发现这就是前缀加、单点改、全局查询最大值,直接线段树即可。

代码
#include <bits/stdc++.h>
#define ls k << 1
#define rs k << 1 | 1
#define mid ((l + r) >> 1)

void Freopen() {
    freopen("", "r", stdin);
    freopen("", "w", stdout);
}

using namespace std;
const int N = 3e5 + 10, M = 2e5 + 10, inf = 1e9, mod = 998244353;

int n;
int a[N], lst[N];

int mx[N * 4], lz[N * 4];

void build( int k, int l, int r) {
    mx[k] = lz[k] = 0;
    if (l == r) return ;

    build(ls, l, mid), build(rs, mid + 1, r);
}

void pst( int k, int v) {
    mx[k] += v, lz[k] += v;
}

void psu( int k) {
    mx[k] = max(mx[ls], mx[rs]);
}

void psd( int k) {
    if (! lz[k]) return ;
    pst(ls, lz[k]), pst(rs, lz[k]);
    lz[k] = 0;
}

void upd( int x, int k = 1, int l = 0, int r = n) {
    if (l == r) return mx[k] = 0, void();
    psd(k);
    x <= mid ? upd(x, ls, l, mid) : upd(x, rs, mid + 1, r);
    psu(k);
}

void mdf( int x, int y, int k = 1, int l = 0, int r = n) {
    if (x > y) return ;
    if (x <= l && r <= y) return pst(k, 1);

    psd(k);

    if (x <= mid) mdf(x, y, ls, l, mid);
    if (y > mid) mdf(x, y, rs, mid + 1, r);

    psu(k);
}

void solve() {
    cin >> n;
    build(1, 0, n);

    for ( int i = 1; i <= n; i ++) {
        int x; cin >> x;

        mdf(0, x - 1), upd(x);
        cout << mx[1] << ' ';
    }

    cout << '\n';
}

signed main() {
    ios :: sync_with_stdio(false);
    cin.tie(0), cout.tie(0);

    int T; cin >> T;
    while (T --) solve();

    return 0;
}

Book Sorting

中文题面:

给出一个长为 \(n\) 的排列。每一次可以选择以下的任一操作进行:

  1. 交换相邻的两个数
  2. 将排列中一个数挪到序列开头
  3. 将排列中一个数挪到序列结尾

求使得排列有序的最小总操作次数。

\(n\le 5\times 10^5\)

可以先分析一下操作大概会是什么样的,不难发现,对于每个数只会做 \(2\)\(3\) 中的一个,且只会做一次。

那么可以想到最终的操作形式是,对值域在 \([1,x]\) 内的数做 \(2\) 操作,对值域在 \([y,n]\) 内的数做 \(3\) 操作,对值域在 \((x,y)\) 内的数做 \(1\) 操作。

然后发现,对 \((x,y)\) 内的数做一操作的次数,等于值域在 \((x,y)\) 内的数的逆序对个数。

由于值域上的逆序对很难求,希望能够放在序列上求,所以考虑令 \(id_{a_i}=i\),这样就可以把值域在 \([l,r]\) 内的数的逆序对转化成下标在 \([l,r]\) 内的数的逆序对。

打个表后能发现,对于 \(x\),它最优的 \(y\) 满足单调性,这启发可以分治处理。

然后维护序列的逆序对,使用树状数组即可。

代码
#include <bits/stdc++.h>
#define int long long

void Freopen() {
    freopen("", "r", stdin);
    freopen("", "w", stdout);
}

using namespace std;
const int N = 5e5 + 10, M = 2e5 + 10, inf = 1e9, mod = 998244353;

int n, ans = inf;
int a[N], p[N];

int L = 1, R = 0;

int tr1[N], tr2[N];

void add1( int u, int v) {
    for (; u <= n; u += (u & -u)) tr1[u] += v;
}

void add2( int u, int v) {
    for (; u; u -= (u & -u)) tr2[u] += v;
}

int ask1( int u, int res = 0) {
    for (; u; u -= (u & -u)) res += tr1[u];
    return res;
}

int ask2( int u, int res = 0) {
    for (; u <= n; u += (u & -u)) res += tr2[u];
    return res;
}

int now;

int cal( int l, int r) {
    if (l > r) return 0;

    if (R < l || L > r) {
        for ( int i = L; i <= R; i ++) add1(p[i], -1), add2(p[i], -1);
        now = 0;
        for ( int i = l; i <= r; i ++) now += ask2(p[i]), add1(p[i], 1), add2(p[i], 1);
        L = l, R = r;
        return now;
    }

    while (R < r) R ++, now += ask2(p[R]), add1(p[R], 1), add2(p[R], 1);
    while (L > l) L --, now += ask1(p[L]), add1(p[L], 1), add2(p[L], 1);
    while (R > r) add1(p[R], -1), add2(p[R], -1), now -= ask2(p[R]), R --;
    while (L < l) add1(p[L], -1), add2(p[L], -1), now -= ask1(p[L]), L ++;

    return now;
}

void solve( int l, int r, int ql, int qr) {
    if (l > r) return ;

    int mid = (l + r) >> 1;
    int mi = inf, qmid = 0;

    for ( int i = max(mid + 1, ql); i <= qr; i ++) {
        int res = mid + cal(mid + 1, i - 1) + (n - i + 1);
        if (mi > res) mi = res, qmid = i;
    }

    ans = min(ans, mi);

    solve(l, mid - 1, ql, qmid);
    solve(mid + 1, r, qmid, qr);
}

signed main() {
    ios :: sync_with_stdio(false);
    cin.tie(0), cout.tie(0);

    cin >> n;
    for ( int i = 1; i <= n; i ++) cin >> a[i];
    for ( int i = 1; i <= n; i ++) p[a[i]] = i;

    solve(0, n + 1, 0, n + 1);
    cout << ans << '\n';

    return 0;
}

Two tree

中文题面:

\(F_{1/2,u,v}\) 表示在第 \(1/2\) 棵树上,当 \(u\) 为根时,\(v\) 的子树大小。

对于所有 \(u\),求出有多少 \(v\),满足 \(F_{1,u,v}\gt F_{2,u,v}\)

既然对于所有根都要求答案,自然想到换根,所以先对每棵树都是 \(1\) 为根时求一遍子树大小。

显然这里不能两棵树同时换根,所以考虑对第一棵树换根,这样就可以得到第一棵树的所有点的子树大小,记为 \(siz1\)

考虑在第二棵树上,把 \(u\) 变成根后有什么影响。

发现只有在 \(u\to 1\) 这条链上的所有点的子树大小会改变,除了点 \(u\)(点 \(u\) 为根时子树大小肯定是 \(n\)),其余点的子树大小为 \(n\) 减去它在链上的儿子的子树大小。

如何维护这个东西呢?不妨考虑用重链剖分,记 \(val_u\)\(n\) 减去 \(u\) 的重儿子的子树大小,对于重链直接线段树维护 \(siz1_u\gt val_u\) 的个数,对于轻边单独判断即可。

因为链上只会有最多 \(O(\log n)\) 条轻边,所以复杂度 \(O(n\log^2 n)\)

但是会发现这样做有点问题,可能会有:“换根前就大于,换根后也大于。”的情况,这样答案会多统计一次。

记第二棵树的点的子树大小为 \(siz2\),每次再减去 \(siz1\gt siz2\) 的个数就行了,这个再开一棵线段树维护即可。

代码
#include <bits/stdc++.h>

void Freopen() {
    freopen("", "r", stdin);
    freopen("", "w", stdout);
}

using namespace std;
const int N = 5e5 + 10, M = 2e5 + 10, inf = 1e9, mod = 998244353;

int n;
vector< int> G1[N];
vector< int> G2[N];

int siz1[N], fa1[N];

void dfs( int u, int fu) {
    siz1[u] = 1, fa1[u] = fu;

    for ( auto v : G1[u]) {
        if (v == fu) continue ;
        dfs(v, u);
        siz1[u] += siz1[v];
    }
}

int siz2[N], fa2[N], son[N];

void dfs1( int u, int fu) {
    siz2[u] = 1, fa2[u] = fu;

    for ( auto v : G2[u]) {
        if (v == fu) continue ;

        dfs1(v, u);
        siz2[u] += siz2[v];
        son[u] = (siz2[v] > siz2[son[u]] ? v : son[u]);
    }
}

int tot;
int top[N], dfn[N], rev[N];

void dfs2( int u, int topt) {
    top[u] = topt, dfn[u] = ++ tot, rev[tot] = u;

    if (son[u]) dfs2(son[u], topt);

    for ( auto v : G2[u])
        if (v != son[u] && v != fa2[u])
            dfs2(v, v);
}

struct sgt {
    #define ls k << 1
    #define rs k << 1 | 1
    #define mid ((l + r) >> 1)
    int sum[N * 4], val[N * 4];

    void build( int op, int k = 1, int l = 1, int r = n) {
        sum[k] = 0;

        if (l == r) {
            val[rev[l]] = (op ? siz2[rev[l]] : (n - siz2[son[rev[l]]]));
            sum[k] = (siz1[rev[l]] > val[rev[l]]);
            return ;
        }

        build(op, ls, l, mid), build(op, rs, mid + 1, r);
        sum[k] = sum[ls] + sum[rs];

    }

    void upd( int x, int v, int k = 1, int l = 1, int r = n) {
        if (l == r) return sum[k] = (v > val[rev[l]]), void();

        x <= mid ? upd(x, v, ls, l, mid) : upd(x, v, rs, mid + 1, r);
        sum[k] = sum[ls] + sum[rs];
    }

    int ask( int x, int y, int k = 1, int l = 1, int r = n) {
        if (x > y) return 0;
        if (x <= l && r <= y) return sum[k];

        int res = 0;

        if (x <= mid) res += ask(x, y, ls, l, mid);
        if (y > mid) res += ask(x, y, rs, mid + 1, r);

        return res;
    }
} T0, T1;

int ans[N], SUM;

void DP( int u) {
    ans[u] = SUM;

    int x = u;
    while (x) {
        ans[u] += T0.ask(dfn[top[x]], dfn[x] - 1);
        if (fa2[top[x]]) ans[u] += (siz1[fa2[top[x]]] > n - siz2[top[x]]);
        ans[u] -= T1.ask(dfn[top[x]], dfn[x]);

        x = fa2[top[x]];
    }

    int sizu = siz1[u], sum = SUM;
    for ( auto v : G1[u]) {
        if (v == fa1[u]) continue ;

        int sizv = siz1[v];
        SUM -= (siz1[u] > siz2[u]), SUM -= (siz1[v] > siz2[v]);
        siz1[u] = n - siz1[v], siz1[v] = n;
        T0.upd(dfn[u], siz1[u]), T1.upd(dfn[u], siz1[u]);
        T0.upd(dfn[v], siz1[v]), T1.upd(dfn[v], siz1[v]);
        SUM += (siz1[u] > siz2[u]), SUM += (siz1[v] > siz2[v]);
        DP(v);
        siz1[u] = sizu, siz1[v] = sizv, SUM = sum;
        T0.upd(dfn[u], siz1[u]), T1.upd(dfn[u], siz1[u]);
        T0.upd(dfn[v], siz1[v]), T1.upd(dfn[v], siz1[v]);
    }
}

signed main() {
    ios :: sync_with_stdio(false);
    cin.tie(0), cout.tie(0);

    cin >> n;

    for ( int i = 1, u, v; i < n; i ++) {
        cin >> u >> v;
        G1[u].push_back(v);
        G1[v].push_back(u);
    }

    for ( int i = 1, u, v; i < n; i ++) {
        cin >> u >> v;
        G2[u].push_back(v);
        G2[v].push_back(u);
    }

    dfs(1, 0);
    dfs1(1, 0), dfs2(1, 1);

    T0.build(0), T1.build(1);

    for ( int i = 1; i <= n; i ++) SUM += (siz1[i] > siz2[i]);

    DP(1);

    for ( int i = 1; i <= n; i ++) cout << ans[i] << ' ';

    return 0;
}

Simultaneous Coloring

首先分析操作,不难发现只有最后一次操作有用。

那么对于 \(R\) 的限制,肯定先涂列,再涂行,\(B\) 就是反过来。

考虑把图建出来,\((x,y)\)\(R\),就连 \((y+n,x)\),如果是 \(B\) 就连 \((x,y+n)\)

然后发现建出来的图如果是 DAG,花费就是 \(0\),如果出现了环,就需要选择这个环的所有点,花费是环的点数的平方。

也就说答案等于图上所有 SCC 的点数平方和。

那么就把问题转化成动态加边,维护 SCC 的点数。

考虑一条边 \(i\) 能够出现在 SCC 中的一个最小时间 \(t_i\),这个 \(t_i\) 是有单调性的。

那么可以对所有边二分出来这个 \(t_i\)。(具体就是把 \([1,mid]\) 的边给建图,然后看这条边的两个端点是否在一个连通块里。)

求出 \(t_i\) 后,顺序遍历一遍时间,在对应时间把边加进并查集里,维护一下答案即可。

这样复杂度是 \(O(q(n+m)\log q)\),因为要同时对所有边二分,不妨考虑整体二分。

对于一个区间 \([l,r]\),记 \(G\) 表示所有最小出现时间在 \([l,r]\) 的边的集合。

每次把时间在 \([1,mid]\) 的边跑一个 Tarjan,如果两端是同一个 SCC 的,就递归进左边,否则递归进右边。

当然不能每次都把 \([1,mid]\) 的边加进去,可以考虑只加 \([l,mid]\) 的边,分治时右区间继承一下左区间加的边即可,这里具体可以看代码。

复杂度 \(O(q\log q)\)

代码
#include <bits/stdc++.h>

void Freopen() {
    freopen("", "r", stdin);
    freopen("", "w", stdout);
}

using namespace std;
const int N = 5e5 + 10, M = 2e5 + 10, inf = 1e9, mod = 998244353;

int n, m, q;

struct dsu {
    int fa[N], siz[N], n;
    long long res;

    void init( int _n) {
        n = _n, res = 0;
        for ( int i = 1; i <= n; i ++) fa[i] = i, siz[i] = 1;
    }

    int find( int x) {
        return x == fa[x] ? x : (fa[x] = find(fa[x]));
    }

    void merge( int x, int y) {
        x = find(x), y = find(y);
        if (x == y) return ;

        if (siz[x] > 1) res -= 1ll * siz[x] * siz[x];
        if (siz[y] > 1) res -= 1ll * siz[y] * siz[y];

        fa[y] = x, siz[x] += siz[y], siz[y] = 0;
        res += 1ll * siz[x] * siz[x];
    }
} D;

vector< int> sta;
vector< int> G[N];
int vis[N], dfn[N], low[N], scc[N];
int tot;
int cnt;

void tarjan( int u) {
    sta.push_back(u), vis[u] = 1;
    dfn[u] = low[u] = ++ tot;
    
    for ( auto v : G[u]) {
        if (! dfn[v]) tarjan(v), low[u] = min(low[u], low[v]);
        else if (vis[v]) low[u] = min(low[u], dfn[v]);
    }
    
    if (dfn[u] != low[u]) return ;
    
    int v; cnt ++;
    
    do {
        v = sta.back(); sta.pop_back(), vis[v] = 0;
        scc[v] = cnt;
    } while (v != u) ;
}

vector< tuple< int, int, int> > e;
vector< int> vec[N];

void clear( int u) {
    G[u].clear();
    dfn[u] = low[u] = scc[u] = 0;
}

void solve( int l, int r, const vector< tuple< int, int, int> > & E) {
    int mid = (l + r) >> 1;

    cnt = tot = 0;
    for ( auto [u, v, w] : E)
        clear(u), clear(v);

    for ( auto [u, v, w] : E)
        if (w <= mid) G[u].push_back(v);

    for ( auto [u, v, w] : E)
        if (! dfn[u]) tarjan(u);

    if (l == r) {
        for ( auto [u, v, w] : E) 
            if (scc[u] == scc[v])
                D.merge(get<0>(e[w - 1]), get<1>(e[w - 1]));

        cout << D.res << '\n';

        return ;
    }

    vector< tuple< int, int, int> > q1, q2;

    for ( auto [u, v, w] : E) {
        if (scc[u] == scc[v]) {
            if (w <= mid) q1.push_back({u, v, w});
        } else q2.push_back({scc[u], scc[v], w}); //这里直接用求出来的 scc 编号,就相当于继承了左区间跑的 Tarjan 信息。
    }

    solve(l, mid, q1), solve(mid + 1, r, q2);
}

signed main() {
    ios :: sync_with_stdio(false);
    cin.tie(0), cout.tie(0);

    cin >> n >> m >> q;

    for ( int i = 1; i <= q; i ++) {
        int u, v;
        string op;
        cin >> u >> v >> op;

        if (op[0] == 'R') e.push_back({v + n, u, i});
        else e.push_back({u, v + n, i});
    }

    D.init(n + m);
    solve(1, q, e);

    return 0;
}

[HNOI2010] 城市建设

考虑时间分治。

对于一个时间区间,把这个时间内要修改的边成为“动态边”,把不在这个时间内的边成为“静态边”。

那么关于最小生成树有一些结论:

  • 如果把所有“动态边”的边权改成 \(\infin\) 跑 Kruskal,没有被选进去的“静态边”以后就不可能被选进去。
  • 如果把所有“动态边”的边权改成 \(-\infin\) 跑 Kruskal,被选进去的“静态边”以后就一定会被选进去。

根据这个,不难想到维护一个集合 \(may\),表示可能被选进的边集。

对于永远不会被选进的边,就把它去掉;对于一定会选进去的边,把这条边合并成一个点,加上边权后去掉。这样每个区间的边数一定是 \(O(len)\) 的。

这种做法就是比较经典的时间分治。

具体实现看代码。

代码
#include <bits/stdc++.h>
#define int long long

void Freopen() {
    freopen("", "r", stdin);
    freopen("", "w", stdout);
}

using namespace std;
const int N = 2e4 + 10, M = 5e4 + 10, inf = 1e9, mod = 998244353;

int n, m, q;
int RES;

int ans[M];

struct edge {
    int u, v, w;
} E[M], tE[M];

struct mdf {
    int i, w;
} Q[M];

struct DSU {
    int fa[N], siz[N], n;
    vector< pair< int, int> > del;
    set< int> S; // S 表示缩完点之后的点集

    DSU() {}

    DSU( int _n) {
        n = _n; del.clear(), S.clear();
        for ( int i = 1; i <= n; i ++) siz[i] = 1, fa[i] = i, S.insert(i);
    }

    int find( int x) {
        return x == fa[x] ? x : find(fa[x]);
    }

    int find2( int x) {
        return x == fa[x] ? x : fa[x] = find2(fa[x]);
    }

    int merge( int u, int v) {
        u = find(u), v = find(v);
        if (u == v) return 0;

        if (siz[u] < siz[v]) swap(u, v);
        del.push_back({u, v});
        siz[u] += siz[v], fa[v] = u, S.erase(v);
        return 1;
    }

    void Del() {
        if (! del.size()) return ;
        auto [u, v] = del.back(); del.pop_back();
        siz[u] -= siz[v], fa[v] = v, S.insert(v);
    }
} D1, D2;
// D1:维护缩点的可撤销并查集
// D2:求最小生成树时的普通并查集

set< int> get( set< int> s) {
    vector< int> vec;

    for ( auto i : D1.S) D2.fa[i] = i;

    for ( auto i : s) vec.push_back(i);

    sort(vec.begin(), vec.end(), [&]( int a, int b) {
        return E[a].w < E[b].w;
    });

    set< int> res;

    for ( auto i : vec) {
        auto [u, v, w] = E[i];
        u = D1.find(u), v = D1.find(v);
        u = D2.find2(u), v = D2.find2(v);
        if (u == v) continue ;
        D2.fa[v] = u;
        res.insert(i);
    }

    return res;
}
// 对边集 s 求出最小生成树的边集

set< int> may;
// 可能选进的边的集合

void solve( int l, int r) {
    if (l == r) {
        ans[l] = RES;
        if (D1.siz[D1.find(1)] != n)
            ans[l] += min(Q[l].w, (int)may.size() ? E[* may.begin()].w : inf);

        // 判断图是否连通,若不连通,要么选自己这条边,要么选最后剩下的那条“静态边”

        return ;
    }

    int mid = (l + r) >> 1;
    set< int> L, R, tmay = may, supt;
    int tRES = RES, tot = 0;
    for ( int i = l; i <= mid; i ++) L.insert(Q[i].i);
    for ( int i = mid + 1; i <= r; i ++) R.insert(Q[i].i);
    
    for ( auto i : R) if (! L.count(i)) may.insert(i);
    may = get(may);
    supt = may;

    for ( int i = l; i <= mid; i ++) E[Q[i].i].w = -inf, supt.insert(Q[i].i);
    supt = get(supt);

    for ( int i = l; i <= mid; i ++) E[Q[i].i].w = tE[Q[i].i].w;

    for ( auto i : supt) if (may.count(i))
        may.erase(i), tot += D1.merge(E[i].u, E[i].v), RES += E[i].w;

    solve(l, mid);
    while (tot) D1.Del(), tot --;
    may = tmay, RES = tRES;

    for ( int i = l; i <= mid; i ++) E[Q[i].i].w = tE[Q[i].i].w = Q[i].w;
    // 将左区间的边修改后,处理右区间

    for ( auto i : L) if (! R.count(i)) may.insert(i);
    may = get(may);
    supt = may;

    for ( int i = mid + 1; i <= r; i ++) E[Q[i].i].w = -inf, supt.insert(Q[i].i);
    supt = get(supt);

    for ( int i = mid + 1; i <= r; i ++) E[Q[i].i].w = tE[Q[i].i].w;

    for ( auto i : supt) if (may.count(i))
        may.erase(i), tot += D1.merge(E[i].u, E[i].v), RES += E[i].w;

    solve(mid + 1, r);
    while (tot) D1.Del(), tot --;
    may = tmay, RES = tRES; 
}

signed main() {
    ios :: sync_with_stdio(false);
    cin.tie(0), cout.tie(0);

    cin >> n >> m >> q;
    D1 = DSU(n);

    for ( int i = 1; i <= m; i ++) {
        int u, v, w;
        cin >> u >> v >> w;
        tE[i] = E[i] = {u, v, w}, may.insert(i);
    }

    for ( int i = 1; i <= q; i ++) {
        int id, w; cin >> id >> w;
        Q[i] = {id, w};
        may.erase(id);
    }

    may = get(may);
    set< int> supt = may;

    for ( int i = 1; i <= q; i ++) E[Q[i].i].w = -inf, supt.insert(Q[i].i);
    supt = get(supt);

    for ( int i = 1; i <= q; i ++) E[Q[i].i].w = tE[Q[i].i].w;

    for ( auto i : supt) if (may.count(i))
        may.erase(i), D1.merge(E[i].u, E[i].v), RES += E[i].w;
    // 先对全局跑一次

    solve(1, q);

    for ( int i = 1; i <= q; i ++) cout << ans[i] << '\n';

    return 0;
}

「FeOI Round 4.5」はぐ

因为异或满足可减性,对于一条路径 \((u,v)\) 考虑树上差分变成四条链的询问,记 \(l\)\({\rm lca}(u,v)\)

  • \(\oplus_{x\in u\to 1}a_x-dep_x+dep_u\)
  • \(\oplus_{x\in v\to 1}a_x+dep_x+dep_u-2dep_{l}\)
  • \(\oplus_{x\in l\to 1}a_x-dep_x+dep_u\)
  • \(\oplus_{x\in fa_l\to 1}a_x+dep_x+dep_u-2dep_{l}\)

\(a_x-dep_x+dep_u\)\(a_x+dep_x+dep_u-2dep_l\) 显然是两种对称询问,这里以前者为例。

\(w_x\)\(a_x-dep_x\)\(dep_u\)\(val\),询问就转化成把一条链上的数给加上一个常数,然后求异或和。

既然又有异或又有加减,考虑拆位。

对于第 \(k\) 位,记 \(op\) 表示 \(w_x+val\) 的第 \(k\) 位是否进位,那么对于 \(x\),第 \(k\) 位的值就是 \({\rm bit}_k(w_x)\oplus{\rm bit}_k(val)\oplus op\)

考虑把这三个值分开维护。

对于 \(\oplus {\rm bit}_k(w_x)\),就是树上前缀异或和;对于 \(\oplus {\rm bit}_k(val)\),只需要考虑 \(dep_x\);对于 \(\oplus op\),首先 \(op=[w_x\bmod 2^k+val\bmod 2^k\ge 2^k]\),可以变形成 \(op=[w_x\bmod 2^k\ge 2^k-val\bmod 2^k]\),把 \(\ge\) 后面的看作常数,求这个就可以在 dfs 时同时用树状数组维护,具体是 \(x\) 刚进递归栈时,把它加入树状数组,出栈后就在树状数组中删除,查询就是一个后缀。

代码实现唯一注意的是可能有负数,这个用 unsigned int 处理即可。

代码
#include <bits/stdc++.h>
#define ui unsigned int

void Freopen() {
    freopen("", "r", stdin);
    freopen("", "w", stdout);
}

using namespace std;
const int N = 2e5 + 10, M = 2e5 + 10, inf = 1e9, mod = 998244353;

int n, q;
int a[N];

struct que {
    ui val;
    int id;
} ;

struct bit {
    int tr[1 << 20];

    void upd( int u) {
        u ++;
        for (; u; u -= (u & -u)) tr[u] ^= 1;
    }

    int ask( int u, int res = 0) {
        u ++;
        for (; u < (1 << 20); u += (u & -u)) res ^= tr[u];
        return res;
    }
} t;

vector< que> q1[N], q2[N];
ui w1[N], w2[N];
int ans[N];

int siz[N], dep[N], fa[N], son[N], top[N];

vector< int> G[N];

void dfs1( int u, int fu) {
    fa[u] = fu, siz[u] = 1, dep[u] = dep[fu] + 1;

    for ( auto v : G[u]) {
        if (v == fu) continue ;

        dfs1(v, u);
        siz[u] += siz[v];
        son[u] = (siz[v] > siz[son[u]] ? v : son[u]);
    }
}

void dfs2( int u, int topt) {
    top[u] = topt;

    if (son[u]) dfs2(son[u], topt);

    for ( auto v : G[u]) {
        if (v == fa[u] || v == son[u]) continue ;
        dfs2(v, v);
    }
}

int lca( int u, int v) {
    while (top[u] != top[v]) {
        if (dep[top[u]] < dep[top[v]]) swap(u, v);
        u = fa[top[u]];
    }

    return (dep[u] < dep[v] ? u : v);
}

void dfs3( int u, int k, ui sx, int op) {
    int mask = ((1 << k) - 1);
    ui w = (op ? w1[u] : w2[u]);
    sx ^= w;
    t.upd(w & mask);

    for ( auto [val, id] : (op ? q1[u] : q2[u])) {
        ui res = t.ask((1 << k) - (val & mask));
        ui bit = ((val >> k) & 1);

        ans[id] ^= ((((sx >> k) & 1) ^ res ^ ((dep[u] & 1) ? bit : 0)) << k);
    }

    for ( auto v : G[u])
        if (v != fa[u]) dfs3(v, k, sx, op);

    t.upd(w & mask);
}

signed main() {
    ios :: sync_with_stdio(false);
    cin.tie(0), cout.tie(0);

    cin >> n >> q;

    for ( int i = 1; i <= n; i ++) cin >> a[i];

    for ( int i = 1; i < n; i ++) {
        int u, v; cin >> u >> v;
        G[u].push_back(v), G[v].push_back(u);
    }
    
    dfs1(1, 0), dfs2(1, 1);

    for ( int i = 1; i <= n; i ++) w1[i] = (ui)(a[i] - dep[i]), w2[i] = (ui)(a[i] + dep[i]);
    for ( int i = 1; i <= q; i ++) {
        int u, v; cin >> u >> v;
        int l = lca(u, v);

        q1[u].push_back({(ui)dep[u], i});
        q2[v].push_back({(ui)(dep[u] - 2 * dep[l]), i});
        q1[l].push_back({(ui)dep[u], i});
        q2[fa[l]].push_back({(ui)(dep[u] - 2 * dep[l]), i});
    }

    for ( int i = 0; i < 20; i ++)
        dfs3(1, i, 0, 1), dfs3(1, i, 0, 0);

    for ( int i = 1; i <= q; i ++) cout << ans[i] << '\n';

    return 0;
}

Big Wins!

对于最小值,很套路的把它的极长最小区间求出来,可以用单调栈。

对于一个最小值为 \(x\) 的区间 \([l,r]\),如何求出里面最大的中位数?

考虑去二分这个中位数 \(y\),很经典的结论,若把 \(\lt y\) 的数看作 \(-1\),把 \(\ge y\) 的数看作 \(1\),那如果存在 \(\ge y\) 的中位数,要满足它们的和 \(\ge 0\)

所以为了尽可能满足能够选出 \(y\),肯定想要和最大。

那么相当于在 \([l,i-1]\) 选出一个最大后缀和、\([i+1,r]\) 选出一个最大前缀和即可。

最大后缀和与最大前缀和可以用线段树维护出来,但是因为要二分中位数 \(y\),需要得到所有 \(y\) 对应的 \(-1\)\(1\) 序列,这里可以用主席树处理。

复杂度 \(O(n\log^2n)\)

代码
#include <bits/stdc++.h>

void Freopen() {
    freopen("", "r", stdin);
    freopen("", "w", stdout);
}

using namespace std;
const int N = 5e5 + 10, M = 2e5 + 10, inf = 1e9, mod = 998244353;

int n, tot;
int a[N], rt[N];

struct sgt {
    int sum, lmax, rmax;
} T[N * 20];

int ls[N * 20], rs[N * 20];

sgt merge( sgt a, sgt b) {
    return {a.sum + b.sum, max(a.lmax, a.sum + b.lmax), max(b.rmax, b.sum + a.rmax)};
}

void ins( int lk, int & k, int l, int r, int x, int v) {
    k = ++ tot;
    T[k] = T[lk], ls[k] = ls[lk], rs[k] = rs[lk];

    if (l == r) {
        T[k] = {v, max(0, v), max(0, v)};
        return ;
    }

    int mid = (l + r) >> 1;
    x <= mid ? ins(ls[lk], ls[k], l, mid, x, v) : ins(rs[lk], rs[k], mid + 1, r, x, v);

    T[k] = merge(T[ls[k]], T[rs[k]]);
}

sgt ask( int k, int l, int r, int x, int y) {
    if (x > y || ! k) return {0, 0, 0};
    if (x <= l && r <= y) return T[k];

    int mid = (l + r) >> 1;
    if (y <= mid) return ask(ls[k], l, mid, x, y);
    if (x > mid) return ask(rs[k], mid + 1, r, x, y);

    return merge(ask(ls[k], l, mid, x, y), ask(rs[k], mid + 1, r, x, y));
}

vector< int> vec[N];
int sta[N], top;

int L[N], R[N];

void solve() {
    cin >> n;

    for ( int i = 0; i <= n; i ++) rt[i] = 0, vec[i].clear();
    for ( int i = 1; i <= n; i ++) cin >> a[i], vec[a[i]].push_back(i);
    a[n + 1] = 0;

    for ( int i = 0; i <= tot; i ++)
        T[i] = {0, 0, 0}, ls[i] = rs[i] = 0;
    tot = 0;

    for ( int i = 1; i <= n; i ++) {
        int root = rt[0];
        ins(root, rt[0], 1, n, i, 1);
    }

    for ( int i = 1; i <= n; i ++) {
        rt[i] = rt[i - 1];

        for ( auto v : vec[i - 1]) {
            int root = rt[i];
            ins(root, rt[i], 1, n, v, -1);
        }
    }

    top = 0;
    for ( int i = 1; i <= n + 1; i ++) {
        while (top && a[sta[top]] > a[i]) R[sta[top]] = i - 1, top --;
        sta[++ top] = i;
    }

    top = 0;
    for ( int i = n; i >= 0; i --) {
        while (top && a[sta[top]] > a[i]) L[sta[top]] = i + 1, top --;
        sta[++ top] = i;       
    }

    int ans = 0;
    for ( int i = 1; i <= n; i ++) {
        int l = 1, r = n;

        while (l < r) {
            int mid = (l + r + 1) >> 1;

            if (ask(rt[mid], 1, n, L[i], i - 1).rmax + (a[i] < mid ? -1 : 1) + ask(rt[mid], 1, n, i + 1, R[i]).lmax >= 0) l = mid;
            else r = mid - 1;
        }

        ans = max(ans, l - a[i]);
    }

    cout << ans << '\n';
}

signed main() {
    ios :: sync_with_stdio(false);
    cin.tie(0), cout.tie(0);

    int T; cin >> T;
    while (T --) solve();

    return 0;
}

Farmer John's Favorite Function

首先发现若每次只保留 \(f(i)\) 的整数部分,答案是不变的,所以只用考虑每次对 \(f(i)\) 向下取整的结果。

可以想到一个做法,对时间开一颗线段树,然后从 \(1\sim n\) 枚举 \(a_i\),给对应位置加上对应的 \(a_i\)\(a_i\) 的相邻两次修改内值是一样的,那么就是一个区间加)。

然后每次给所有位置开根号即可。

考虑势能线段树,发现一个数开几次根号就会趋近于 \(1\)

那么线段树上维护一个最大值 \(mx\) 与最小值 \(mi\),若线段树上一个区间满足 \(mx-\lfloor\sqrt mx\rfloor=mi -\lfloor \sqrt mi\rfloor\),那么这整段区间都相当于减去了 \(\lfloor\sqrt mx\rfloor-mx\),打一个加的标记即可,对于其他的位置就递归到叶子即可。

复杂度 \(O(m\log m\log V)\)

代码
#include <bits/stdc++.h>
#define int long long

void Freopen() {
    freopen("", "r", stdin);
    freopen("", "w", stdout);
}

using namespace std;
const int N = 2e5 + 10, M = 2e5 + 10, inf = 1e9, mod = 998244353;

int n, m;
int a[N];

int mysqrt( int x) {
    int res = sqrtl(x);
    while (res * res > x) res -= 1;
    while ((res + 1) * (res + 1) <= x) res += 1;
    return res;
}

struct sgt {
    #define ls k << 1
    #define rs k << 1 | 1
    #define mid ((l + r) >> 1)
    #define lson ls, l, mid
    #define rson rs, mid + 1, r
    int mx[N * 4], mi[N * 4], tag[N * 4];

    void psu( int k) {
        mx[k] = max(mx[ls], mx[rs]);
        mi[k] = min(mi[ls], mi[rs]);
    }

    void pst( int k, int v) {
        mx[k] += v, mi[k] += v, tag[k] += v;
    }

    void psd( int k) {
        if (! tag[k]) return ;
        pst(ls, tag[k]), pst(rs, tag[k]);
        tag[k] = 0;
    }

    void upd( int x, int y, int v, int k = 1, int l = 1, int r = m) {
        if (x > y) return ;
        if (x <= l && r <= y) return pst(k, v);
        psd(k);

        if (x <= mid) upd(x, y, v, lson);
        if (y > mid) upd(x, y, v, rson);

        psu(k);
    }

    void sQrt( int x, int y, int k = 1, int l = 1, int r = m) {
        if (x > y) return ;
        if (l == r) return mx[k] = mysqrt(mx[k]), mi[k] = mysqrt(mi[k]), void();
        if (x <= l && r <= y && mx[k] - mysqrt(mx[k]) == mi[k] - mysqrt(mi[k])) return pst(k, mysqrt(mx[k]) - mx[k]);
        psd(k);

        if (x <= mid) sQrt(x, y, lson);
        if (y > mid) sQrt(x, y, rson);

        psu(k);
    }

    void print( int k = 1, int l = 1, int r = m) {
        if (l == r) return cout << mi[k] << '\n', void();
        psd(k);
        print(lson), print(rson);
    }

    #undef ls
    #undef rs
    #undef mid
    #undef lson
    #undef rson
} T;

vector< pair< int, int> > que[N];

signed main() {
    ios :: sync_with_stdio(false);
    cin.tie(0), cout.tie(0);

    cin >> n >> m;
    for ( int i = 1; i <= n; i ++) cin >> a[i];

    for ( int i = 1; i <= m; i ++) {
        int p, x; cin >> p >> x;
        que[p].push_back({x, i});
    }

    for ( int i = 1; i <= n; i ++) {
        int lstt = 1, lstv = a[i];

        for ( auto [val, tim] : que[i]) {
            T.upd(lstt, tim - 1, lstv);
            lstt = tim, lstv = val;
        }

        T.upd(lstt, m, lstv);
        T.sQrt(1, m);
    }

    T.print();

    return 0;
}

树锯解构

没有原题。

给一个序列,区间加对 \(2^m\) 取模,询问区间与。

\(n,q\le 5\times 10^5,m\le32\)

考虑拆位,对 \(a_i\) 考虑它的第 \(k\) 位是 \(1\) 的情况,发现可以求出一个 \([l,r]\),满足 \([a_i+l\bmod 2^k,a_i+r\bmod 2^k)\)\([2^k,2^{k+1})\) 中,也就是说给 \(a_i\) 加上 \([l,r]\) 中的一个数,可以使得 \(a_i\) 的第 \(k\) 位是 \(1\)

那么每次区间加操作都是把 \(l\)\(r\) 减去一个数,其实就是把 \(l\)\(r\) 看作在一个长度为 \(2^{k+1}\) 的环上移动。

询问时若第 \(k\) 位的区间与是 \(1\),那么就要满足所有的 \([l,r]\) 交集有 \(0\)

放在线段树上维护 \(l,r\) 的交集情况即可。

代码
#include <bits/stdc++.h>
#define int long long

void Freopen() {
    freopen("", "r", stdin);
    freopen("", "w", stdout);
}

using namespace std;
const int N = 5e5 + 10, M = 2e5 + 10, inf = 1e9;

int n, q, m;
int a[N];

int mod[32];

struct sgt {
    #define ls k << 1
    #define rs k << 1 | 1
    #define mid ((l + r) >> 1)

    struct node {
        int l[32], r[32];

        node() {
            memset(l, -1, sizeof l);
            memset(r, -1, sizeof r);
        }
    } tr[N * 4];
    int tag[N * 4];

    node merge( node a, node b) {
        node res = node();

        for ( int i = 0; i < m; i ++) {
            int l1 = a.l[i], r1 = a.r[i], l2 = b.l[i], r2 = b.r[i];
            if (l1 == -1 || l2 == -1) continue ;

            if (l1 > r1) {
                if (l2 > r2) res.l[i] = max(l1, l2), res.r[i] = min(r1, r2);
                else if (l2 <= r1) res.l[i] = l2, res.r[i] = min(r1, r2);
                else if (r2 >= l1) res.r[i] = r2, res.l[i] = max(l1, l2);
            } else {
                if (l2 <= r2) {
                    res.l[i] = max(l1, l2), res.r[i] = min(r1, r2);
                    if (res.l[i] > res.r[i]) res.l[i] = res.r[i] = -1;
                } else {
                    swap(l1, l2), swap(r1, r2);
                    if (l2 <= r1) res.l[i] = l2, res.r[i] = min(r1, r2);
                    else if (r2 >= l1) res.r[i] = r2, res.l[i] = max(l1, l2);
                }
            }
        }

        return res;
    }

    void pst( int k, int v) {
        for ( int i = 0; i < m; i ++) {
            if (tr[k].l[i] == -1) continue ;
            tr[k].l[i] = (tr[k].l[i] - (v & (mod[i] - 1)) + mod[i]) & (mod[i] - 1);
            tr[k].r[i] = (tr[k].r[i] - (v & (mod[i] - 1)) + mod[i]) & (mod[i] - 1);
        }
        tag[k] += v;
    }

    void psd( int k) {
        if (! tag[k]) return ;
        pst(ls, tag[k]), pst(rs, tag[k]);
        tag[k] = 0;
    }

    void build( int k = 1, int l = 1, int r = n) {
        tr[k] = node(), tag[k] = 0;
        if (l == r) {
            for ( int i = 0; i < m; i ++) {
                int x = a[l] & (mod[i] - 1);
                tr[k].l[i] = ((mod[i] >> 1) - x + mod[i]) & (mod[i] - 1);
                tr[k].r[i] = (tr[k].l[i] + (mod[i] >> 1) - 1) & (mod[i] - 1);
            }

            return ;
        }

        build(ls, l, mid), build(rs, mid + 1, r);
        tr[k] = merge(tr[ls], tr[rs]);
    }

    void upd( int x, int y, int v, int k = 1, int l = 1, int r = n) {
        if (x <= l && r <= y) return pst(k, v);
        psd(k);

        if (x <= mid) upd(x, y, v, ls, l, mid);
        if (y > mid) upd(x, y, v, rs, mid + 1, r);

        tr[k] = merge(tr[ls], tr[rs]);
    }

    node ask( int x, int y, int k = 1, int l = 1, int r = n) {
        if (x <= l && r <= y) return tr[k];
        psd(k);

        if (y <= mid) return ask(x, y, ls, l, mid);
        if (x > mid) return ask(x, y, rs, mid + 1, r);

        return merge(ask(x, y, ls, l, mid), ask(x, y, rs, mid + 1, r));
    }
} T;


signed main() {
    ios :: sync_with_stdio(false);
    cin.tie(0), cout.tie(0);

    cin >> n >> q >> m;

    for ( int i = 0; i < m; i ++) mod[i] = (1ll << (i + 1));
    for ( int i = 1; i <= n; i ++) cin >> a[i];

    T.build();

    while (q --) {
        int op, l, r, x;
        cin >> op >> l >> r;
        l ++;

        if (! op) {
            cin >> x;
            T.upd(l, r, x);
        } else {
            auto res = T.ask(l, r);
            int ans = 0;

            for ( int i = 0; i < m; i ++) if (res.l[i] != -1 && (res.l[i] > res.r[i] || res.l[i] == 0)) ans |= (1ll << i);
            cout << ans << '\n';
        }
    }
  
    return 0;
}

Building Forest Trails

先把环断开,如果两个区间有交集,那么就代表这两个区间构成一个连通块。

对于一个连通块,内部的连边是不重要的,可以只连接有用的边。(这个连通块构成的凸多边形上的边)

那么对于不同的连通块,它们之间的边是不交的,只有包含关系。

就比如这个图,\(1\)\(2\)\(4\)\(8\) 是一个连通块,\(3\) 是一个连通块,\(5\)\(6\)\(7\) 是一个连通块。

然后考虑去维护这个包含的关系,记 \(h_x\)\(x\) 上面的弧的个数(端点不算)。

考虑给 \(x\)\(y\) 加边,默认 \(x\lt y\),这里分两种情况。

  1. \(h_x\gt h_y\),那么考虑找到 \(x\) 右边第一个 \(p\) 满足 \(h_p\lt h_x\),显然这个 \(p\lt y\) ,那么把 \(x\)\(p\) 的连通块合并。
  2. \(h_x\lt h_y\),那么考虑找到 \(y\) 左边第一个 \(p\),满足 \(h_p\lt h_y\),然后把 \(p\)\(y\) 连通块合并。
  3. \(h_x=h_y\),考虑找到 \(x\) 右边第一个 \(p\),若 \(p\ge y\) 说明连接了 \(x\)\(y\) 不与其他连通块产生交集,可以直接合并,否则就合并 \(x\)\(p\) 的连通块。

一只做,直到 \(x\)\(y\) 在同一个连通块,合并的总次数是 \(O(n)\) 的。

考虑如何合并?

若合并的两个连通块没有包含关系,那么直接连接第一个连通块的右端点和第二个连通块的左端点即可,也就是把 \((R_1,L_2)\)\(h\) 给加一。

如果合并的两个连通块有包含关系,那么会删掉大连通块包含小连通块的那条边,然后加上小连通块与大连通块的左、右端点相连的两条边,也就是把 \([\max (L_1, L_2),\min(R_1,R_2)]\)\(h\) 减一。

所以维护一颗 \(h\) 的线段树即可,找 \(p\) 可以用线段树二分。

复杂度 \(O(n\log n)\)

代码
#include <bits/stdc++.h>

void Freopen() {
    freopen("", "r", stdin);
    freopen("", "w", stdout);
}

using namespace std;
const int N = 2e5 + 10, M = 2e5 + 10, inf = 1e9, mod = 998244353;

int n, m;

struct DSU {
    int fa[N], L[N], R[N];

    void init() {
        for ( int i = 1; i <= n; i ++) fa[i] = L[i] = R[i] = i;
    }

    int find( int x) {
        return x == fa[x] ? x : fa[x] = find(fa[x]);
    }

    void merge( int u, int v) {
        u = find(u), v = find(v);
        if (u == v) return ;
        fa[v] = u, L[u] = min(L[u], L[v]), R[u] = max(R[u], R[v]);
    }
} D;

struct sgt {
    #define ls k << 1
    #define rs k << 1 | 1
    #define mid ((l + r) >> 1)
    #define lson ls, l, mid
    #define rson rs, mid + 1, r

    int mi[N * 4], tag[N * 4];

    void psu( int k) {
        mi[k] = min(mi[ls], mi[rs]);
    }

    void pst( int k, int v) {
        mi[k] += v, tag[k] += v;
    }

    void psd( int k) {
        if (! tag[k]) return ;
        pst(ls, tag[k]), pst(rs, tag[k]);
        tag[k] = 0;
    }

    void upd( int x, int y, int v, int k = 1, int l = 1, int r = n) {
        if (x > y) return ;
        if (x <= l && r <= y) return pst(k, v);

        psd(k);
        if (x <= mid) upd(x, y, v, lson);
        if (y > mid) upd(x, y, v, rson);

        psu(k);
    }

    int ask( int x, int k = 1, int l = 1, int r = n) {
        if (l == r) return mi[k];

        psd(k);
        return x <= mid ? ask(x, lson) : ask(x, rson);
    }

    int findl( int x, int v, int k = 1, int l = 1, int r = n) {
        if (r <= x) {
            if (mi[k] >= v) return 0;
            if (l == r) return l;

            psd(k);
            if (mi[rs] < v) return findl(x, v, rson);
            return findl(x, v, lson);
        }

        int res = 0;
        psd(k);
        if (x > mid) res = findl(x, v, rson);
        if (res == 0) res = findl(x, v, lson);
        return res; 
    }

    int findr( int x, int v, int k = 1, int l = 1, int r = n) {
        if (l >= x) {
            if (mi[k] >= v) return n + 1;
            if (l == r) return l;

            psd(k);
            if (mi[ls] < v) return findr(x, v, lson);
            return findr(x, v, rson);
        }

        int res = n + 1;
        psd(k);
        if (x <= mid) res = findr(x, v, lson);
        if (res == n + 1) res = findr(x, v, rson);
        return res; 
    }

    #undef ls
    #undef rs
    #undef mid
    #undef lson
    #undef rson
} T;

void merge( int u, int v) {
    u = D.find(u), v = D.find(v);
    if (D.R[u] < D.L[v]) T.upd(D.R[u] + 1, D.L[v] - 1, 1);
    else T.upd(max(D.L[u], D.L[v]), min(D.R[u], D.R[v]), -1);

    D.merge(u, v);
}

signed main() {
    ios :: sync_with_stdio(false);
    cin.tie(0), cout.tie(0);

    cin >> n >> m;
    D.init();

    while (m --) {
        int op, x, y;
        cin >> op >> x >> y;

        if (op == 1) {
            if (x > y) swap(x, y);

            while (D.find(x) != D.find(y)) {
                int vx = T.ask(x), vy = T.ask(y);

                if (vx == vy) {
                    int p = T.findr(x, vx);
                    if (p < y) merge(x, p);
                    else merge(x, y);
                } else if (vx > vy) merge(x, T.findr(x, vx));
                else merge(y, T.findl(y, vy));
            }
        } else  cout << (D.find(x) == D.find(y));
    }

    return 0;
}

点心

没有原题。

给一个 \(n\times m\) 的矩阵。

定义一个子矩阵的价值是其中的元素和的平方,但如果子矩阵中出现了 \(-1\),那么价值为 \(0\)
求所有子矩阵的价值和。

\(n,m\le 2500,-1\le a_{i,j}\le 998244353\)

对于一个子矩阵,它的贡献可以看作 \((A-B-C+D)^2\),把它拆开有:

\[A^2+B^2+C^2+D^2-2AB-2AC-2CD-2BD+2AD+2BC \]

发现本质不同的只有三种类型:

  1. 平方,例如 \(AA\)
  2. 矩形平行项乘积,例如 \(AB\)
  3. 矩形对角线乘积,例如 \(AD\)

考虑对这些东西分开计算。

枚举一个矩形的下边界,然后对每一列求出上方第一个 \(-1\),记 \(h_j\) 表示第 \(j\) 列从下边界到 \(-1\) 的距离。

\(h_j\) 建立笛卡尔树,对于一个 \(h_j\) 为最小值的区间 \([L,R]\),所有跨过 \(j\) 的子矩阵都是合法的,统计即可。

统计需要用到二维前缀和,二维前缀和的二维前缀和,二维前缀和的平方的二维前缀和。具体看代码。

代码
#include <bits/stdc++.h>
// #define long long

void Freopen() {
    freopen("rena.in", "r", stdin);
    freopen("rena.out", "w", stdout);
}

using namespace std;
const int N = 2e5 + 10, M = 2e5 + 10, inf = 1e9, mod = 998244353;

int n, m;
int a[2501][2501];
int s[2501][2501], sum[2501][2501], spw[2501][2501];

int add( int x, int v) {
    x += v;
    return x >= mod ? x - mod : x;
}

int ask( int a, int b, int c, int d, int S[][2501]) {
    a = max(0, a - 1), b = max(0, b - 1);
    return (0ll + S[c][d] - S[a][d] + mod - S[c][b] + mod + S[a][b]) % mod;
}

int h[N];
int ls[N], rs[N], sta[N], top;

int A, B, C;
// A:平方项的和 B:平行项乘积和 C:对角线乘积和

void build( int n) {
    for ( int i = 1; i <= n; i ++) ls[i] = rs[i] = 0;
    top = 0;

    for ( int i = 1; i <= n; i ++) {
        int k = top;
        while (top && h[sta[top]] > h[i]) top --;
        if (k != top) ls[i] = sta[top + 1];
        if (top) rs[sta[top]] = i;
        sta[++ top] = i;
    }
}

void cal1( int u, int l, int r, int i, int op) {
    if (ls[u]) cal1(ls[u], l, u - 1, i, op);
    if (rs[u]) cal1(rs[u], u + 1, r, i, op);
    if (! h[u]) return ;

    A = add(A, 1ll * h[u] * (u - l + 1) % mod * ask(i, u, i, r, spw) % mod);
    A = add(A, 1ll * h[u] * (r - u + 1) % mod * ask(i, l - 1, i, u - 1, spw) % mod);
    B = add(B, 1ll * h[u] * ask(i, l - 1, i, u - 1, sum) % mod * ask(i, u, i, r, sum) % mod);

    if (op) return ;
    C = add(C, 1ll * ask(i, u, i, r, sum) * ask(i - h[u], l - 1, i - 1, u - 1, sum) % mod);
    C = add(C, 1ll * ask(i, l - 1, i, u - 1, sum) * ask(i - h[u], u, i - 1, r, sum) % mod);
}

void cal2( int u, int l, int r, int i) {
    if (ls[u]) cal2(ls[u], l, u - 1, i);
    if (rs[u]) cal2(rs[u], u + 1, r, i);
    if (! h[u]) return ;

    B = add(B, 1ll * h[u] * ask(l - 1, i, u - 1, i, sum) % mod * ask(u, i, r, i, sum) % mod);
}

signed main() {
    Freopen();

    ios :: sync_with_stdio(false);
    cin.tie(0), cout.tie(0);

    cin >> n >> m;

    for ( int i = 1; i <= n; i ++)
        for ( int j = 1; j <= m; j ++)
            cin >> a[i][j];

    for ( int i = 1; i <= n; i ++)
        for ( int j = 1; j <= m; j ++) {
            s[i][j] = (0ll + (a[i][j] == -1 ? 0 : a[i][j]) + s[i - 1][j] + s[i][j - 1] - s[i - 1][j - 1] + mod) % mod;
            sum[i][j] = (0ll + s[i][j] + sum[i - 1][j] + sum[i][j - 1] - sum[i - 1][j - 1] + mod) % mod;
            spw[i][j] = (0ll + 1ll * s[i][j] * s[i][j] % mod + spw[i - 1][j] + spw[i][j - 1] - spw[i - 1][j - 1] + mod) % mod;
        }

    for ( int i = 1; i <= m; i ++ ) h[i] = 0;

    for ( int i = 1; i <= n; i ++) {
        for ( int j = 1; j <= m; j ++) h[j] = (a[i][j] == -1 ? 0 : h[j] + 1);
        build(m);
        cal1(sta[1], 1, m, i, 0);
    }

    for ( int i = 1; i <= m; i ++) h[i] = 0;

    for ( int i = n; i; i --) {
        for ( int j = 1; j <= m; j ++) h[j] = (a[i][j] == -1 ? 0 : h[j] + 1);
        build(m);
        cal1(sta[1], 1, m, i - 1, 1);
    }

    for ( int i = 1; i <= n; i ++) h[i] = 0;

    for ( int j = 1; j <= m; j ++) {
        for ( int i = 1; i <= n; i ++) h[i] = (a[i][j] == -1 ? 0 : h[i] + 1);
        build(n);
        cal2(sta[1], 1, n, j);
    }

    for ( int i = 1; i <= n; i ++) h[i] = 0;

    for ( int j = m; j; j --) {
        for ( int i = 1; i <= n; i ++) h[i] = (a[i][j] == -1 ? 0 : h[i] + 1);
        build(n);
        cal2(sta[1], 1, n, j - 1);
    }

    cout << ((0ll + A - B - B + C + C) % mod + mod) % mod << '\n';

    return 0;
}
posted @ 2025-09-08 18:38  咚咚的锵  阅读(37)  评论(0)    收藏  举报