洛谷P3466 [POI 2008] KLO-Building blocks 题解 FHQ Treap

题目链接:https://www.luogu.com.cn/problem/P3466

解题思路完全来自 CodyTheWolf大佬的博客

示例程序:

#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e5 + 5;
mt19937 rng(time(0));

struct Node {
	int ls, rs, key, pri, sz;
	long long sum;
	Node() {}
	Node(int _key) { ls = rs = 0; sum = key = _key; pri = rng(); sz = 1; }
} tr[maxn];
int rt, idx;

void push_up(int u) {
	int ls = tr[u].ls, rs = tr[u].rs;
	tr[u].sz = tr[ls].sz + tr[rs].sz + 1;
	tr[u].sum = tr[ls].sum + tr[rs].sum + tr[u].key;
}

void split(int u, int x, int &L, int &R) {
	if (!u) {
		L = R = 0;
		return;
	}
	if (tr[u].key <= x) {
		L = u;
		split(tr[u].rs, x, tr[u].rs, R);
	}
	else {
		R = u;
		split(tr[u].ls, x, L, tr[u].ls);
	}
	push_up(u);
}

int merge(int L, int R) {
	if (!L || !R) return L + R;
	if (tr[L].pri > tr[R].pri) {
		tr[L].rs = merge(tr[L].rs, R);
		push_up(L);
		return L;
	}
	else {
		tr[R].ls = merge(L, tr[R].ls);
		push_up(R);
		return R;
	}
}

void ins(int x) {
	int L, R;
	split(rt, x, L, R);
	tr[++idx] = Node(x);
	rt = merge(merge(L, idx), R);
}

void del(int x) {
	int L, R, p;
	split(rt, x, L, R);
	split(L, x-1, L, p);
	p = merge(tr[p].ls, tr[p].rs);
	rt = merge(merge(L, p), R);
}

int kth(int u, int k) {
	while (u) {
		int sz = tr[tr[u].ls].sz;
		if (sz + 1 == k)
			return u;
		else if (sz >= k)
			u = tr[u].ls;
		else
			k -= sz + 1,
			u = tr[u].rs;
	}
	return 0;
}

int n, K, h[maxn], ansp, ansmid;
long long ans = 1e18;

long long cal() {
    int val = tr[ kth(rt, (K+1)/2) ].key;
    int L, R, p;
    split(rt, val, L, R);
    split(L, val-1, L, p);
    long long res = (long long) val * tr[L].sz - tr[L].sum + tr[R].sum - (long long) val * tr[R].sz;
    rt = merge(merge(L, p), R);
    return res;
}

int main() {
    scanf("%d%d", &n, &K);
    for (int i = 1; i <= n; i++) {
        scanf("%d", h+i);
        ins(h[i]);
        if (i > K)
            del(h[i-K]);
        if (i < K) continue;
        long long tmp = cal();
        if (tmp < ans)
            ans = tmp, ansp = i - K + 1, ansmid = tr[ kth(rt, (K+1)/2) ].key;
    }
    for (int i = 0; i < K; i++)
        h[ansp + i] = ansmid;
    printf("%lld\n", ans);
    for (int i = 1; i <= n; i++)
        printf("%d\n", h[i]);
	return 0;
}
posted @ 2025-04-17 16:38  quanjun  阅读(11)  评论(0)    收藏  举报