[bzoj3196]Tyvj 1730 二逼平衡树——线段树套平衡树

题目

Description
您需要写一种数据结构(可参考题目标题),来维护一个有序数列,其中需要提供以下操作:
1.查询k在区间内的排名
2.查询区间内排名为k的值
3.修改某一位值上的数值
4.查询k在区间内的前驱(前驱定义为小于x,且最大的数)
5.查询k在区间内的后继(后继定义为大于x,且最小的数)
Input
第一行两个数 n,m 表示长度为n的有序序列和m个操作
第二行有n个数,表示有序序列
下面有m行,opt表示操作标号
若opt=1 则为操作1,之后有三个数l,r,k 表示查询k在区间[l,r]的排名
若opt=2 则为操作2,之后有三个数l,r,k 表示查询区间[l,r]内排名为k的数
若opt=3 则为操作3,之后有两个数pos,k 表示将pos位置的数修改为k
若opt=4 则为操作4,之后有三个数l,r,k 表示查询区间[l,r]内k的前驱
若opt=5 则为操作5,之后有三个数l,r,k 表示查询区间[l,r]内k的后继
Output
对于操作1,2,4,5各输出一行,表示查询结果

题解

树套树模板题。
对于每一个线段树上的节点,我们都在上面建一个平衡树。
对于操作1,我们把一个长度为len的区间分解为\(\lfloor log_2(len) \rfloor\)个子区间,分别处理,把每个区间的排名加起来就好了。
对于操作2,我们二分答案,对于二分出来的一个答案执行操作1,check以下即可。
对于操作3,我们删除再插入。
对于操作4和操作5,我们把每个区间(前驱/后继)的(最大值/最小值)搞一搞就好啦。
PS.谁能告诉我bzoj上那些2000ms的是怎么做的。。。

代码

#include <algorithm>
#include <cctype>
#include <cstdio>
using namespace std;
const int maxn = 4e6 + 5;
const int inf = 1e9;
int ans, n, m, opt, l, r, k, pos, sz, Max;
int a[maxn], fa[maxn], ch[maxn][2], size[maxn], cnt[maxn], data[maxn], rt[maxn];
inline int read() {
  int x = 0, f = 1;
  char ch = getchar();
  while (!isdigit(ch)) {
    if (ch == '-')
      f = -1;
    ch = getchar();
  }
  while (isdigit(ch)) {
    x = x * 10 + ch - '0';
    ch = getchar();
  }
  return x * f;
}
struct Splay {
  void clear(int x) {
    fa[x] = ch[x][0] = ch[x][1] = size[x] = cnt[x] = data[x] = 0;
  }
  void update(int x) {
    if (x) {
      size[x] = cnt[x];
      if (ch[x][0])
        size[x] += size[ch[x][0]];
      if (ch[x][1])
        size[x] += size[ch[x][1]];
    }
  }
  void zig(int x) {
    int y = fa[x], z = fa[y], l = (ch[y][1] == x), r = l ^ 1;
    fa[ch[y][l] = ch[x][r]] = y;
    fa[ch[x][r] = y] = x;
    fa[x] = z;
    if (z)
      ch[z][ch[z][1] == y] = x;
    update(y);
    update(x);
  }
  void splay(int i, int x, int aim = 0) {
    for (int y; (y = fa[x]) != aim; zig(x))
      if (fa[y] != aim)
        zig((ch[fa[y]][0] == y) == (ch[y][0] == x) ? y : x);
    if (aim == 0)
      rt[i] = x;
  }
  void insert(int i, int v) {
    int x = rt[i], y = 0;
    if (x == 0) {
      rt[i] = x = ++sz;
      fa[x] = ch[x][0] = ch[x][1] = 0;
      size[x] = cnt[x] = 1;
      data[x] = v;
      return;
    }
    while (1) {
      if (data[x] == v) {
        cnt[x]++;
        update(y);
        splay(i, x);
        return;
      }
      y = x;
      x = ch[x][v > data[x]];
      if (x == 0) {
        ++sz;
        fa[sz] = y;
        ch[sz][0] = ch[sz][1] = 0;
        size[sz] = cnt[sz] = 1;
        data[sz] = v;
        ch[y][v > data[y]] = sz;
        update(y);
        splay(i, sz);
        rt[i] = sz;
        return;
      }
    }
  }
  void find(int i, int v) {
    int x = rt[i];
    while (1) {
      if (data[x] == v) {
        splay(i, x);
        return;
      } else
        x = ch[x][v > data[x]];
    }
  }
  int pre(int i) {
    int x = ch[rt[i]][0];
    while (ch[x][1])
      x = ch[x][1];
    return x;
  }
  int succ(int i) {
    int x = ch[rt[i]][1];
    while (ch[x][0])
      x = ch[x][0];
    return x;
  }
  void del(int i) {
    int x = rt[i];
    if (cnt[x] > 1) {
      cnt[x]--;
      return;
    }
    if (!ch[x][0] && !ch[x][1]) {
      clear(rt[i]);
      rt[i] = 0;
      return;
    }
    if (!ch[x][0]) {
      int oldroot = x;
      rt[i] = ch[x][1];
      fa[rt[i]] = 0;
      clear(oldroot);
      return;
    }
    if (!ch[x][1]) {
      int oldroot = x;
      rt[i] = ch[x][0];
      fa[rt[i]] = 0;
      clear(oldroot);
      return;
    }
    int y = pre(i), oldroot = x;
    splay(i, y);
    rt[i] = y;
    ch[rt[i]][1] = ch[oldroot][1];
    fa[ch[oldroot][1]] = rt[i];
    clear(oldroot);
    update(rt[i]);
    return;
  }
  int rank(int i, int v) {
    int x = rt[i], ans = 0;
    while (1) {
      if (!x)
        return ans;
      if (data[x] == v)
        return ((ch[x][0]) ? size[ch[x][0]] : 0) + ans;
      else if (data[x] < v) {
        ans += ((ch[x][0]) ? size[ch[x][0]] : 0) + cnt[x];
        x = ch[x][1];
      } else if (data[x] > v) {
        x = ch[x][0];
      }
    }
  }
  int find_pre(int i, int v) {
    int x = rt[i];
    while (x) {
      if (data[x] < v) {
        if (ans < data[x])
          ans = data[x];
        x = ch[x][1];
      } else
        x = ch[x][0];
    }
    return ans;
  }
  int find_succ(int i, int v) {
    int x = rt[i];
    while (x) {
      if (v < data[x]) {
        if (ans > data[x])
          ans = data[x];
        x = ch[x][0];
      } else
        x = ch[x][1];
    }
    return ans;
  }
} sp;
void insert(int k, int l, int r, int x, int v) {
  int mid = (l + r) >> 1;
  sp.insert(k, v);
  if (l == r)
    return;
  if (x <= mid)
    insert(k << 1, l, mid, x, v);
  else
    insert(k << 1 | 1, mid + 1, r, x, v);
}
void askrank(int k, int l, int r, int x, int y, int val) {
  int mid = (l + r) >> 1;
  if (x <= l && r <= y) {
    ans += sp.rank(k, val);
    return;
  }
  if (x <= mid)
    askrank(k << 1, l, mid, x, y, val);
  if (mid + 1 <= y)
    askrank(k << 1 | 1, mid + 1, r, x, y, val);
}
void change(int k, int l, int r, int pos, int val) {
  int mid = (l + r) >> 1;
  sp.find(k, a[pos]);
  sp.del(k);
  sp.insert(k, val);
  if (l == r)
    return;
  if (pos <= mid)
    change(k << 1, l, mid, pos, val);
  else
    change(k << 1 | 1, mid + 1, r, pos, val);
}
void askpre(int k, int l, int r, int x, int y, int val) {
  int mid = (l + r) >> 1;
  if (x <= l && r <= y) {
    ans = max(ans, sp.find_pre(k, val));
    return;
  }
  if (x <= mid)
    askpre(k << 1, l, mid, x, y, val);
  if (mid + 1 <= y)
    askpre(k << 1 | 1, mid + 1, r, x, y, val);
}
void asksucc(int k, int l, int r, int x, int y, int val) {
  int mid = (l + r) >> 1;
  if (x <= l && r <= y) {
    ans = min(ans, sp.find_succ(k, val));
    return;
  }
  if (x <= mid)
    asksucc(k << 1, l, mid, x, y, val);
  if (mid + 1 <= y)
    asksucc(k << 1 | 1, mid + 1, r, x, y, val);
}
int main() {
#ifdef D
  freopen("input", "r", stdin);
#endif
  n = read(), m = read();
  for (int i = 1; i <= n; i++)
    a[i] = read(), Max = max(Max, a[i]), insert(1, 1, n, i, a[i]);
  for (int i = 1; i <= m; i++) {
    opt = read();
    if (opt == 1) {
      l = read(), r = read(), k = read();
      ans = 0;
      askrank(1, 1, n, l, r, k);
      printf("%d\n", ans + 1);
    } else if (opt == 2) {
      l = read(), r = read(), k = read();
      int head = 0, tail = Max + 1;
      while (head != tail) {
        int mid = (head + tail) >> 1;
        ans = 0;
        askrank(1, 1, n, l, r, mid);
        if (ans < k)
          head = mid + 1;
        else
          tail = mid;
      }
      printf("%d\n", head - 1);
    } else if (opt == 3) {
      pos = read();
      k = read();
      change(1, 1, n, pos, k);
      a[pos] = k;
      Max = std::max(Max, k);
    } else if (opt == 4) {
      l = read();
      r = read();
      k = read();
      ans = 0;
      askpre(1, 1, n, l, r, k);
      printf("%d\n", ans);
    } else if (opt == 5) {
      l = read();
      r = read();
      k = read();
      ans = inf;
      asksucc(1, 1, n, l, r, k);
      printf("%d\n", ans);
    }
  }
}

posted on 2017-03-01 17:56  蒟蒻konjac  阅读(359)  评论(0编辑  收藏  举报