ABC306 E - Best Performances 题解 离散化+线段树/splay tree

题目链接:https://atcoder.jp/contests/abc306/tasks/abc306_e

题目大意:

有一个长度为 \(N\) 的序列 \(A = (A_1, A_2, \ldots, A_N)\),以及一个整数 \(K\)

初始时序列 \(A\) 的所有元素的数值均为 \(0\)

\(Q\) 次操作,每次操作给你两个整数 \(X_i\)\(Y_i\),你需要将序列 \(A\) 中第 \(X_i\) 个元素的数值修改为 \(Y_i\)(即 \(A_{X_i} \leftarrow Y_i\)),然后输出一个整数,这个整数的数值为序列 \(A\) 中最小的 \(K\) 个数之和。

解题思路:

线段树离散化之后,或者 splay tree 都是基本操作。

示例程序1(线段树 + 离散化):

#include <bits/stdc++.h>
using namespace std;
const int maxn = 5e5 + 5;
int n, M, K, Q, x[maxn], y[maxn], a[maxn];
vector<int> vec;

int lsh(int val) {
	return lower_bound(vec.begin(), vec.end(), val) - vec.begin() + 1;
}

// 线段树
int tr_cnt[maxn<<2];
long long tr_sum[maxn<<2];
#define lson l, mid, rt<<1
#define rson mid+1, r, rt<<1|1
void push_up(int rt) {
	tr_cnt[rt] = tr_cnt[rt<<1] + tr_cnt[rt<<1|1];
	tr_sum[rt] = tr_sum[rt<<1] + tr_sum[rt<<1|1];
}
// 离散化之后的数字是p,增加了 c 个(+1 或者 -1) 
void add(int p, int c, int l, int r, int rt) {
//	printf("add [%d, %d] [%d, %d] %d\n", p, c, l, r, rt);
	if (l == r) {
		tr_cnt[rt] += c;
		tr_sum[rt] += (long long) c * vec[p-1];
		return;
	}
	int mid = (l + r) / 2;
	(p <= mid) ? add(p, c, lson) : add(p, c, rson);
	push_up(rt);  
}
// 前k个数 
long long query(int k, int l, int r, int rt) {
//	printf("query %d [%d , %d] %d\n", k, l, r, rt);
	if (l == r) {
		assert(tr_cnt[rt] >= k);
		return (long long) k * vec[l-1];
	}
	if (tr_cnt[rt] == k)
		return tr_sum[rt];
	int mid = (l + r) / 2;
	if (tr_cnt[rt<<1|1] >= k)
		return query(k, rson);
	return tr_sum[rt<<1|1] + query(k - tr_cnt[rt<<1|1], lson);
}

int main()
{
	scanf("%d%d%d", &n, &K, &Q);
	vec.push_back(0);
	for (int i = 0; i < Q; i++) {
		scanf("%d%d", x+i, y+i);
		vec.push_back(y[i]);
	}
	sort(vec.begin(), vec.end());
	vec.erase(unique(vec.begin(), vec.end()), vec.end());
	M = vec.size();
	add(1, n, 1, M, 1);
	for (int i = 0; i < Q; i++) {
		int p = x[i], q = y[i]; // a[p] = q
		add(lsh(a[p]), -1, 1, M, 1);
		a[p] = q;
		add(lsh(a[p]), 1, 1, M, 1);
		printf("%lld\n", query(K, 1, M, 1));
	}
	return 0;
}

示例程序2(splay tree):

#include <bits/stdc++.h>
using namespace std;
const int maxn = 5e5 + 5;
int n, M, K, Q, a[maxn];

struct Node {
	int s[2], p, v;	// s[0]左儿子编号,s[1]右儿子编号,p父节点编号,v数值
	int sz, cnt; // 子树大小
	long long sum; // 子树数值之和
	
	Node() {}
	Node(int _v, int _p) {
		v = _v;
		p = _p;
		s[0] = s[1] = 0;
		sz = cnt = 1;
		sum = _v;
	}
} tr[maxn];
int root, idx;

void push_up(int x) {
	int ls = tr[x].s[0], rs = tr[x].s[1];
	tr[x].sz = tr[ls].sz + tr[rs].sz + tr[x].cnt;
	tr[x].sum = tr[ls].sum + tr[rs].sum + (long long) tr[x].cnt * tr[x].v;
}

void f_s(int p, int u, bool k) {
    tr[p].s[k] = u;
    tr[u].p = p;
}

void rot(int x) {
    int y = tr[x].p, z = tr[y].p;
    bool k = tr[y].s[1] == x;
    f_s(z, x, tr[z].s[1]==y);
    f_s(y, tr[x].s[k^1], k);
    f_s(x, y, k^1);
    push_up(y), push_up(x);
}

// 旋转到 x 的父节点为 k 为止(若k为0,则 x 将旋转到根节点) 
void splay(int x, int k) {
    while (tr[x].p != k) {
        int y = tr[x].p, z = tr[y].p;
        if (z != k)
            (tr[y].s[1]==x) ^ (tr[z].s[1]==y) ? rot(x) : rot(y);
        rot(x);
    }
    if (!k) root = x;
}

// 插入一个数值为 v 的节点 
void ins(int v) {
    int u = root, p = 0;
    while (u) {
    	if (tr[u].v == v)
    		break;
        p = u, u = tr[u].s[v > tr[u].v];
    }
    if (u) {
    	tr[u].cnt++;
    	push_up(u);
	}
    else {
		tr[u = ++idx] = Node(v, p);
    	if (p) tr[p].s[v > tr[p].v] = u;
	}
    splay(u, 0);
}

// 找前驱:找数值 < v 的最大的那个数
int get_pre(int v) {
	int u = root, res = -1;
	while (u) {
		if (tr[u].v < v) res = u, u = tr[u].s[1];
		else u = tr[u].s[0];
	}
	return res;
}

// 找数值等于 v 的最前面(中序遍历序号最小)那个点
int get_point(int v) {
	int u = root, res = -1;
	while (u) {
		if (tr[u].v >= v) res = u, u = tr[u].s[0];
		else u = tr[u].s[1];
	}
	return res;
}  

// 删除一个数值为 v 的节点 
void del(int v) {
	int u1 = get_pre(v);	// 找前驱 
	splay(u1, 0);
	int u2 = get_point(v); // 查找一个数值为 v 的节点
	splay(u2, u1);
	if (tr[u2].cnt > 1) {
		tr[u2].cnt--;
		push_up(u2);
	}
	else
		f_s(u1, tr[u2].s[1], 1);
	push_up(u1);
}

long long query(int k) {
	int u = root;
	long long res = 0;
	if (tr[u].sz <= k) return tr[u].sum;
	while (u) {
		int ls = tr[u].s[0], rs = tr[u].s[1];
		if (tr[rs].sz >= k)
			u = rs;
		else {
			res += tr[rs].sum;
			k -= tr[rs].sz;
			if (k <= tr[u].cnt) {
				res += (long long) tr[u].v * k;
				break;
			}
			else {
				res += (long long) tr[u].v * tr[u].cnt;
				k -= tr[u].cnt;
			}
			u = ls;
		}		
	}
	return res;
}

int main()
{
	scanf("%d%d%d", &n, &K, &Q);
	ins(0);
	for (int i = 0; i < Q; i++) {
		int x, y;
		scanf("%d%d", &x, &y);
		if (a[x])
			del(a[x]);	// 删除 delete 
		a[x] = y;
		if (a[x])
			ins(a[x]);	// 插入 insert
		printf("%lld\n", query(K));
	}
	return 0;
}
posted @ 2024-02-20 16:27  quanjun  阅读(6)  评论(0编辑  收藏  举报