Luogu P14379 【MX-S9-T2】「LAOI-16」摩天大楼 题解
Broken Prob Description. 形式化题意如下:
维护一个长度为 \(n\) 的序列 \(a\),要求支持单点修改,每次修改后需要对所有区间 \([l, r]\) 计算 \(f(l, r)\) 之和,其中 \(1 \leq l \lt r \leq n\)。定义 \(f(l, r)\) 的计算法则如下:当且仅当存在一个正整数 \(k\) 使得 \(k \in [l, r)\) 使得左半区间 \([l, k]\) 的 \(\rm mex\) 不等于右半区间 \([k + 1, r]\) 的 \(\rm mex\) 时 \(f(l, r) = 1\),否则 \(f(l, r) = 0\),这里的 \(\rm mex\) 的定义域为 \([1, \infty)\)。
这个 \(\rm mex\) 的限制比较棘手,套路化的,我们先求出所有 \([l, r]\) 对数 \(\frac{n(n - 1)}{2}\),然后从方案中扣去不产生贡献的 \(f(l, r) = 0\) 的情况。我们从 \(\rm mex\) 的角度入手,注意到其作为最小的缺失的正整数,总会是序列中从 \(1\) 开始的连续整数集合的第一个缺口。也就是说,\(\rm mex\) 每次都是从 \(1\) 开始向上找。注意到我们的目标仅仅是找到两个 \(\rm mex\) 不同的区间,数字 \(1, 2\) 作为两个最小的两个正整数决定了 \(\rm mex\) 的基本行为,只要区间中同时包含 \(1, 2\) 就一定能找到一个分割点使得左右 \(\rm mex\) 不同,无非是一边区间少了一种数或者都少了。
因此,关键的发现为,\(f(l, r) = 0\) 的情况有且仅有两种:
- 区间中不存在 \(1\)
- 区间形如 \([x_l = 1, \dots, x_r = 1]\),其中 \([l + 1, r - 1]\) 非 \(2\)
统计答案时,具体地:
- 记录极长不含 \(1\) 段长度为 \(x\),则有贡献为 \(0\) 的区间 \(\frac{x(x - 1)}{2}\) 个
- 记录极长无 \(2\) 段中 \(1\) 的个数为 \(c\),则有贡献为 \(0\) 的区间 \(\frac{c(c - 1)}{2}\) 个
对于这两种情况各自维护出线段树,对于无 \(1\) 段维护每个区间左侧极长非 \(1\) 段长度、右侧极长非 \(1\) 段长度以及是否区间内都非 \(1\);对于无 \(2\) 段维护每个区间左、右侧的极长非 \(2\) 段中 \(1\) 的个数以及是否区间内都非 \(2\)。
修改时跑线段树合并,复杂度是 \(O(q \log n)\) 的。
#include <bits/stdc++.h>
using i64 = long long;
constexpr int N = 1e6 + 7;
int n, q;
int a[N];
struct Seg0 {
struct node {
int zl, zr, l, r, len;
i64 ans;
} tr[N << 2];
#define ls(o) (o << 1)
#define rs(o) (o << 1 | 1)
i64 calc(int len) {
return 1ll * len * (len + 1) / 2;
}
node merge(node L, node R) {
node res;
res.len = L.len + R.len;
if (L.zl == L.len) {
res.zl = L.zl + R.zl;
res.l = L.l + R.l;
} else {
res.zl = L.zl;
res.l = L.l;
}
if (R.zr == R.len) {
res.zr = R.zr + L.zr;
res.r = R.r + L.r;
} else {
res.zr = R.zr;
res.r = R.r;
}
res.ans = L.ans + R.ans;
res.ans += 1ll * L.r * R.l;
return res;
}
void build(int o, int l, int r) {
if (l == r) {
if (a[l] == 1) tr[o] = {1, 1, 1, 1, 1, 1};
else if (a[l] == 2) tr[o] = {0, 0, 0, 0, 1, 0};
else tr[o] = {1, 1, 0, 0, 1, 0};
return;
}
int mid = (l + r) >> 1;
build(ls(o), l, mid);
build(rs(o), mid + 1, r);
tr[o] = merge(tr[ls(o)], tr[rs(o)]);
}
void update(int o, int l, int r, int p) {
if (l == r) {
if (a[l] == 1) tr[o] = {1, 1, 1, 1, 1, 1};
else if (a[l] == 2) tr[o] = {0, 0, 0, 0, 1, 0};
else tr[o] = {1, 1, 0, 0, 1, 0};
return;
}
int mid = (l + r) >> 1;
if (p <= mid) update(ls(o), l, mid, p);
else update(rs(o), mid + 1, r, p);
tr[o] = merge(tr[ls(o)], tr[rs(o)]);
}
node query(int o, int l, int r, int ql, int qr) {
if (ql <= l && r <= qr) return tr[o];
int mid = (l + r) >> 1;
if (qr <= mid) return query(ls(o), l, mid, ql, qr);
if (ql > mid) return query(rs(o), mid + 1, r, ql, qr);
return merge(query(ls(o), l, mid, ql, mid), query(rs(o), mid + 1, r, mid + 1, qr));
}
} seg0;
struct Seg1 {
struct node {
int l, r, len;
i64 ans;
} tr[N << 2];
#define ls(o) (o << 1)
#define rs(o) (o << 1 | 1)
node merge(node L, node R) {
node res;
res.len = L.len + R.len;
if (L.l == L.len) res.l = L.l + R.l;
else res.l = L.l;
if (R.r == R.len) res.r = R.r + L.r;
else res.r = R.r;
res.ans = L.ans + R.ans;
res.ans += 1ll * L.r * R.l;
return res;
}
void build(int o, int l, int r) {
if (l == r) {
if (a[l] == 1) tr[o] = {0, 0, 1, 0};
else tr[o] = {1, 1, 1, 1};
return;
}
int mid = (l + r) >> 1;
build(ls(o), l, mid);
build(rs(o), mid + 1, r);
tr[o] = merge(tr[ls(o)], tr[rs(o)]);
}
void update(int o, int l, int r, int p) {
if (l == r) {
if (a[l] == 1) tr[o] = {0, 0, 1, 0};
else tr[o] = {1, 1, 1, 1};
return;
}
int mid = (l + r) >> 1;
if (p <= mid) update(ls(o), l, mid, p);
else update(rs(o), mid + 1, r, p);
tr[o] = merge(tr[ls(o)], tr[rs(o)]);
}
node query(int o, int l, int r, int ql, int qr) {
if (ql <= l && r <= qr) return tr[o];
int mid = (l + r) >> 1;
if (qr <= mid) return query(ls(o), l, mid, ql, qr);
if (ql > mid) return query(rs(o), mid + 1, r, ql, qr);
return merge(query(ls(o), l, mid, ql, mid), query(rs(o), mid + 1, r, mid + 1, qr));
}
} seg1;
int main() {
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
std::cin >> n >> q;
for (int i = 1; i <= n; i++) {
std::cin >> a[i];
}
seg0.build(1, 1, n);
seg1.build(1, 1, n);
while (q--) {
int x, v;
std::cin >> x >> v;
a[x] = v;
seg0.update(1, 1, n, x);
seg1.update(1, 1, n, x);
i64 ans = 1ll * n * (n + 1) / 2;
ans -= seg0.query(1, 1, n, 1, n).ans;
ans -= seg1.query(1, 1, n, 1, n).ans;
std::cout << ans << "\n";
}
return 0;
}

浙公网安备 33010602011771号