「题解」洛谷 P5268 一个简单的询问
描述
给定一个长度为 \(n\) 的序列 \(a\),\(q\) 次询问。每次询问给定 \(l_1,r_1,l_2,r_2\),求 \(\sum\limits_{x=0}^\infty \text{get}(l_1,r_1,x)\cdot\text{get}(l_2,r_2,x)\)。
\(\text{get}(l,r,x)\) 表示计算区间 \([l,r]\) 中,数字 \(x\) 出现了多少次。
思路
考虑容斥,即 \(\text{get}(l,r,x)=\text{get}(1,r,x)-\text{get}(1,l-1,x)\)。
再来化简式子:
记 \(\text{g}(p)=\text{get}(1,p,x)\),即 \(1\sim p\) 中 \(x\) 的出现次数。
展开,得:
这样就把一次询问强行拆成了 \(4\) 次询问,\(4\) 次询问的和(差)即为答案。
记二元组 \((l,r)\) 表示一小次询问 \(\pm\sum\limits_{x=0}^\infty\text{g}(l)\cdot\text{g}(r)\),即所有数在 \(1\sim l\) 的出现次数 \(\times\) 在 \(1\sim r\) 的出现次数。
发现从 \((l,r)\) 转移到 \((l',r')\) 的暴力转移的时间复杂度是 \(\mathcal{O}(\text{abs}(l'-l)+\text{abs}(r'-r))\),可以直接莫队。
在这题中,\(l\) 和 \(r\) 的意义相较于以往莫队模板不再是狭意义上的「区间」,而变成了更广意义上的「区间」。即莫队的区间 \([l,r]\) 被抽象为形如 \((x,y)\) 的二元组,只要可以在 \(\mathcal{O}(\text{abs}(x'-x)+\text{abs}(y'-y))\) 的时间复杂度内从 \((x,y)\) 转移到 \((x',y')\),就可以莫队求解。
代码
#include <bits/stdc++.h>
using namespace std;
#define re register
#define int unsigned
#define MAXN 100010
int n, k, m, a[MAXN], len, sum, bel[MAXN], geta[MAXN], getl[MAXN], getr[MAXN], ans[MAXN];
struct ques {
int i, l, r, z;
ques() {}
ques(const int &a, const int &b, const int &c, const int &d) {
i = a, l = b, r = c, z = d;
}
} q[MAXN << 2];
inline bool cmp(const ques &a, const ques &b) {
if (bel[a.l] != bel[b.l]) return a.l < b.l;
if (bel[a.l] & 1) return a.r < b.r;
else return a.r > b.r;
}
inline void update_r(const int &c) {
sum -= geta[c]; geta[c] += getl[c]; getr[c]++; sum += geta[c];
}
inline void update_l(const int &c) {
sum -= geta[c]; geta[c] += getr[c]; getl[c]++; sum += geta[c];
}
inline void remove_r(const int &c) {
sum -= geta[c]; geta[c] -= getl[c]; getr[c]--; sum += geta[c];
}
inline void remove_l(const int &c) {
sum -= geta[c]; geta[c] -= getr[c]; getl[c]--; sum += geta[c];
}
signed main() {
cin >> n;
for (register int i = 1; i <= n; i++) cin >> a[i];
cin >> k;
for (register int i = 1; i <= k; i++) {
int l1, r1, l2, r2;
cin >> l1 >> r1 >> l2 >> r2;
q[++m] = ques(i, r1, r2, 1), q[++m] = ques(i, r1, l2 - 1, -1), q[++m] = ques(i, l1 - 1, r2, -1), q[++m] = ques(i, l1 - 1, l2 - 1, 1);
}
len = n / sqrt(n * 2 / 3);
for (register int i = 1; i <= n; i++) bel[i] = (i - 1) / len + 1;
sort(q + 1, q + m + 1, cmp);
int l = 0, r = 0;
for (register int i = 1; i <= m; i++) {
for (register int j = l + 1; j <= q[i].l; j++) update_l(a[j]);
for (register int j = l; j > q[i].l; j--) remove_l(a[j]);
for (register int j = r + 1; j <= q[i].r; j++) update_r(a[j]);
for (register int j = r; j > q[i].r; j--) remove_r(a[j]);
l = q[i].l, r = q[i].r;
ans[q[i].i] += q[i].z * sum;
}
for (register int i = 1; i <= k; i++) cout << ans[i] << endl;
return 0;
}