树状数组
树状数组(Binary Indexed Tree,BIT)是一种用于维护 \(n\) 个元素的前缀信息的数据结构。
以前缀和为例,对于数列 \(a\),可以将其存储为前缀和数组 \(s\) 的形式,其中 \(s_i = \sum \limits_{j=1}^i a_j\)。那么通过前缀和数组,就可以快速求出原数组中给定区间中数字的和:对于区间 \([l,r]\),区间和为 \(s_r - s_{l-1}\),其中假设 \(s_0 = 0\)。
显然,对于长度为 \(n\) 的数列,前缀和需要用长度为 \(n\) 的数组进行存储。而当数列 \(a\) 发生变化时,要使得 \(s\) 数组的内容仍能够正确对应数列 \(a\) 的前缀和,就需要对 \(s\) 的值进行修改,即使数列中只有一个数发生变化,也可能需要修改 \(s\) 数组的多个值,才能保证整个数组仍然存储的是 \(a\) 的前缀和。
类似地,对于长度为 \(n\) 的数列,树状数组也会使用长度为 \(n\) 的数组来进行存储。在这个数组中,每个位置存储的内容则稍微有些复杂。
例题:P3374 【模板】树状数组 1
已知一个数列,需要支持两种操作:
1. 将某一个数加上 \(x\);
2. 求出某区间中每一个数的和并输出。
数列长度和操作个数均不超过 \(5 \times 10^5\)。
分析:如果使用朴素的做法,将这个数列保存在一个数组 \(a\) 中,那么对于第二种操作,需要将查询区间内的每一个数依次加起来。如果这样做,那么最坏情况下每一次操作就要遍历整个数组,导致超时。
另一种想法是,通过将数列存储为前缀和数组 \(s\) 的形式,那么就可以快速求出给定区间的和;然而,对于第一种操作,在最坏情况下则又需要修改整个数组,同样会导致超时。
那么有没有方法可以结合两种做法的优势,使得两个操作均使用较低的时间复杂度来完成呢?这里就可以用到树状数组。
对于任何一种数据结构,可以将其抽象为一个黑匣子:黑匣子里面存储的是数据,可以向其提供支持的操作,包括修改操作和查询操作。当向其支持查询操作时,其需要通过保存的数据计算出需要的结果然后返回;当向其提供修改操作时,黑匣子需要更新其内部的数据,来保证对于之后的查询操作,黑匣子仍能够返回正确的结果。能否解决问题取决于这个黑匣子是否能以及能以何种复杂度实现这些操作;而如何实现这样一个黑匣子,则是我们的任务。
在这个问题中,黑匣子需要维护一个数列,需要支持的有单点修改操作和区间查询操作。
和前缀和类似,树状数组每个位置保存的也是原数组中某一段区间的和。为了准确说明每个位置分别保存的是哪一段区间,首先引入一个函数 lowbit(x),它的值是 \(x\) 的二进制表达式中最低位的 \(1\) 所对应的值。例如,\(6\) 的二进制表示为 \((110)_2\),最低位的 \(1\) 为第二个 \(1\),其对应的值为 \((10)_2=(2)_{10}\),故 \(lowbit(6)=2\);\(20\) 的二进制表示为 \((10100)_2\),最低位的 \(1\) 为第二个 \(1\),其对应的值为 \((100)_2=(4)_{10}\),故 \(lowbit(20)=4\)。
在常见的计算机中,有符号数采用补码表示,而在补码表示下,\(lowbit(x)\) 有一种简单的表达方法:lowbit(x)=x&(-x),其中 & 为按位与。由于 \(-x\) 的补码为 \(x\) 按位取反后再加 \(1\),考虑 \(x\) 和 \(-x\) 的二进制表示,\(x\) 末尾的若干 \(0\) 在取反后变成 \(1\),加上 \(1\) 后变成 \(0\);\(x\) 最低位的 \(1\) 在取反后变成 \(0\),得到进位后变成 \(1\);比该位更高的不会得到进位,维持取反的状态。因此,在按位与的过程中,只有那一位得到的结果为 \(1\),其余都为 \(0\)。
x 二进制表示:0101...1000...0
-x 的反码表示:1010...0111...1
-x 的补码表示:1010...1000...0
那么,假设树状数组使用数组 \(c\) 来进行存储,原来的 \(n\) 个数分别为 \(a_1\) 到 \(a_n\),则 \(c_i = \sum \limits_{j=i-lowbit(i)+1}^i a_j\)。换句话说,树状数组中每个位置保存的是其向前 \(lowbit\) 长度的区间和。

这样做有什么好处呢?考虑假设已经有了这样一个数组 \(c\),如何用它实现前缀和查询操作。假设要求 \(a_1\) 到 \(a_i\) 的前缀和 \(s_i\),可以先将 \(c_i\) 加入答案,那么剩下的部分就是 \(a_1\) 到 \(a_{i-lowbit(i)}\),换句话说,问题变成了求 \(s_{i-lowbit(i)}\)。那么接下来又可以将 \(c_{i-lowbit(i)}\) 加入答案,不断重复操作,直到问题变成求 \(s_0\) 为止,那么此时就已经得到 \(s_i\) 了。示例代码如下:
int query(int x) {
int res = 0;
while (x > 0) {
res += c[x]; x -= lowbit(x); // 从大到小将需要的值求和
}
return res;
}
这个过程的每一步中,把一个数 \(x\) 变成 \(x-lowbit(x)\),结合之前说的 \(lowbit\) 的含义,可以发现实际上是在不断地去掉 \(i\) 的二进制表示中最低位的 \(1\)。由于一个数 \(i\) 的二进制表示的位数不超过 \(\log i\),故每一次查询的时间复杂度为 \(O(\log n)\)。
接下来再考虑单点修改操作。假设修改的数是 \(a_i\),由于可能有多个位置对应的区间包含 \(a_i\),对于这些位置都要进行修改。
例如,要查询 \(s_{14}\) 的值,可以发现 \(s_{14}=c_{14}+c_{12}+c_8=64\);如果要修改 \(a_3\) 的值,则需修改所有包含 \(a_3\) 的区间值,也就是 \(a_3,a_4\) 和 \(a_8\)。

有哪些位置需要包含 \(a_i\) 呢?先考虑几个结论,假设一个位置 \(c_j\) 包含 \(a_i\),那么有:
- \(j \ge i\)。这一点很显然,因为一个位置只会包含它前面的数。
- \(lowbit(j) \ge lowbit(i)\),当且仅当 \(j=i\) 时取等号。
- \(lowbit\) 的值相等的位置不会包含同一个数。
综合以上的结论,可以按 \(lowbit\) 从小到大的顺序找出满足条件的 \(j\)。
首先,\(i\) 是第一个满足条件的 \(j\),记为 \(j_0=i\)。
下一个 \(j\) 需要比 \(i\) 大,且 \(lowbit\) 也要更大,即二进制表示中末尾的 \(0\) 更多,因此至少需要把最后一个 \(1\) 变成 \(0\),也就是至少加上 \(lowbit(j_0)\);由于 \(lowbit(j_0)<lowbit(j_0+lowbit(j_0))\),而 \(j_0=i\),所以 \(i\) 显然在 \(j_0+lowbit(j_0)\) 对应的区间内,也就是说 \(j_0+lowbit(j_0)\) 就是下一个 \(j\),记为 \(j_1=j_0+lowbit(j_0)\)。
再下一个 \(j\) 又可以通过 \(j_1+lowbit(j_1)\) 得到,由于 \(lowbit\) 是翻倍增长的,所以 \(lowbit(j_0)+lowbit(j_1)\) 仍然小于 \(lowbit(j_1)+lowbit(j_1)\),意味着 \(i\) 也在 \(j_1+lowbit(j_1)\) 所对应的区间内,即 \(j_2=j_1+lowbit(j_1)\)。以此类推,即可得到所有需要修改的位置。示例代码如下:
void add(int x, int y) {
while (x <= n) {
c[x] += y; x += lowbit(x); // 从小到大修改需要修改的位置
}
}
由于 \(lowbit\) 的值只有不超过 \(\log n\) 种,一次修改中一个 \(lowbit\) 值最多只会对应一个需要的位置,所以每一次修改的时间复杂度也为 \(O(\log n)\)。
至此,我们知道树状数组可以维护一个数列,并以 \(O(\log n)\) 的时间复杂度进行单点修改操作和前缀和查询操作。对于本题,要实现的是区间和查询操作,可以通过前缀和查询操作来实现:对于 \([l,r]\) 的查询,只需要用 \([1,r]\) 的和减去 \([1,l-1]\) 的和即可。
#include <cstdio>
typedef long long LL;
const int MAXN = 5e5 + 5;
LL a[MAXN];
int n, m;
int lowbit(int x) {
return x & -x;
}
LL query(int x) {
LL ret = 0;
while (x > 0) {
ret += a[x];
x -= lowbit(x);
}
return ret;
}
void update(int x, LL d) {
while (x <= n) {
a[x] += d;
x += lowbit(x);
}
}
int main()
{
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++) {
int x;
scanf("%d", &x);
update(i, x);
}
while (m--) {
int op, x, y;
scanf("%d%d%d", &op, &x, &y);
if (op == 1) update(x, y);
else printf("%lld\n", query(y) - query(x - 1));
}
return 0;
}
例题:P3368 【模板】树状数组 2
已知一个数列,需要进行两种操作:将区间 \([x,y]\) 每一个数加上 \(x\);或者求出某一个数的值。
数列长度和操作个数均不超过 \(5 \times 10^5\)。
分析:和上个问题相反,这里需要对于数列实现区间加法的修改操作和单点的查询操作。乍一看好像没法使用树状数组,但实际上只需要进行一些小处理,就能把这个问题变得和上个问题相同。
对数组进行差分操作:假设原来的数列为 \(a\),令 \(b_i=a_i-a_{i-1}\),那么 \(a_i=\sum \limits_{j=1}^i b_j\),即 \(a\) 是 \(b\) 的前缀和数组。当 \(b_i\) 增加 \(x\) 时,意味着 \(a_i\) 到 \(a_n\) 都会增加 \(x\)。那么,对于 \(b\) 数组而言,第一个操作的效果为:假设要将区间 \([l,r]\) 的数增加 \(x\),则 \(b_l\) 增加 \(x\),\(b_{r+1}\) 减少 \(x\);第二个操作的效果为:求出 \(b\) 的某个前缀和。这样一来,\(b\) 数组就可以用树状数组进行维护。
#include <cstdio>
const int MAXN = 5e5 + 5;
int a[MAXN], n;
int lowbit(int x) {
return x & -x;
}
int query(int x) {
int ret = 0;
while (x > 0) {
ret += a[x];
x -= lowbit(x);
}
return ret;
}
void update(int x, int d) {
while (x <= n) {
a[x] += d;
x += lowbit(x);
}
}
int main()
{
int m, pre = 0;
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++) {
int x;
scanf("%d", &x);
update(i, x - pre);
pre = x;
}
while (m--) {
int op;
scanf("%d", &op);
if (op == 1) {
int x, y, k;
scanf("%d%d%d", &x, &y, &k);
update(x, k); update(y + 1, -k);
} else {
int x;
scanf("%d", &x);
printf("%d\n", query(x));
}
}
return 0;
}
例题:P1908 逆序对
对于给定的一段正整数序列,逆序对就是序列中 \(a_i>a_j\) 且 \(i<j\) 的有序对。给定长度为 \(n\) 的正整数序列,求逆序对数。其中 \(n \le 5 \times 10^5\)。
分析:考虑朴素的做法,枚举 \(i\),再枚举比 \(i\) 大的位置 \(j\),统计 \(a_j<a_i\) 的数量。假设把所有 \(j>i\) 中 \(a_j=k\) 的数量记为 \(cnt_k\),那么也就是统计 \(s_{a_i-1}=\sum \limits_{k=1}^{a_i-1} cnt_k\)。也就是说,查询的是一个数列 \(cnt\) 的前缀和。如果按照从大到小的位置枚举 \(i\),那么每当 \(i\) 前进一步,可用的 \(j\) 就增加一个,需要将 \(cnt_{a_j}\) 增加 \(1\)。可以发现,这是不断地在对数列 \(cnt\) 进行前缀和查询和单点修改操作,因此可以用树状数组维护数列 \(cnt\)。
但是还有一个问题:数列 \(cnt\) 的长度是多少呢?由于 \(a\) 中的元素可以很大,所以 \(cnt\) 的下标也可以很大。为了解决这个问题,可以用到离散化的思想。由于 \(cnt\) 数组开始时全为 \(0\),总共会进行 \(n\) 次修改,也就是说最多只有 \(n\) 个位置不是 \(0\)。因此可以只记录这些可能非 \(0\) 的位置。具体而言,首先将序列排序并去重,在这个序列上利用 std::lower_bound(),可以快速求出原数列中一个数是数列中的第几小。那么 \(cnt_k\) 可以表示序列中第 \(k\) 小的数的个数。这样一来,\(cnt\) 的长度就最多是 \(n\) 了。
#include <cstdio>
#include <vector>
#include <algorithm>
using std::lower_bound;
using std::sort;
using std::unique;
using std::vector;
typedef long long LL;
const int N = 5e5 + 5;
int n, a[N], bit[N], bound;
vector<int> data;
int discretization(int x) { // 求出x是第几小
return lower_bound(data.begin(), data.end(), x) - data.begin() + 1;
}
int lowbit(int x) {
return x & -x;
}
void add(int x) {
while (x <= bound) {
bit[x]++; x += lowbit(x);
}
}
int query(int x) {
int res = 0;
while (x > 0) {
res += bit[x]; x -= lowbit(x);
}
return res;
}
int main()
{
scanf("%d", &n);
for (int i = 1; i <= n; i++) {
scanf("%d", &a[i]); data.push_back(a[i]);
}
// 离散化的准备工作
sort(data.begin(), data.end());
data.erase(unique(data.begin(), data.end()), data.end());
bound = data.size();
LL ans = 0;
for (int i = n; i >= 1; i--) {
ans += query(discretization(a[i]) - 1);
add(discretization(a[i]));
}
printf("%lld\n", ans);
}
习题:P5459 [BJOI2016] 回转寿司
给定一个长度为 \(n\) 的序列 \(a\),从中选出一段连续子序列 \([l,r]\),使得 \(L \le \sum \limits_{i=l}^r a_i \le R\),求方案数。
数据范围:\(1 \le n \le 10^5, |a_i| \le 10^5, 1 \le L,R \le 10^9\)。
解题思路
枚举 \(r = 1 \sim x\),求出对于每个 \(r\) 有多少 \(l\) 符合条件,累加即为答案。
先预处理出前缀和数组 \(sum\),那么 \(\sum \limits_{i=l}^r a_i\) 的值为 \(sum_r - sum_{l-1}\),当且仅当 \(L \le sum_r - sum_{l-1} \le R\) 时 \(l\) 符合条件。将式子变形,可得 \(sum_r - R \le sum_{l-1} \le sum_r - L\)。
所以只需要找到在 \(r\) 前面有多少个 \(sum_{l-1}\) 在 \([sum_r-R,sum_r-L]\) 这个值域范围内。这个问题可以对数据离散化后用树状数组维护,时间复杂度为 \(O(n \log n)\)。
参考代码
#include <cstdio>
#include <algorithm>
#include <vector>
typedef long long LL;
using std::sort;
using std::lower_bound;
using std::unique;
using std::vector;
const int N = 1e5 + 5;
int a[N], bit[N * 3], bound;
LL sum[N];
vector<LL> data;
int discretization(LL x) {
return lower_bound(data.begin(), data.end(), x) - data.begin() + 1;
}
int lowbit(int x) {
return x & -x;
}
void add(int x) {
while (x <= bound) {
bit[x]++;
x += lowbit(x);
}
}
int query(int x) {
int res = 0;
while (x > 0) {
res += bit[x];
x -= lowbit(x);
}
return res;
}
int main()
{
int n, l, r; scanf("%d%d%d", &n, &l, &r);
data.push_back(0);
for (int i = 1; i <= n; i++) {
scanf("%d", &a[i]); sum[i] = sum[i - 1] + a[i]; // 预处理前缀和
data.push_back(sum[i]);
data.push_back(sum[i] - l);
data.push_back(sum[i] - r);
}
sort(data.begin(), data.end());
data.erase(unique(data.begin(), data.end()), data.end());
bound = data.size();
LL ans = 0;
add(discretization(0)); // sum[0]计数加1
for (int i = 1; i <= n; i++) { // 枚举右端点
int q1 = query(discretization(sum[i] - l));
int q2 = query(discretization(sum[i] - r) - 1);
ans += q1 - q2; // 累加在值域范围内的方案数
add(discretization(sum[i])); // sum[i]计数加1
}
printf("%lld\n", ans);
return 0;
}
习题:P6186 [NOI Online #1 提高组] 冒泡排序
给定一个长度为 \(n\) 的排列 \(p\),\(m\) 个操作,需要支持两种操作:交换 \(p_x\) 和 \(p_{x+1}\);查询数组经过 \(k\) 轮冒泡排序后的逆序对个数。
数据范围:\(n,m \le 2 \times 10^5; 1 \le p_i \le n\)。
解题思路
设 \(f_i\) 表示在数字 \(i\) 左侧的比其大的数的个数,那么逆序对个数就是 \(\sum \limits_{i=1}^n f_i\)。
每经过一轮冒泡排序,若原本 \(f_i>0\),则一轮过后 \(f_i\) 会减一,否则保持不变,即等于 \(0\)。想象一下一轮冒泡排序的过程:如果 \(i\) 左边没有更大的数,则这个数左边的数不会跟它发生交换,则 \(f_i\) 仍等于 \(0\);如果左边有更大的数,则一轮冒泡过程中那个更大的数会和 \(i\) 发生交换从而使得 \(f_i\) 减一,并且在这一轮后面的过程中 \(i\) 的位置就不变了。
由上可知,经过 \(k\) 轮冒泡排序之后对逆序对还有贡献的是原本 \(f_i>k\) 的数。则答案为 \((\sum \limits_{f_i>k} f_i) - cnt \times k\),其中 \(cnt\) 代表满足 \(f_i>k\) 的 \(i\) 的个数。这正好是两种不同的前缀和(\(f_i\) 的前缀和以及 \(f_i\) 的个数的前缀和),可以通过树状数组维护。
针对交换操作,如果左小右大,则左边那个数对应的 \(f\) 在交换后会加一,如果左大右小,则右边那个数对应的 \(f\) 在交换后会减一,将其转化为相应树状数组上的更新操作即可。
参考代码
#include <cstdio>
#include <algorithm>
using std::swap;
using std::min;
typedef long long LL;
const int N = 2e5 + 5;
int n, p[N], f[N];
// 树状数组inv用于求一开始的f[i]
// 树状数组cnt用于维护f[i]的个数的前缀和
// 树状数组sum用于维护f[i]的前缀和
LL sum[N], cnt[N], inv[N];
int lowbit(int x) {
return x & -x;
}
void update(LL bit[], int x, int delta) {
while (x <= n) {
bit[x] += delta; x += lowbit(x);
}
}
LL query(LL bit[], int x) {
LL res = 0;
while (x > 0) {
res += bit[x]; x -= lowbit(x);
}
return res;
}
int main()
{
int m; scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++) {
scanf("%d", &p[i]);
f[p[i]] = query(inv, n) - query(inv, p[i]);
if (f[p[i]] > 0) {
update(sum, f[p[i]], f[p[i]]);
update(cnt, f[p[i]], 1);
}
update(inv, p[i], 1);
}
while (m--) {
int t, c; scanf("%d%d", &t, &c);
if (t == 1) {
int i = p[c] < p[c + 1] ? p[c] : p[c + 1];
// 注意不要忘了判f[i]>0
if (f[i] > 0) {
update(sum, f[i], -f[i]);
update(cnt, f[i], -1);
}
f[i] += p[c] < p[c + 1] ? 1 : -1;
if (f[i] > 0) {
update(sum, f[i], f[i]);
update(cnt, f[i], 1);
}
swap(p[c], p[c + 1]);
} else {
c = min(c, n - 1); // 冒泡排序n-1轮过后足够完成排序
LL ans = query(sum, n) - query(sum, c) - (query(cnt, n) - query(cnt, c)) * c;
printf("%lld\n", ans);
}
}
return 0;
}
习题:P4648 [IOI 2007] pairs 动物对数
解题思路
情况一:\(B = 1\)(一维空间)
问题:在一条直线上,给定 \(N\) 个点,找出所有距离小于等于 \(D\) 的点对。
解题思路:
- 排序:首先,将所有 \(N\) 个点的坐标从小到大进行排序。
- 双指针:遍历排序后的每一个点 \(x_i\),对于每个 \(x_i\),想找到有多少个在它之前(即坐标更小)的点 \(x_j\) 满足 \(x_i - x_j \le D\)。
- 使用一个指针 \(j\),它指向满足条件的点中坐标最小的那个。当 \(i\) 向右移动时,\(j\) 也会向右移动,不可能后退。
- 对于当前的 \(i\),移动 \(j\) 直到 \(x_j\) 刚好满足 \(x_i - x_j \le D\)。那么,所有在 \(j\) 和 \(i\) 之间的点(不包括 \(i\))都满足条件。这样的点的数量就是 \(i-j\)。
- 将这个数量累加到总答案中。遍历完所有 \(i\) 后,就得到了最终结果。
情况二:\(B = 2\)(二维空间)
问题:在一个平面上,给定 \(N\) 个点,找出所有曼哈顿距离(\(|x_1 - x_2| + |y_1 - y_2|\))小于等于 \(D\) 的点对。
解题思路:
- 坐标变换:直接处理曼哈顿距离 \(|x_1 - x_2| + |y_1 - y_2| \le D\) 比较复杂。使用一个巧妙的技巧:将坐标系旋转 \(45\) 度。令新坐标为 \(u = x + y\) 和 \(v = x - y\)。经过变换后,原来的曼哈顿距离条件就等价于新坐标系下的切比雪夫距离条件:\(\max (|u_1 - u_2|, |v_1 - v_2|) \le D\)。这又等价于 \(|u_1 - u_2| \le D\) 和 \(|v_1 - v_2| \le D\) 两个不等式同时成立。
- 问题转化为了:对于每个点 \((u_i, v_i)\),寻找有多少个其他的点 \((u_j, v_j)\) 满足 \(u_i - D \le u_j \le u_i + D\) 且 \(v_i - D \le v_j \le v_i + D\)。这是一个经典的二维数点问题。
- 将所有点按 \(u\) 排序。
- 从左到右遍历每个点 \(p_i\)(按 \(u\) 坐标)。
- 使用双指针 \(j\) 维护一个“活动窗口”,确保窗口内所有点的 \(u\) 坐标都满足 \(u_i - u_j \le D\)。对于 \(u\) 坐标太小的点(即 \(u_j \lt u_i - D\)),将其从窗口中移除。
- 对于在窗口内的点,需要快速统计有多少个点的 \(v\) 坐标落在 \([v_i - D, v_i + D]\) 区间内。这个任务可以通过树状数组高效完成。
- 在处理点 \(p_i\) 时,先将不满足 \(u\) 坐标条件的点从树状数组中删除,然后查询树状数组中满足 \(v\) 坐标条件的点的数量,计入总答案。最后,将当前点 \(p_i\) 的 \(v\) 坐标加入树状数组。
情况三:\(B = 3\)(三维空间)
问题:在三维空间中,给定 \(N\) 个点,找出所有曼哈顿距离(\(|x_1 - x_2| + |y_1 - y_2| + |z_1 - z_2|\))小于等于 \(D\) 的点对。
解题思路:
- 利用数据范围:这一问的关键在于题目给出的坐标范围非常小(\(M \le 75\))。
- 坐标变换 + 降维打击:类似于二维情况,对 \(x,y\) 坐标进行变换:\(u = x + y, v = x - y\)。距离条件变为 \(\max (|u_1 - u_2|, |v_1 - v_2|) + |z_1 - z_2| \le D\)。
- 分层处理 + 二维前缀和:
- 可以把三维空间看成是按 \(z\) 坐标分成的很多个独立的二维 \((u,v)\) 平面。
- 由于 \(u,v,z\) 的坐标范围都很小,可以创建一个三维数组 \(sum_{z, u, v}\) 来记录每个坐标上的点的数量。
- 对每个 \(z\) 平面,预处理出它的二维前缀和。这样,就可以在 \(O(1)\) 时间内查询出任意一个矩形区域 \((u_1,v_1)\) 到 \((u_2,v_2)\) 内点的总数。
- 分类统计:遍历每一个点 \(p_i = (u_i, v_i, z_i)\),计算能与它配对的点数。
- 对于和 \(p_i\) 在同一个 \(z\) 平面内的点 \(p_j\),它们 \(z\) 坐标的距离为 \(0\)。所以 \(p_j\) 必须满足 \(\max(|u_i - u_j|, |v_i - v_j|) \le D\)。利用处理好的二维前缀和,可以快速查出在 \(z_i\) 平面上,以 \((u_i,v_i)\) 为中心、范围为 \(D\) 的正方形内有多少个点。
- 对于在不同 \(z\) 平面内的点 \(p_j\),设 \(z\) 方向的距离为 \(d_z = |z_i - z_j|\),那么 \(p_j\) 必须满足 \(\max(|u_i-u_j|,|v_i-v_j|) \le D-d_z\)。遍历所有 \(z_j\) 不等于 \(z_i\) 的平面,利用二维前缀和计算出每个平面上满足条件的点的数量,然后累加起来。
- 去重:在统计时,为了避免重复计数(例如 \((A,B)\) 和 \((B,A)\) 被算作两对),对同一平面的点对数最后除以 \(2\),对不同平面的点对只计算 \(z_j \lt z_i\) 的情况。
参考代码
#include <cstdio>
#include <algorithm>
using namespace std;
typedef long long LL;
const int MAXN = 100005;
int b, n, d, m;
int lowbit(int x) {
return x & -x;
}
namespace b1 {
int x[MAXN];
void solve() {
for (int i = 1; i <= n; i++) scanf("%d", &x[i]);
sort(x + 1, x + n + 1);
LL ans = 0;
int j = 1;
for (int i = 1; i <= n; i++) {
while (j < i && x[j] + d < x[i]) j++;
ans += i - j;
}
printf("%lld\n", ans);
}
};
namespace b2 {
struct Point {
int x, y;
bool operator<(const Point& other) const {
return x < other.x;
}
} p[MAXN];
LL c[MAXN * 3];
void update(int x, int d) {
while (x < MAXN * 3) {
c[x] += d;
x += lowbit(x);
}
}
LL query(int x) {
if (x <= 0) return 0;
if (x >= MAXN * 3) return c[MAXN * 3 - 1];
LL ret = 0;
while (x > 0) {
ret += c[x];
x -= lowbit(x);
}
return ret;
}
void solve() {
for (int i = 1; i <= n; i++) {
int x, y;
scanf("%d%d", &x, &y);
p[i].x = x + y + MAXN; p[i].y = x - y + MAXN;
}
sort(p + 1, p + n + 1);
LL ans = 0;
int j = 1;
for (int i = 1; i <= n; i++) {
while (j < i && p[j].x + d < p[i].x) {
update(p[j].y, -1);
j++;
}
ans += query(p[i].y + d) - query(p[i].y - d - 1);
update(p[i].y, 1);
}
printf("%lld\n", ans);
}
};
namespace b3 {
struct Point {
int x, y, z;
} p[MAXN];
LL sum[100][300][300];
LL query(int z, int x, int y) {
if (x <= 0 || y <= 0) return 0;
if (x >= 300) x = 299;
if (y >= 300) y = 299;
return sum[z][x][y];
}
LL calc(int z, int x, int y, int d) {
return query(z, x + d, y + d) - query(z, x - d - 1, y + d) - query(z, x + d, y - d - 1) + query(z, x - d - 1, y - d - 1);
}
void solve() {
for (int i = 1; i <= n; i++) {
int x, y, z;
scanf("%d%d%d", &x, &y, &z);
p[i].x = x + y + 100; p[i].y = x - y + 100; p[i].z = z;
sum[p[i].z][p[i].x][p[i].y]++;
}
for (int i = 1; i <= 75; i++)
for (int j = 1; j < 300; j++)
for (int k = 1; k < 300; k++)
sum[i][j][k] += sum[i][j - 1][k] + sum[i][j][k - 1] - sum[i][j - 1][k - 1];
LL ans1 = 0, ans2 = 0;
for (int i = 1; i <= n; i++) {
for (int j = max(p[i].z - d, 1); j < p[i].z; j++) ans1 += calc(j, p[i].x, p[i].y, d - p[i].z + j);
ans2 += calc(p[i].z, p[i].x, p[i].y, d) - 1;
}
printf("%lld\n", ans1 + ans2 / 2);
}
};
int main()
{
scanf("%d%d%d%d", &b, &n, &d, &m);
if (b == 1) b1::solve();
else if (b == 2) b2::solve();
else b3::solve();
return 0;
}
树状数组优化 DP
如果将树状数组代码中的求和改为取 max 或取 min,则树状数组可以用来维护前缀最大或最小值,从而帮助优化一些 DP 问题。
例题:P3431 [POI 2005] AUT-The Bus
在一个二维平面上给定 \(k\) 个点,每个点有一个坐标 \((x, y)\) 以及点权 \(p\),从左下角 \((1,1)\) 走到右上角 \((n,m)\),只能向上或向右走,求经过的点权和的最大值,\(k \le 10^5\)。
解题思路
若某个点为点 \(i\),设 \(dp_i\) 表示 \((1,1)\) 到 \((x_i, y_i)\) 点权和的最大值,则有 \(dp_i = \max \{ dp_j \} + p_i\),其中点 \(j\) 需要满足 \(x_j \le x_i\) 并且 \(y_j \le y_i\),也就是点 \(j\) 在点 \(i\) 的左下方。
为了保证计算某个点 \(i\) 时其左下方的所有点都已计算过,可以对输入的点以横坐标为第一关键字,纵坐标为第二关键字进行排序,则排序后按顺序扫描即满足之前的点一定是在左边的。此时要求出该点下方(即 \(y_j \le y_i\))的 \(dp_j\) 的最大值,正好是一个前缀最大值,所以可以用树状数组来维护。
时间复杂度 \(O(k \log k)\)。
参考代码
#include <cstdio>
#include <algorithm>
#include <vector>
using ll = long long;
const int K = 100005;
struct Point {
int x, y, p;
};
Point a[K];
int k;
ll c[K], dp[K];
std::vector<int> num;
int discretize(int x) {
return std::lower_bound(num.begin(), num.end(), x) - num.begin() + 1;
}
int lowbit(int x) {
return x & -x;
}
void update(int x, ll val) {
while (x <= k) {
c[x] = std::max(c[x], val);
x += lowbit(x);
}
}
ll query(int x) {
ll res = 0;
while (x > 0) {
res = std::max(res, c[x]);
x -= lowbit(x);
}
return res;
}
int main()
{
int n, m; scanf("%d%d%d", &n, &m, &k);
for (int i = 1; i <= k; i++) {
scanf("%d%d%d", &a[i].x, &a[i].y, &a[i].p);
num.push_back(a[i].y);
}
std::sort(num.begin(), num.end());
num.erase(std::unique(num.begin(), num.end()), num.end());
std::sort(a + 1, a + k + 1, [](const Point& lhs, const Point& rhs) {
return lhs.x != rhs.x ? lhs.x < rhs.x : lhs.y < rhs.y;
});
ll ans = 0;
for (int i = 1; i <= k; i++) {
a[i].y = discretize(a[i].y);
dp[i] = query(a[i].y) + a[i].p;
ans = std::max(ans, dp[i]);
update(a[i].y, dp[i]);
}
printf("%lld\n", ans);
return 0;
}
习题:P6007 [USACO20JAN] Springboards G
在一个二维平面上给定 \(p\) 对点,每对点有坐标 \((x_1, y_1)\) 和 \((x_2, y_2)\),表示从前者可以不需要行走瞬移到后者,从左下角 \((0,0)\) 走到右上角 \((n,n)\),只能向上或向右走,求最小的行走距离,\(p \le 10^5\)。
解题思路
设到点 \(i\) 的最小行走距离是 \(dp_i\),则有 \(dp_i = \min \{ dp_j + x_i - x_j + y_i - y_j \}\),其中 \(j\) 在 \(i\) 的左下方。在计算 \(dp_i\) 时,\(x_i\) 和 \(y_i\) 是两个定值,可以拆到括号外面,也就是 \(dp_i = \min \{ dp_j - x_j - y_j \} + x_i + y_i\),于是和上一题类似,只不过相当于需要维护 \(dp - x - y\) 的最小值。
而如果点 \(i\) 是跳板的右上端点,还有一种情况是 \(dp_i = dp_j\),这里的点 \(j\) 指的是该跳板的左下端点。
为了在点排序后能维持之前的跳板关系,可以使用索引排序。
时间复杂度 \(O(p \log p)\)。
参考代码
#include <cstdio>
#include <vector>
#include <algorithm>
using ll = long long;
const int P = 200005;
int n, p, x[P], y[P], idx[P], from[P];
ll c[P], dp[P];
std::vector<int> num;
int discretize(int x) {
return std::lower_bound(num.begin(), num.end(), x) - num.begin() + 1;
}
int lowbit(int x) {
return x & -x;
}
void update(int x, ll val) {
while (x <= 2 * p) {
c[x] = std::min(c[x], val);
x += lowbit(x);
}
}
ll query(int x) {
ll res = 2 * n;
while (x > 0) {
res = std::min(res, c[x]);
x -= lowbit(x);
}
return res;
}
int main()
{
scanf("%d%d", &n, &p);
for (int i = 1; i <= p; i++) {
int x1, y1, x2, y2;
scanf("%d%d%d%d", &x1, &y1, &x2, &y2);
x[i] = x1; y[i] = y1;
x[i + p] = x2; y[i + p] = y2; from[i + p] = i;
num.push_back(y1); num.push_back(y2);
idx[i] = i; idx[i + p] = i + p;
}
std::sort(num.begin(), num.end());
num.erase(std::unique(num.begin(), num.end()), num.end());
std::sort(idx + 1, idx + 2 * p + 1, [](int i, int j) {
return x[i] != x[j] ? x[i] < x[j] : y[i] < y[j];
});
for (int i = 1; i <= 2 * p; i++) c[i] = 2 * n;
ll ans = 2 * n;
for (int i = 1; i <= 2 * p; i++) {
int cur = idx[i];
if (x[cur] > n || y[cur] > n) continue;
dp[cur] = x[cur] + y[cur]; // 从(0,0)直接走过来
int d = discretize(y[cur]);
dp[cur] = std::min(dp[cur], query(d) + x[cur] + y[cur]);
if (from[cur] != 0) { // 如果是某个跳板的右上端点
dp[cur] = std::min(dp[cur], dp[from[cur]]);
}
ans = std::min(ans, n - x[cur] + n - y[cur] + dp[cur]);
update(d, dp[cur] - x[cur] - y[cur]);
}
printf("%lld\n", ans);
return 0;
}

浙公网安备 33010602011771号