【总结】树状数组
树状数组的概念
树状数组(Binary Indexed Tree(B.I.T))是一个区间查询和单点修改复杂度都为 \(\log n\) 的数据结构。主要用于查询任意两点之间的所有元素之和。
引入
-
问题的提出
有一个一维数组,长度为 \(n\)。
对这个数组做两种操作:- 修改,对第 \(i \to j\) 之间的某元素增加 \(v\)。
- 求和,求 \(i\) 到 \(j\) 的和。
-
朴素算法
-
用
for循环从 \(i\) 到 \(j\) 依次求和,时间复杂度:\(O(n)\) -
缺陷:当数据规模极大的时候,将会变得效率低下。
-
-
前缀和
- 我们可以做到查询 \(O(1)\),
- 但是我们的修改依旧很慢
-
我们可以采用树状数组

lowbit
lowbit(i)的意思是将 \(i\) 转化成二进制数之后,只保留最低位的 \(1\) 其后面的 \(0\),截断前面的内容,然后再转成十进制数,这个数也是树状数组中 \(i\) 号位的子叶个数。
这里直接给出式子
lowbit(x) = x & (-x)
Build
我们可以采用一种类似 \(DP\) 的方式,转移每一个节点对后面节点的贡献,最后得到我们的树状数组,其时间复杂度为 \(O(n)\)
for (int i = 1; i <= n; i++) {
scanf("%d", &a[i]);
c[i] += a[i];
if (i + lowbit(i) <= n)
c[i + lowbit(i)] += c[i];
}
Update
该操作可以将 \(A_x + k\)。所以我们可以用这种方式进行建树,即对于每一种节点进行一次 Update,但是这样的时间复杂度是 \(O(n log n)\) 的。建议结合图分析。
void update(int x, int k) {
while (x <= n) {
c[x] += k;
x += lowbit(x);
}
}

query
该操作可以求得 \(\sum_{i = 1}^x i\),即 \(x\) 的前缀和。建议结合图分析。
int query(int x) {
int ans = 0;
while (x) {
ans += c[x];
x -= lowbit(x);
}
return ans;
}

不要问我为什么要这张图放这么多次
推广
- 单点查询,区间修改
我们可以联想到一个叫差分的东西,我们可以维护一个差分数组,将区间查询,单点修改转化为单点查询,区间修改。
- 区间修改,区间查询
我们定义序列 \(a\) 的差分为 \(b\), 我们要求一个前缀 \(r\) 求和由差分数组的定义得到 \(a_i = \sum_{j = 1} ^ i b_j\)
进行推导
我们只需分别维护他们就可以了。
具体代码如下:
#include <cstdio>
#define int long long
#define lowbit(x) (x & (-x))
const int MAXN = 1e6 + 5;
int n, q;
int a[MAXN], c1[MAXN], c2[MAXN];
void Add(int x, int v) {
int v1 = v * x;
while (x <= n) {
c1[x] += v, c2[x] += v1;
x += lowbit(x);
}
}
int Sum(int* c, int x) {
int res = 0;
while (x) {
res += c[x];
x -= lowbit(x);
}
return res;
}
void update(int l, int r, int k) {
Add(l, k), Add(r + 1, -k);
}
int query(int l, int r) {
return (r + 1) * Sum(c1, r) - l * Sum(c1, l - 1) - (Sum(c2, r) - Sum(c2, l - 1));
}
signed main() {
scanf("%lld %lld", &n, &q);
for (int i = 1; i <= n; i++) {
scanf("%lld", &a[i]);
update(i, i, a[i]);
}
while (q--) {
int op, x, y, k;
scanf("%lld %lld %lld", &op, &x, &y);
if (op == 1) {
scanf("%lld", &k);
update(x, y, k);
} else {
printf("%lld\n", query(x, y));
}
}
return 0;
}
我们可以类比普通的树状数组得到二维树状数组,但是对于这一系列操作,我们需要用到容斥原理。
- 二维树状数组:区间查询,单点修改
#include <cstdio>
#define int long long
#define lowbit(x) (x & (-x))
const int MAXN = (2 << 12) + 5;
int n, m, op, x, y, u, v, k;
int a[MAXN][MAXN];
int c[MAXN][MAXN];
void update(int x, int y, int k) {
for (int i = x; i <= n; i += lowbit(i))
for (int j = y; j <= m; j += lowbit(j))
c[i][j] += k;
}
int query(int x, int y) {
int res = 0;
for (int i = x; i > 0; i -= lowbit(i))
for (int j = y; j > 0; j -= lowbit(j))
res += c[i][j];
return res;
}
signed main() {
scanf("%lld %lld", &n, &m);
while (~scanf("%lld", &op)) {
if (op == 1) {
scanf("%lld %lld %lld", &x, &y, &k);
update(x, y, k);
} else {
scanf("%lld %lld %lld %lld", &x, &y, &u, &v);
printf("%lld\n", query(u, v) - query(u, y - 1) - query(x - 1, v) + query(x - 1, y - 1));
}
}
return 0;
}
- 二维树状数组:单点查询,区间修改
#include <cstdio>
#define int long long
#define lowbit(x) (x & (-x))
const int MAXN = 5005;
int n, m, op, x, y, u, v, k;
int a[MAXN][MAXN];
int c[MAXN][MAXN];
void update(int x, int y, int k) {
for (int i = x; i <= n; i += lowbit(i))
for (int j = y; j <= m; j += lowbit(j))
c[i][j] += k;
}
int query(int x, int y) {
int res = 0;
for (int i = x; i > 0; i -= lowbit(i))
for (int j = y; j > 0; j -= lowbit(j))
res += c[i][j];
return res;
}
signed main() {
scanf("%lld %lld", &n, &m);
while (~scanf("%lld", &op)) {
if (op == 1) {
scanf("%lld %lld %lld %lld %lld", &x, &y, &u, &v, &k);
update(u + 1, v + 1, k), update(x, y, k), update(x, v + 1, -k), update(u + 1, y, -k);
} else {
scanf("%lld %lld", &x, &y);
printf("%lld\n", query(x, y));
}
}
return 0;
}
- 二维树状数组:区间查询,区间修改
同一维树状数组,我们首先推一下式子
#include <cstdio>
#define int long long
#define lowbit(x) (x & (-x))
const int MAXN = 5005;
int n, m, op, x, y, u, v, k;
int a[MAXN][MAXN];
int c1[MAXN][MAXN], c2[MAXN][MAXN], c3[MAXN][MAXN], c4[MAXN][MAXN];
void update(int x, int y, int k) {
for (int i = x; i <= n; i += lowbit(i))
for (int j = y; j <= m; j += lowbit(j)) {
c1[i][j] += k;
c2[i][j] += x * k;
c3[i][j] += y * k;
c4[i][j] += x * y * k;
}
}
int query(int x, int y) {
int res = 0;
for (int i = x; i > 0; i -= lowbit(i))
for (int j = y; j > 0; j -= lowbit(j)) {
res += (x + 1) * (y + 1) * c1[i][j] - (y + 1) * c2[i][j] - (x + 1) * c3[i][j] + c4[i][j];
}
return res;
}
signed main() {
scanf("%lld %lld", &n, &m);
while (~scanf("%lld", &op)) {
if (op == 1) {
scanf("%lld %lld %lld %lld %lld", &x, &y, &u, &v, &k);
update(u + 1, v + 1, k), update(x, y, k), update(x, v + 1, -k), update(u + 1, y, -k);
} else {
scanf("%lld %lld %lld %lld", &x, &y, &u, &v);
printf("%lld\n", query(u, v) - query(u, y - 1) - query(x - 1, v) + query(x - 1, y - 1));
}
}
return 0;
}
例题
本文来自博客园,作者:zhou_ziyi,转载请注明原文链接:https://www.cnblogs.com/zhouziyi/p/16527213.html

浙公网安备 33010602011771号