树状数组进阶
一、 区间修改+单点查询
分析
我们已经知道了如何进行树状数组的单点修改+区间查询,现在要进行区间修改,很容易想到暴力,若左边界为 \(l\) ,右边界为 \(r\) ,进行 \(r - l + 1\) 次单点修改操作修改区间内每个数的值,时间复杂度不可接受。
考虑优化,因为进行的是区间操作,很容易想到差分,开一个数组 d 储存原数组的差分信息,用树状数组维护它的前缀和,进行区间修改操作 1 l r k (含义:将区间 \([l,r]\) 内每个数加上 \(k\))时,只用将 \(l\) 和 \(r + 1\) 的位置进行单点修改操作 add(l,k) 和 add(r+1,-k) 即可。
对于单点查询的操作,根据差分的性质,可得 \(a_i = \sum_{j=1}^{i}d_j\) ,正好是树状数组维护的差分数组的前缀和信息,所以进行查询操作 query(i) 即可。
时间复杂度
- 区间修改 \(O(log\ n)\)
- 单点查询 \(O(log\ n)\)
Code
#include <bits/stdc++.h>
using namespace std;
const int N = 5e5 + 10;
int d[N], t[N];
int n, m;
int lowbit(int x) { return x & -x; }
void build() {
for (int i = 1; i <= n; ++i) {
t[i] += d[i];
int j = i + lowbit(i);
if (j <= n) t[j] += t[i];
}
}
void update(int i, int x) {
while (i <= n) {
t[i] += x;
i += lowbit(i);
}
}
int query(int i) {
int ans = 0;
while (i > 0) {
ans += t[i];
i -= lowbit(i);
}
return ans;
}
int main() {
cin >> n >> m;
for (int i = 1, prev = 0, x; i <= n; ++i) {
cin >> x;
d[i] = x - prev;
prev = x;
}
build();
while (m--) {
int op;
cin >> op;
if (op == 1) {
int x, y, k;
cin >> x >> y >> k;
update(x, k);
update(y + 1, -k);
} else {
int x;
cin >> x;
cout << query(x) << '\n';
}
}
return 0;
}
二、 区间修改+区间查询
例题:LOJ #132. 树状数组 3 :区间修改,区间查询
分析
我们已经实现区间修改,如何进行区间查询呢?令 \(f(k)\) 为要查询 \(k\) 位置上的前缀和,即 \(f(k)=\sum_{i=0}^ka_i\) ,那么
\[f(k)=\sum_{i=1}^{k}\sum_{j=1}^id_j,
\]
可以发现 \(d_j\) 被加了 \(k-j+1\) 次,所以
\[\begin{align*}
f(k) &= \sum_{i=1}^k(k-i+1)d_i \\
&= \sum_{i=1}^k[(k+1)d_i - id_i] \\
&= (k+1)\sum_{i=1}^kd_i-\sum_{i=1}^kid_i
\end{align*}
\]
现在两项都是前缀和形式,我们已经维护了 \(\sum_{i=1}^kd_i\) ,但无法推导出\(\sum_{i=1}^kid_i\) ,所以可以再开一个树状数组维护 \(\sum_{i=1}^kid_i\) ,这样使用两个树状数组维护差分信息即可实现区间修改+区间查询。
时间复杂度
- 区间修改 \(O(log\ n)\)
- 单点查询 \(O(log\ n)\)
Code
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const int N = 1e6 + 10;
ll a[N], d[N], t1[N], t2[N];
int n, m;
int lowbit(int x) { return x & -x; }
void build() {
for (int i = 1; i <= n; ++i) {
t1[i] += d[i];
t2[i] += i * d[i];
int j = i + lowbit(i);
if (j <= n) {
t1[j] += t1[i];
t2[j] += t2[i];
}
}
}
void update(int i, ll x) {
for (int j = i; j <= n; j += lowbit(j)) {
t1[j] += x;
t2[j] += i * x;
}
}
ll query(int i) {
ll ans = 0;
for (int j = i; j > 0; j -= lowbit(j)) ans += (i + 1LL) * t1[j] - t2[j];
return ans;
}
int main() {
cin >> n >> m;
for (int i = 1, prev = 0, x; i <= n; ++i) {
cin >> x;
d[i] = x - prev;
prev = x;
}
build();
while (m--) {
int op;
cin >> op;
if (op == 1) {
int l, r, k;
cin >> l >> r >> k;
update(l, k);
update(r + 1, -k);
} else {
int l, r;
cin >> l >> r;
cout << query(r) - query(l - 1) << '\n';
}
}
return 0;
}
未完...

浙公网安备 33010602011771号