HDU #4747 MEX (线段树的应用)
题目描述:
定义$mex(i,j)$为序列中第$i$项到第$j$项中没有出现的最小自然数。给定序列,求$\sum^{n}_{i=1}\sum^{n}_{j=i}mex(i,j)$。
解题思路:
首先我们可以$O(n)$预处理出$mex(1,1\sim n)$,因为显然的是$mex$是递增的。然后我们考虑怎么从$mex(i,i\sim n)$推出$mex(i+1,i+1\sim n)$,我们删掉$a_i$这个数后,哪些区间的$mex$会改变呢?其实就是到下一个与$a_i$相等的数出现前$mex$大于$a_i$的区间,因为这段区间没有了$a_i$这个数,而他们原本的mex却大于$a_i$,所以可以变小。所以要区间查询、修改、求和,用线段树就可以了。
代码:
#include <cstdio> #include <cstring> #include <algorithm> #include <map> #define i64 long long using namespace std; const int N = 2e5 + 10; int n, a[N], mex[N], nxt[N]; i64 ans; map<int, int> mp; struct node { int s, mx, tag; } tr[N * 8]; void init() { int now = 0; for (int i = 1; i <= n; i ++) { mp[a[i]] = 1; while (mp.count(now)) now ++; mex[i] = now; } mp.clear(); for (int i = n; i; i --) { if (mp.count(a[i])) nxt[i] = mp[a[i]]; else nxt[i] = n + 1; mp[a[i]] = i; } } void build(int o, int l, int r) { if (l == r) { tr[o].s = tr[o].mx = mex[l]; tr[o].tag = -1; return; } tr[o].tag = -1; int m = l + r >> 1; build(o << 1, l, m); build(o << 1 | 1, m + 1, r); tr[o].s = tr[o << 1].s + tr[o << 1 | 1].s; tr[o].mx = max(tr[o << 1].mx, tr[o << 1 | 1].mx); } void pushdown(int o, int l, int r) { if (tr[o].tag == -1) return; tr[o << 1].tag = tr[o << 1 | 1].tag = tr[o].tag; tr[o].s = tr[o].tag * (r - l + 1); tr[o].mx = tr[o].tag; tr[o].tag = -1; } int find(int o, int l, int r, int v) { if (l == r) return l; int m = l + r >> 1; pushdown(o << 1, l, m); pushdown(o << 1 | 1, m + 1, r); if (tr[o << 1].mx > v) return find(o << 1, l, m, v); else return find(o << 1 | 1, m + 1, r, v); } void updata(int o, int l, int r) { int m = l + r >> 1, x, y; if (tr[o << 1].tag != -1) x = tr[o << 1].tag; else x = tr[o << 1].mx; if (tr[o << 1 | 1].tag != -1) y = tr[o << 1 | 1].tag; else y = tr[o << 1 | 1].mx; tr[o].mx = max(x, y); if (tr[o << 1].tag != -1) x = tr[o << 1].tag * (m - l + 1); else x = tr[o << 1].s; if (tr[o << 1 | 1].tag != -1) y = tr[o << 1 | 1].tag * (r - m); else y = tr[o << 1 | 1].s; tr[o].s = x + y; } void modify(int o, int l, int r, int x, int y, int v) { if (x <= l && r <= y) { tr[o].tag = v; return; } pushdown(o, l, r); int m = l + r >> 1; if (x <= m) modify(o << 1, l, m, x, y, v); if (y > m) modify(o << 1 | 1, m + 1, r, x, y, v); updata(o, l, r); } int query(int o, int l, int r, int x, int y) { pushdown(o, l, r); if (x <= l && r <= y) return tr[o].s; int m = l + r >> 1, t = 0; if (x <= m) t = query(o << 1, l, m, x, y); if (y > m) t += query(o << 1 | 1, m + 1, r, x, y); return t; } void work() { ans += (i64)query(1, 1, n, 1, n - 1); for (int i = 1; i < n - 1; i ++) { pushdown(1, 1, n); int k = find(1, 1, n, a[i]); if (k < nxt[i]) modify(1, 1, n, k, nxt[i] - 1, a[i]); ans += (i64)query(1, 1, n, i + 1, n - 1); } printf("%lld", ans); } int main() { scanf("%d", &n); for (int i = 1; i <= n; i ++) scanf("%d", &a[i]); init(); mex[++ n] = N; build(1, 1, n); work(); return 0; }

浙公网安备 33010602011771号