题解 P3332 [ZJOI2013]K大数查询

题目描述

Link

你需要维护 \(n\) 个可重整数集,集合的编号从 \(1\)\(n\)
这些集合初始都是空集,有 \(m\) 个操作:

  • 1 l r c:表示将 \(c\) 加入到编号在 \([l,r]\) 内的集合中
  • 2 l r c:表示查询编号在 \([l,r]\) 内的集合的并集中,第 \(c\) 大的数是多少。

\(1 \leq n ,m \leq 10^5 ,1 \leq c \leq 2^{63}\)

Solution

发现插入,查询 \(k\) 大,考虑线段树套平衡树。

但是发现 \(n ,m \leq 10^5\) ,并且区间插入并不好打懒标记,比较麻烦。

并且因为平衡树极其优秀的常数,很容易 TLE 。

所以考虑转换思路:我们用权值线段树套动态开店线段树。

权值线段树维护所有权值,里面那一层线段树维护所有权值在下标内出现的次数。

对于 1 操作,我们先在外层线段树找到权值为 \(c\) 的节点,然后在经过的所有节点对应的内层线段树上让区间 \([l,r]\) 加上一,表示这些位置多了一个数。

对于 2 操作,假如我们当前在权值线段树上的节点是 \(now\) 。我们可以看一下 \(now\) 的右儿子落在下标 \([l,r]\) 范围内有几个数(就是在右儿子对应的里层线段树中查询 \([l,r]\) 的和)。如果个数 \(\geq c\) ,那么 \(c\) 大一定在右子树里,递归右儿子处理。

否则 \(c\) 大在左儿子的子树里,递归左子树。注意递归左子树的时候需要将 \(c\) 减掉右儿子 \([l,r]\) 内的数的个数(就像平衡树查排名对应的值一样)。

另外,注意到 \(c \leq 2^{63}\) ,所以我们要先读进所有的操作,然后把所有插入的数放在一起离散化。

代码如下:

#include <cstdio>
#include <cstring>
#include <cctype>
#include <algorithm>
#define int long long
using namespace std;
inline int read() {
    int num = 0 ,f = 1; char c = getchar();
    while (!isdigit(c)) f = c == '-' ? -1 : f ,c = getchar();
    while (isdigit(c)) num = (num << 1) + (num << 3) + (c ^ 48) ,c = getchar();
    return num * f;
}
inline int min(int a ,int b) {return a < b ? a : b;}
inline int max(int a ,int b) {return a > b ? a : b;}
inline void swap(int &a ,int &b) {int t = a; a = b; b = t;}
const int N = 5e4 + 5;
struct Segment1 { //区间加,区间求和的动态开点线段树
    struct node {
        int l ,r ,sum ,add;
        node (int l = 0 ,int r = 0 ,int sum = 0 ,int add = 0) :
            l(l) ,r(r) ,sum(sum) ,add(add) {}
    }t[N * 17 * 17]; int tot; //空间记得开够
    Segment1() : tot(0) {}
    inline void update(int &now) {
        t[now].sum = t[t[now].l].sum + t[t[now].r].sum;
    }
    inline void puttag(int &now ,int l ,int r ,int k) {
        if (!now) now = ++tot; //千万不要写成了 return ;
        t[now].sum += k * (r - l + 1);
        t[now].add += k;
    }
    inline void pushdown(int &now ,int l ,int r) {
        if (t[now].add == 0) return ;
        int mid = (l + r) >> 1;
        puttag(t[now].l ,l ,mid ,t[now].add);
        puttag(t[now].r ,mid + 1 ,r ,t[now].add);
        t[now].add = 0;
    }
    inline void modify(int &now ,int l ,int r ,int ql ,int qr ,int k) {
        if (!now) now = ++tot; //没有创建记得创建
        if (ql <= l && r <= qr) {
			puttag(now ,l ,r ,k);
			return ;
		}
        pushdown(now ,l ,r);
        int mid = (l + r) >> 1;
        if (ql <= mid) modify(t[now].l ,l ,mid ,ql ,qr ,k);
        if (qr > mid) modify(t[now].r ,mid + 1 ,r ,ql ,qr ,k);
        update(now);
    }
    inline int query(int &now ,int l ,int r ,int ql ,int qr) {
        if (!now) now = ++tot;
        if (ql <= l && r <= qr) return t[now].sum;
        pushdown(now ,l ,r);
        int mid = (l + r) >> 1 ,ans = 0;
        if (ql <= mid) ans += query(t[now].l ,l ,mid ,ql ,qr);
        if (qr > mid) ans += query(t[now].r ,mid + 1 ,r ,ql ,qr);
        return ans;
    }
}u;
int n;
struct Segment {
    int t[N << 2];
    Segment () {memset(t ,0 ,sizeof(t));}
    inline void join(int now ,int l ,int r ,int ql ,int qr ,int k) {
        u.modify(t[now] ,1 ,n ,ql ,qr ,1);
        if (l == r) return ;
        int mid = (l + r) >> 1;
        if (k <= mid) join(now << 1 ,l ,mid ,ql ,qr ,k);
        else join(now << 1 | 1 ,mid + 1 ,r ,ql ,qr ,k);
    }
    inline int query(int now ,int l ,int r ,int ql ,int qr ,int k) {
        if (l == r) return l;
        int mid = (l + r) >> 1 ,res = u.query(t[now << 1 | 1] ,1 ,n ,ql ,qr);
        if (res >= k) return query(now << 1 | 1 ,mid + 1 ,r ,ql ,qr ,k);
        return query(now << 1 ,l ,mid ,ql ,qr ,k - res); //这里是 k - res ,不要写成了 k 
    }
}t;
int m ,nums[N] ,tot;
inline int find(int val) {
    return lower_bound(nums + 1 ,nums + tot + 1 ,val) - nums;
}
struct opts {
    int opt ,l ,r ,k;
    opts (int opt = 0 ,int l = 0 ,int r = 0 ,int k = 0) :
        opt(opt) ,l(l) ,r(r) ,k(k) {}
}q[N];
signed main() {
    n = read(); m = read();
    for (int i = 1; i <= m; i++) {
        int opt = read() ,x = read() ,y = read() ,k = read();
        q[i] = opts(opt ,x ,y ,k);
        if (opt == 1) nums[++tot] = k;
    }
    sort(nums + 1 ,nums + tot + 1);
    tot = unique(nums + 1 ,nums + tot + 1) - nums - 1;
    for (int i = 1; i <= m; i++) {
        if (q[i].opt == 1) t.join(1 ,1 ,tot ,q[i].l ,q[i].r ,find(q[i].k));
        else printf("%lld\n" ,nums[t.query(1 ,1 ,tot ,q[i].l ,q[i].r ,q[i].k)]);
    }
    return 0;
}
posted @ 2021-03-07 14:13  recollector  阅读(164)  评论(0编辑  收藏  举报