Loading

Luogu P14379 【MX-S9-T2】「LAOI-16」摩天大楼 题解

Link

Broken Prob Description. 形式化题意如下:

维护一个长度为 \(n\) 的序列 \(a\),要求支持单点修改,每次修改后需要对所有区间 \([l, r]\) 计算 \(f(l, r)\) 之和,其中 \(1 \leq l \lt r \leq n\)。定义 \(f(l, r)\) 的计算法则如下:当且仅当存在一个正整数 \(k\) 使得 \(k \in [l, r)\) 使得左半区间 \([l, k]\)\(\rm mex\) 不等于右半区间 \([k + 1, r]\)\(\rm mex\)\(f(l, r) = 1\),否则 \(f(l, r) = 0\),这里的 \(\rm mex\) 的定义域为 \([1, \infty)\)

这个 \(\rm mex\) 的限制比较棘手,套路化的,我们先求出所有 \([l, r]\) 对数 \(\frac{n(n - 1)}{2}\),然后从方案中扣去不产生贡献的 \(f(l, r) = 0\) 的情况。我们从 \(\rm mex\) 的角度入手,注意到其作为最小的缺失的正整数,总会是序列中从 \(1\) 开始的连续整数集合的第一个缺口。也就是说,\(\rm mex\) 每次都是从 \(1\) 开始向上找。注意到我们的目标仅仅是找到两个 \(\rm mex\) 不同的区间,数字 \(1, 2\) 作为两个最小的两个正整数决定了 \(\rm mex\) 的基本行为,只要区间中同时包含 \(1, 2\) 就一定能找到一个分割点使得左右 \(\rm mex\) 不同,无非是一边区间少了一种数或者都少了。

因此,关键的发现为,\(f(l, r) = 0\) 的情况有且仅有两种:

  • 区间中不存在 \(1\)
  • 区间形如 \([x_l = 1, \dots, x_r = 1]\),其中 \([l + 1, r - 1]\)\(2\)

统计答案时,具体地:

  • 记录极长不含 \(1\) 段长度为 \(x\),则有贡献为 \(0\) 的区间 \(\frac{x(x - 1)}{2}\)
  • 记录极长无 \(2\) 段中 \(1\) 的个数为 \(c\),则有贡献为 \(0\) 的区间 \(\frac{c(c - 1)}{2}\)

对于这两种情况各自维护出线段树,对于无 \(1\) 段维护每个区间左侧极长非 \(1\) 段长度、右侧极长非 \(1\) 段长度以及是否区间内都非 \(1\);对于无 \(2\) 段维护每个区间左、右侧的极长非 \(2\) 段中 \(1\) 的个数以及是否区间内都非 \(2\)

修改时跑线段树合并,复杂度是 \(O(q \log n)\) 的。

#include <bits/stdc++.h>

using i64 = long long;

constexpr int N = 1e6 + 7;

int n, q;
int a[N];

struct Seg0 {
    struct node {
        int zl, zr, l, r, len;
        i64 ans;
    } tr[N << 2];

    #define ls(o) (o << 1)
    #define rs(o) (o << 1 | 1)

    i64 calc(int len) {
        return 1ll * len * (len + 1) / 2;
    }

    node merge(node L, node R) {
        node res;
        res.len = L.len + R.len;
        if (L.zl == L.len) {
            res.zl = L.zl + R.zl;
            res.l = L.l + R.l;
        } else {
            res.zl = L.zl;
            res.l = L.l;
        }
        if (R.zr == R.len) {
            res.zr = R.zr + L.zr;
            res.r = R.r + L.r;
        } else {
            res.zr = R.zr;
            res.r = R.r;
        }
        res.ans = L.ans + R.ans;
        res.ans += 1ll * L.r * R.l;
        return res;
    }

    void build(int o, int l, int r) {
        if (l == r) {
            if (a[l] == 1) tr[o] = {1, 1, 1, 1, 1, 1};
            else if (a[l] == 2) tr[o] = {0, 0, 0, 0, 1, 0};
            else tr[o] = {1, 1, 0, 0, 1, 0};
            return;
        }
        int mid = (l + r) >> 1;
        build(ls(o), l, mid);
        build(rs(o), mid + 1, r);
        tr[o] = merge(tr[ls(o)], tr[rs(o)]);
    }

    void update(int o, int l, int r, int p) {
        if (l == r) {
            if (a[l] == 1) tr[o] = {1, 1, 1, 1, 1, 1};
            else if (a[l] == 2) tr[o] = {0, 0, 0, 0, 1, 0};
            else tr[o] = {1, 1, 0, 0, 1, 0};
            return;
        }
        int mid = (l + r) >> 1;
        if (p <= mid) update(ls(o), l, mid, p);
        else update(rs(o), mid + 1, r, p);
        tr[o] = merge(tr[ls(o)], tr[rs(o)]);
    }

    node query(int o, int l, int r, int ql, int qr) {
        if (ql <= l && r <= qr) return tr[o];
        int mid = (l + r) >> 1;
        if (qr <= mid) return query(ls(o), l, mid, ql, qr);
        if (ql > mid) return query(rs(o), mid + 1, r, ql, qr);
        return merge(query(ls(o), l, mid, ql, mid), query(rs(o), mid + 1, r, mid + 1, qr));
    }
} seg0;

struct Seg1 {
    struct node {
        int l, r, len;
        i64 ans;
    } tr[N << 2];

    #define ls(o) (o << 1)
    #define rs(o) (o << 1 | 1)

    node merge(node L, node R) {
        node res;
        res.len = L.len + R.len;
        if (L.l == L.len) res.l = L.l + R.l;
        else res.l = L.l;
        if (R.r == R.len) res.r = R.r + L.r;
        else res.r = R.r;
        res.ans = L.ans + R.ans;
        res.ans += 1ll * L.r * R.l;
        return res;
    }

    void build(int o, int l, int r) {
        if (l == r) {
            if (a[l] == 1) tr[o] = {0, 0, 1, 0};
            else tr[o] = {1, 1, 1, 1};
            return;
        }
        int mid = (l + r) >> 1;
        build(ls(o), l, mid);
        build(rs(o), mid + 1, r);
        tr[o] = merge(tr[ls(o)], tr[rs(o)]);
    }

    void update(int o, int l, int r, int p) {
        if (l == r) {
            if (a[l] == 1) tr[o] = {0, 0, 1, 0};
            else tr[o] = {1, 1, 1, 1};
            return;
        }
        int mid = (l + r) >> 1;
        if (p <= mid) update(ls(o), l, mid, p);
        else update(rs(o), mid + 1, r, p);
        tr[o] = merge(tr[ls(o)], tr[rs(o)]);
    }

    node query(int o, int l, int r, int ql, int qr) {
        if (ql <= l && r <= qr) return tr[o];
        int mid = (l + r) >> 1;
        if (qr <= mid) return query(ls(o), l, mid, ql, qr);
        if (ql > mid) return query(rs(o), mid + 1, r, ql, qr);
        return merge(query(ls(o), l, mid, ql, mid), query(rs(o), mid + 1, r, mid + 1, qr));
    }
} seg1;

int main() {
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);

    std::cin >> n >> q;
    for (int i = 1; i <= n; i++) {
        std::cin >> a[i];
    }
    seg0.build(1, 1, n);
    seg1.build(1, 1, n);
    while (q--) {
        int x, v;
        std::cin >> x >> v;
        a[x] = v;
        seg0.update(1, 1, n, x);
        seg1.update(1, 1, n, x);
        i64 ans = 1ll * n * (n + 1) / 2;
        ans -= seg0.query(1, 1, n, 1, n).ans;
        ans -= seg1.query(1, 1, n, 1, n).ans;
        std::cout << ans << "\n";
    }
    return 0;
}
posted @ 2025-11-05 15:56  夢回路  阅读(6)  评论(0)    收藏  举报