hdu 4747 (线段树)

题意:有一个序列a[],mex(L, R)表示区间a在区间[L, R]上第一个没出现的最小非负整数,对于序列a[],求所有的mex(L, R)的和(1 <= L <= R <= n,1 <= n <= 200000,0 <= ai <= 10^9)。

求出所有的mex(1, i);接着删去第1个结点,就是所有的mex(2, i);接着再删去第1个结点,就是所有的mex(3, i);……最后就是mex(n, n),求和即是答案。

  1 #include <iostream>
  2 #include <cstdio>
  3 #include <cstring>
  4 #include <algorithm>
  5 #include <cmath>
  6 
  7 using namespace std;
  8 #define ls rt<<1
  9 #define rs rt<<1|1
 10 #define lson l, m, rt<<1
 11 #define rson m + 1, r, rt<<1|1
 12 typedef long long ll;
 13 const int maxn = 2e5 + 5;
 14 int n, a[maxn], vis[maxn], next[maxn], mex1[maxn << 2];
 15 struct SegTree{
 16     int Max, lazy;
 17     ll sum;
 18 }seg[maxn << 2];
 19 
 20 void pushUp(int rt){
 21     seg[rt].sum = seg[ls].sum + seg[rs].sum;
 22     seg[rt].Max = max(seg[ls].Max, seg[rs].Max);
 23 }
 24 void pushDown(int rt, int len){
 25     if (seg[rt].lazy != -1){
 26         seg[ls].lazy = seg[rs].lazy = seg[rt].lazy;
 27         seg[ls].Max = seg[rt].lazy;
 28         seg[rs].Max = seg[rt].lazy;
 29         seg[ls].sum = (ll)((len + 1) / 2) * ((ll)seg[rt].lazy);
 30         seg[rs].sum = (ll)(len / 2) * ((ll)seg[rt].lazy);
 31         seg[rt].lazy = -1;
 32     }
 33 }
 34 void build(int l, int r, int rt){
 35     seg[rt].lazy = -1;
 36     if (l == r){
 37         seg[rt].sum = seg[rt].Max = mex1[l];
 38         return ;
 39     }
 40     int m = (l + r) >> 1;
 41     build(lson);
 42     build(rson);
 43     pushUp(rt);
 44 }
 45 int find(int key, int l, int r, int rt){//找到第一个mex大于a[i]的下标
 46     if (l == r) return l;
 47     pushDown(rt, r - l + 1);
 48     int m = (l + r) >> 1;
 49     if (seg[ls].Max > key) return find(key, lson);
 50     else return find(key, rson);
 51 }
 52 void update(int val, int L, int R, int l, int r, int rt){
 53     if (L <= l && r <= R){
 54         //pushDown(rt, r - l + 1);
 55         seg[rt].Max = val;
 56         seg[rt].sum = (ll) val * (ll) (r - l + 1);
 57         seg[rt].lazy = val;
 58         return ;
 59     }
 60     int m = (l + r) >> 1;
 61     //cout << seg[rt].lazy << " l = " << l << " r= " << r << endl;
 62     pushDown(rt, r - l + 1);
 63     if (L <= m ) update(val, L, R, lson);
 64     if (R > m) update(val, L, R, rson);
 65     pushUp(rt);
 66     /*if (L == 2 && R == 4){
 67         cout << " ll = " << l << " rr = " << r << endl;
 68         cout << " sum = " << seg[rt].sum << " rt = " << rt << endl;
 69     }*/
 70 }
 71 int main(){
 72     while (~scanf("%d", &n) && n){
 73         for (int i = 1; i <= n; ++i){
 74             scanf("%d", &a[i]);
 75             if (a[i] > n) a[i] = n;
 76         }
 77         //得到mex[1, i]
 78         int tmp = 0;
 79         memset(vis, 0, sizeof(vis));
 80         for (int i = 1; i <= n; ++i){
 81             vis[a[i]] = 1;
 82             while(vis[tmp]) tmp++;
 83             mex1[i] = tmp;
 84         }
 85         //得到next值
 86         for (int i = 0; i <= n; ++i) vis[i] = n + 1;
 87         for (int i = n; i >= 1; --i){
 88             next[i] = vis[a[i]];
 89             vis[a[i]] = i;
 90         }
 91         build(1, n, 1);
 92         ll ans = 0;
 93         for (int i = 1; i <= n; ++i){
 94         //    cout << seg[1].sum << endl;
 95             ans += seg[1].sum;
 96             update(0, i, i, 1, n, 1);
 97             if (a[i] < seg[1].Max){
 98                 int l = find(a[i], 1, n, 1), r = next[i] - 1;
 99                 //cout << " l = " << l << " r = " << r << endl;
100                 if (l <= r) update(a[i], l, r, 1, n, 1);
101             }
102         }
103         printf("%I64d\n", ans);
104     }
105     return 0;
106 }

 

posted @ 2013-10-10 15:51  Missa  阅读(1012)  评论(0编辑  收藏  举报