mex数组
给定数列 \(a\),对于每个 \(i\) 询问多少个区间的 mex 等于 \(i\)
考虑转化成求多少个 \(\text{mex} \ge i\) 的区间,设为 \(ans_i\),容易发现,设右端点为 \(r\), \(p_i = \max\{j\le r|a_j=i\}\),则 mex 不小于于 \(i\) 的合法左端点数量为 \(\min \{p_j | j < i\}\)
每次暴力维护 \(p_i\),你就有了 50 分的做法。。
再考虑神秘做法,显然 \(ans_0 = {n(n+1)\over 2}\),我们从 \(0~n\) 枚举一个数 \(i\),只考虑 \(ans_i\),容易发现实际上比 \(i\) 大的数怎么出现对 \(ans_i\) 都没有影响,我们要维护的是 \(p\) 从 \(0~i\) 的前缀最小值。\(ans_i\) 实际上就是我们之前右端点枚举到每个位置时的前缀最小值之和。
考虑 \(0~i-1\) 的贡献我们都已经算过,我们枚举序列里每一个为 \(i\) 的位置,它对序列 \(p_i\) 的影响 本质上还是一个前缀取 \(\min\),可以吉司机线段树维护。
#include <bits/stdc++.h>
using namespace std;
const int N = 2e6 + 5;
int mx[N << 2], mn[N << 2], lazy[N << 2];
long long sum[N << 2];
inline void add(int now, int x, int l, int r) {
mx[now] = mn[now] = lazy[now] = x;
sum[now] = 1ll * (r - l + 1) * x;
}
inline void update(int now) {
mn[now] = mn[now << 1]; mx[now] = mx[now << 1 | 1];
sum[now] = sum[now << 1] + sum[now << 1 | 1];
}
inline void pushdown(int now, int l, int r) {
if (lazy[now] == -1) return ;
add(now << 1, lazy[now], l, l + 1 >> 1);
add(now << 1 | 1, lazy[now], (l + r >> 1) + 1, r);
lazy[now] = -1;
}
inline void build(int now, int l, int r) {
lazy[now] = -1;
if (l >= r) {
add(now, l, l, r); lazy[now] = -1;
return ;
} build(now << 1, l, l + r >> 1);
build(now << 1 | 1, (l + r >> 1) + 1, r);
update(now);
}
inline void modify(int now, int l, int r, int q, int x) {
if (q < l || mx[now] <= x) return ;
if (r <= q && mn[now] >= x) {
add(now, x, l, r);
return ;
} pushdown(now, l, r);
int mid = l + r >> 1;
modify(now << 1, l, mid, q, x);
if (q > mid) modify(now << 1 | 1, mid + 1, r, q, x);
update(now);
}
inline void otp(long long x) {
(x >= 10) ?otp(x / 10), putchar((x % 10) ^ 48) : putchar(x ^ 48);
}
inline int read() {
register int s = 0; register char ch = getchar();
while (!isdigit(ch)) ch = getchar();
while (isdigit(ch)) s = (s << 1) + (s << 3) + (ch & 15), ch = getchar();
return s;
}
int a[N], lst[N], pre[N];
long long ans[N];
int main() {
int n = read();
for (int i = 1; i <= n; ++i) {
a[i] = read();
pre[i] = lst[a[i]]; lst[a[i]] = i;
}
ans[0] = 1ll * n * (n + 1) / 2;
build(1, 1, n);
for (int i = 0; i <= n; ++i) {
int l = n;
for (int j = lst[i]; j; l = j - 1, j = pre[j]) {
modify(1, 1, n, l, j);
}
modify(1, 1, n, l, 0);
ans[i + 1] += sum[1];
otp(ans[i] - ans[i + 1]); putchar(' ');
} return 0;
}

浙公网安备 33010602011771号