HDU 4747 Mex 线段树

题意:

\(S\)为一个自然数集合,定义函数\(mex(S)\)为集合中没有出现的最小自然数。
给出一个长度为为\(n\)序列\(a\),设\(S_{l,r}\)表示由\(a_l \sim a_r\)构成的集合。
求:

\[\sum\limits_{1 \leq l \leq r \leq n}mex(S_{l,r}) \]

分析:

有这样一个事实:往集合\(S\)中任意加入一个元素,\(mex(S)\)的值不会变小。
固定区间左端点来统计答案。
首先计算一下\(mex{S_{1,1}},mex{S_{1,2}}, \cdots, mex{S_{1,n}}\),所以这是一个非递减的序列。
假设现在计算出\(mex{S_{i,i}},mex{S_{i,i+1}}, \cdots, mex{S_{i,n}}\),考虑区间左端点向右移动。
相当于从这些集合中都删去了一个\(a_i\),如果有一个最小的\(j>i\)\(a_i=a_j\),那么删去\(a_i\)\([j,n]\)这段区间没有影响,因为这段区间对应的集合没有改变。
然后考虑区间\([i+1,j-1]\),找到\(mex\)值大于\(a_i\)的区间,把它们的值都变为\(a_i\)
因为集合中少了\(a_i\),所以根据\(mex\)函数的定义,\(mex\)值为\(a_i\)
而且由于区间是非递减的,所以\(mex\)值大于\(a_i\)的区间也是连续的,用线段树维护即可。

#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;

typedef long long LL;
const int maxn = 200000 + 10;
const int maxnode = maxn * 4;

int n;
int a[maxn], b[maxn], tot;
int pos[maxn], nxt[maxn];

bool vis[maxn];
int mex[maxn];

//Segment Tree
LL sum[maxnode];
int setv[maxnode], minv[maxnode], maxv[maxnode];

void pushup(int o) {
    sum[o] = sum[o<<1] + sum[o<<1|1];
    minv[o] = min(minv[o<<1], minv[o<<1|1]);
    maxv[o] = max(maxv[o<<1], maxv[o<<1|1]);
}

void build(int o, int L, int R) {
    if(L == R) {
        sum[o] = minv[o] = maxv[o] = mex[L];
        return;
    }
    int M = (L + R) / 2;
    build(o<<1, L, M);
    build(o<<1|1, M+1, R);
    pushup(o);
}

void pushdown(int o, int L, int R) {
    if(setv[o] != -1) {
        int lc = o<<1, rc = o<<1|1;
        setv[lc] = setv[rc] = setv[o];
        minv[lc] = minv[rc] = setv[o];
        maxv[lc] = maxv[rc] = setv[o];
        int M = (L + R) / 2;
        sum[lc] = (LL)setv[o] * (M - L + 1);
        sum[rc] = (LL)setv[o] * (R - M);
        setv[o] = -1;
    }
}

void update(int o, int L, int R, int qL, int qR, int v) {
    if(qL <= L && R <= qR && minv[o] > v) {
        setv[o] = minv[o] = maxv[o] = v;
        sum[o] = (LL)v * (R - L + 1);
        return;
    }
    pushdown(o, L, R);
    int M = (L + R) / 2;
    if(qL <= M && maxv[o<<1] > v) update(o<<1, L, M, qL, qR, v);
    if(qR > M && maxv[o<<1|1] > v) update(o<<1|1, M+1, R, qL, qR, v);
    pushup(o);
}

LL query(int o, int L, int R, int qL, int qR) {
    if(qL <= L && R <= qR) return sum[o];
    pushdown(o, L, R);
    int M = (L + R) / 2;
    LL ans = 0;
    if(qL <= M) ans += query(o<<1, L, M, qL, qR);
    if(qR > M) ans += query(o<<1|1, M+1, R, qL, qR);
    return ans;
}

int main()
{
    while(scanf("%d", &n) == 1 && n) {
        for(int i = 1; i <= n; i++) {
            scanf("%d", a + i);
            if(a[i] >= maxn) a[i] = maxn - 1;
            b[i] = a[i];
        }
        sort(b + 1, b + 1 + n);
        tot = unique(b + 1, b + 1 + n) - b - 1;
        for(int i = 1; i <= n; i++)
            a[i] = lower_bound(b + 1, b + 1 + tot, a[i]) - b;
        for(int i = 1; i <= tot; i++) pos[i] = n + 1;
        for(int i = n; i > 0; i--) {
            nxt[i] = pos[a[i]];
            pos[a[i]] = i;
        }

        memset(vis, false, sizeof(vis));
        int p = 0;
        for(int i = 1; i <= n; i++) {
            vis[b[a[i]]] = true;
            while(vis[p]) p++;
            mex[i] = p;
        }

        memset(setv, -1, sizeof(setv));
        build(1, 1, n);
        LL ans = sum[1];
        for(int i = 2; i <= n; i++) {
            int j = nxt[i - 1];
            if(j > i) update(1, 1, n, i, j - 1, b[a[i-1]]);
            ans += query(1, 1, n, i, n);
        }

        printf("%lld\n", ans);
    }

    return 0;
}
posted @ 2016-04-01 10:52  AOQNRMGYXLMV  阅读(183)  评论(0编辑  收藏  举报