AT_jsc2019_final_h Distinct Integers 题解

题目传送门

思路

首先考虑怎么算 \([x,y]\) 中满足条件的对数 \((l, r)\)。我们可以将右端点 \(r\) 固定,求左端点 \(l\) 的方案数。设 \(pre_i\) 代表与 \(a_i\) 相等的上一个位置。我们会发现,\(l\) 必须比区间中的每一个 \(pre_i\) 都要大,才能满足区间中数互不相等的条件。所以 \(l\) 需要比这些 \(pre_i\) 的最大值还要大即可。写为式子即为:

\[\sum_{i = x}^y (i - \max_{j = x}^i pre_j) \]

我们会发现这是一个经典问题:楼房重建。只用维护区间前缀最大值的和即可。维护每一个区间的最大值和区间前缀最大值的和。重点在于如何 push_up

这里可以仿照楼房重建进行解决。还是维护区间最大值和区间前缀最大值的和。每一次把区间分为左区间和右区间,再根据左区间的最大值求出右区间的贡献。具体代码如下:

int calc(int p, int l, int r, int maxl)
{
    if (l == r) return max(tr[p].sum, maxl); // 叶子节点
    if (tr[p].maxn < maxl) return (r - l + 1) * maxl; // 整个区间的最大值都不如前面的最大值
    int mid = (l + r) >> 1;
    if (tr[p * 2].maxn < maxl) return maxl * (mid - l + 1) + calc(p * 2 + 1, mid + 1, r, maxl); // 右区间的左儿子比 maxl 小
    return calc(p * 2, l, mid, maxl) + tr[p].sum - tr[p * 2].sum; // 有区间的左儿子不比 maxl 小
}

set 维护一下每一个数所在的位置即可。时间复杂度 \(\mathcal{O}(q \log^2 n)\)

代码

#include <bits/stdc++.h>
#define int long long
using namespace std;

const int N = 5e5 + 5;

int n, q, maxx;
int a[N], pre[N];
set<int> s[N];

struct seg_tree
{
	struct node
	{
		int l, r, maxn, sum;
	} tr[4 * N];
	int calc(int p, int l, int r, int maxl)
	{
		if (l == r) return max(tr[p].sum, maxl);
		if (tr[p].maxn < maxl) return (r - l + 1) * maxl;
		int mid = (l + r) >> 1;
		if (tr[p * 2].maxn < maxl) return maxl * (mid - l + 1) + calc(p * 2 + 1, mid + 1, r, maxl);
		return calc(p * 2, l, mid, maxl) + tr[p].sum - tr[p * 2].sum;
	}
	void push_up(int p, int l, int r)
	{
		tr[p].maxn = max(tr[p * 2].maxn, tr[p * 2 + 1].maxn);
		int mid = (l + r) >> 1;
		tr[p].sum = tr[p * 2].sum + calc(p * 2 + 1, mid + 1, r, tr[p * 2].maxn);
	}
	void build(int p, int l, int r)
	{
		tr[p] = {l, r, 0, 0};
		if (l == r)
		{
			tr[p].maxn = tr[p].sum = pre[l];
			return;
		}
		int mid = (l + r) >> 1;
		build(p * 2, l, mid);
		build(p * 2 + 1, mid + 1, r);
		push_up(p, l, r);
	}
	void modify(int p, int l, int r, int k, int x)
	{
		if (l == r)
		{
			pre[l] = tr[p].maxn = tr[p].sum = x;
			return;
		}
		int mid = (l + r) >> 1;
		if (k <= mid) modify(p * 2, l, mid, k, x);
		else modify(p * 2 + 1, mid + 1, r, k, x);
		push_up(p, l, r);
	}
	int query(int p, int l, int r)
	{
		if (l <= tr[p].l && tr[p].r <= r)
		{
			int now = calc(p, tr[p].l, tr[p].r, maxx);
			maxx = max(maxx, tr[p].maxn);
			return now;
		}
		int mid = (tr[p].l + tr[p].r) >> 1, res = 0;
		if (l <= mid) res += query(p * 2, l, r);
		if (r > mid) res += query(p * 2 + 1, l, r);
		return res;
	}
} ST;

int find(int pos)
{
	auto it = s[a[pos]].find(pos);
	if (it == s[a[pos]].begin()) return 0;
	it--;
	return *it;
}

signed main()
{
	scanf("%lld%lld", &n, &q);
	for (int i = 1; i <= n; i++)
		scanf("%lld", &a[i]), s[a[i]].insert(i);
	for (int i = 0; i < n; i++)
	{
		if (!s[i].size()) continue;
		int pos = 0;
		auto it = s[i].begin();
		while (it != s[i].end())
		{
			pre[*it] = pos;
			pos = *it;
			it++;
		}
	}
	ST.build(1, 1, n);
	while (q--)
	{
		int op, x, y;
		scanf("%lld%lld%lld", &op, &x, &y);
		if (op == 0)
		{
			x++;
			vector<int> tmp;
			tmp.push_back(x);
			auto it = s[a[x]].find(x);
			if (it != (--s[a[x]].end())) tmp.push_back(*(++it));
			s[a[x]].erase(x);
			a[x] = y;
			s[a[x]].insert(x);
			it = s[a[x]].find(x);
			if (it != (--s[a[x]].end())) tmp.push_back(*(++it));
			for (int i : tmp) ST.modify(1, 1, n, i, find(i));
		}
		else
		{
			int l = x + 1, r = y;
			maxx = x;
			int ans = (l + r) * (r - l + 1) / 2 - ST.query(1, l, r);
			printf("%lld\n", ans);
		}
	}
	return 0;
}
posted @ 2026-01-10 16:29  lucasincyber  阅读(9)  评论(0)    收藏  举报