AT_jsc2019_final_h Distinct Integers 题解
思路
首先考虑怎么算 \([x,y]\) 中满足条件的对数 \((l, r)\)。我们可以将右端点 \(r\) 固定,求左端点 \(l\) 的方案数。设 \(pre_i\) 代表与 \(a_i\) 相等的上一个位置。我们会发现,\(l\) 必须比区间中的每一个 \(pre_i\) 都要大,才能满足区间中数互不相等的条件。所以 \(l\) 需要比这些 \(pre_i\) 的最大值还要大即可。写为式子即为:
\[\sum_{i = x}^y (i - \max_{j = x}^i pre_j)
\]
我们会发现这是一个经典问题:楼房重建。只用维护区间前缀最大值的和即可。维护每一个区间的最大值和区间前缀最大值的和。重点在于如何 push_up。
这里可以仿照楼房重建进行解决。还是维护区间最大值和区间前缀最大值的和。每一次把区间分为左区间和右区间,再根据左区间的最大值求出右区间的贡献。具体代码如下:
int calc(int p, int l, int r, int maxl)
{
if (l == r) return max(tr[p].sum, maxl); // 叶子节点
if (tr[p].maxn < maxl) return (r - l + 1) * maxl; // 整个区间的最大值都不如前面的最大值
int mid = (l + r) >> 1;
if (tr[p * 2].maxn < maxl) return maxl * (mid - l + 1) + calc(p * 2 + 1, mid + 1, r, maxl); // 右区间的左儿子比 maxl 小
return calc(p * 2, l, mid, maxl) + tr[p].sum - tr[p * 2].sum; // 有区间的左儿子不比 maxl 小
}
用 set 维护一下每一个数所在的位置即可。时间复杂度 \(\mathcal{O}(q \log^2 n)\)。
代码
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int N = 5e5 + 5;
int n, q, maxx;
int a[N], pre[N];
set<int> s[N];
struct seg_tree
{
struct node
{
int l, r, maxn, sum;
} tr[4 * N];
int calc(int p, int l, int r, int maxl)
{
if (l == r) return max(tr[p].sum, maxl);
if (tr[p].maxn < maxl) return (r - l + 1) * maxl;
int mid = (l + r) >> 1;
if (tr[p * 2].maxn < maxl) return maxl * (mid - l + 1) + calc(p * 2 + 1, mid + 1, r, maxl);
return calc(p * 2, l, mid, maxl) + tr[p].sum - tr[p * 2].sum;
}
void push_up(int p, int l, int r)
{
tr[p].maxn = max(tr[p * 2].maxn, tr[p * 2 + 1].maxn);
int mid = (l + r) >> 1;
tr[p].sum = tr[p * 2].sum + calc(p * 2 + 1, mid + 1, r, tr[p * 2].maxn);
}
void build(int p, int l, int r)
{
tr[p] = {l, r, 0, 0};
if (l == r)
{
tr[p].maxn = tr[p].sum = pre[l];
return;
}
int mid = (l + r) >> 1;
build(p * 2, l, mid);
build(p * 2 + 1, mid + 1, r);
push_up(p, l, r);
}
void modify(int p, int l, int r, int k, int x)
{
if (l == r)
{
pre[l] = tr[p].maxn = tr[p].sum = x;
return;
}
int mid = (l + r) >> 1;
if (k <= mid) modify(p * 2, l, mid, k, x);
else modify(p * 2 + 1, mid + 1, r, k, x);
push_up(p, l, r);
}
int query(int p, int l, int r)
{
if (l <= tr[p].l && tr[p].r <= r)
{
int now = calc(p, tr[p].l, tr[p].r, maxx);
maxx = max(maxx, tr[p].maxn);
return now;
}
int mid = (tr[p].l + tr[p].r) >> 1, res = 0;
if (l <= mid) res += query(p * 2, l, r);
if (r > mid) res += query(p * 2 + 1, l, r);
return res;
}
} ST;
int find(int pos)
{
auto it = s[a[pos]].find(pos);
if (it == s[a[pos]].begin()) return 0;
it--;
return *it;
}
signed main()
{
scanf("%lld%lld", &n, &q);
for (int i = 1; i <= n; i++)
scanf("%lld", &a[i]), s[a[i]].insert(i);
for (int i = 0; i < n; i++)
{
if (!s[i].size()) continue;
int pos = 0;
auto it = s[i].begin();
while (it != s[i].end())
{
pre[*it] = pos;
pos = *it;
it++;
}
}
ST.build(1, 1, n);
while (q--)
{
int op, x, y;
scanf("%lld%lld%lld", &op, &x, &y);
if (op == 0)
{
x++;
vector<int> tmp;
tmp.push_back(x);
auto it = s[a[x]].find(x);
if (it != (--s[a[x]].end())) tmp.push_back(*(++it));
s[a[x]].erase(x);
a[x] = y;
s[a[x]].insert(x);
it = s[a[x]].find(x);
if (it != (--s[a[x]].end())) tmp.push_back(*(++it));
for (int i : tmp) ST.modify(1, 1, n, i, find(i));
}
else
{
int l = x + 1, r = y;
maxx = x;
int ans = (l + r) * (r - l + 1) / 2 - ST.query(1, l, r);
printf("%lld\n", ans);
}
}
return 0;
}

浙公网安备 33010602011771号