树状数组学习笔记
树状数组学习笔记
友链
树状数组是一种支持很多操作,且常数很小的数据结构,ta支持区间查询,修改与单点的查询与修改。
ta的操作基于二进制,如图:
数组 \(c\) 表示区间的和,每个 \(c_i\) 都有一个关辖的区间,那么问题来了,怎么判断管辖的区间呢?
这个时候 \(lowbit\) 就可以帮助我们:
int lowbit(int x){
// x 的二进制中,最低位的 1 以及后面所有 0 组成的数。
// lowbit(0b01011000) == 0b00001000
// ~~~~^~~~
// lowbit(0b01110010) == 0b00000010
//
return x&(-x);
}
lowbit 表示 最低位 1和后面所有 0 组成的数。
最普通的树状数组支持单点修改,去间查询,\(c_i\) 也就表示为 \([i-lowbit(i)+1,i]\) 的区间和。
code:
#include<bits/stdc++.h>
using namespace std;
const int N=1e6+10;
long long c[N],a[N],n,q;
long long lowbit(int x) {return (x&-x);}
void add(int x,int v){
int i=x;
while(i<=n){//定义 c[x+lowbit(x)] 包含 c[x]
c[i]+=v;
i+=lowbit(i);
}
}
long long sum(int x){
long long ans=0;
while(x>0){//枚举每个在1~x中的区间
ans+=c[x];
x-=lowbit(x);
}
return ans;
}
signed main(){
cin>>n>>q;
for(int i=1;i<=n;i++) cin>>a[i];
for(int i=1;i<=n;i++) add(i,a[i]);
while(q--){
int op;
cin>>op;
if(op==1){
int x,v;
cin>>x>>v;
add(x,v);
}else{
int l,r;
cin>>l>>r;
cout<<sum(r)-sum(l-1)<<endl;
}
}
}
加上差分以后。就有差分数组 \(d\)。
就支持区间修改,单点查询了(改两个端点就可以了)。
code:
#include<bits/stdc++.h>
using namespace std;
const int N=1e6+10;
long long c[N],a[N],n,q;
long long lowbit(int x) {return (x&-x);}
void add(int x,int v){
while(x<=n){
c[x]+=v;
x+=lowbit(x);
}
}
long long sum(int x){
long long ans=0;
while(x>0){
ans+=c[x];
x-=lowbit(x);
}
return ans;
}
signed main(){
cin>>n>>q;
for(int i=1;i<=n;i++) cin>>a[i];
for(int i=1;i<=n;i++) add(i,a[i]-a[i-1]);
while(q--){
int op;
cin>>op;
if(op==1){
int l,r,v;
cin>>l>>r>>v;
add(l,v);
add(r+1,-v);
}else{
int l;
cin>>l;
cout<<sum(l)<<endl;
}
}
}
但有的时候,我们既要支持区间修改,又要支持区间查询怎么办。
\(\sum_{i=1}^{k} a_i = \sum_{i=1}^{k} \sum_{j=1}^{i} d_j\)
\(= \sum_{j=1}^{k} \sum_{i=j}^{k} d[j]\)
\(= \sum_{j=1}^{k} (k - j + 1) \times d[j]\)
最后一步中,\(d_j\) 出现了 \((k - j + 1)\) 次,所以可以得到:
\((k - j + 1) × d[j] = (k+1) × d_j - j × d_j\)
维护两个树状数组就可以了,一个维护 \(d_j\),一个维护 \(d_j \times j\)。
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=1e6+10;
int c[N],a[N],n,q,d[N];
int lowbit(int x) {return (x&-x);}
void add(int x,int v){
int id=x;
while(x<=n){
c[x]+=v;
d[x]+=v*id;
x+=lowbit(x);
}
}
int sum(int x){
int ans=0,id=x;
while(x>0){
ans+=(id+1)*c[x]-d[x];//公式
x-=lowbit(x);
}
return ans;
}
signed main(){
cin>>n>>q;
for(int i=1;i<=n;i++) cin>>a[i];
for(int i=1;i<=n;i++) add(i,a[i]-a[i-1]);
while(q--){
int op;
cin>>op;
if(op==1){
int l,r,v;
cin>>l>>r>>v;
add(l,v);
add(r+1,-v);
}else{
int l,r;
cin>>l>>r;
cout<<sum(r)-sum(l-1)<<endl;
}
}
}
浙公网安备 33010602011771号