数据结构:树状数组、线段树
树状数组/线段树都可以把原来朴素的O(n2)变为O(n*logn),用于高效计算数列的前缀和。具体主要表现为3种情况:区间修改单点查询;单点修改区间查询;区间修改区间查询,这3种情况是一个递进关系,理解规律之后就比较好记。
树状数组的具体原理见https://www.cnblogs.com/xenny/p/9739600.html,这里就不详细描述了.....树状数组的显著特点就是借助lowbit来确定修改或者查询的位置。如下是区间修改区间查询的完整板子代码:
#include<bits/stdc++.h> using namespace std; int n,m; int s1[500005],s2[500005],a[500005]; int lowbit(int x) { return x&(-x); } void update(int i,int k) { int x=i; while(i<=n){ s1[i]+=k; s2[i]+=k*(x-1); //注意 i+=lowbit(i); } } int getSum(int i) { int res=0,x=i; while(i>0){ res+=s1[i]*x-s2[i]; //注意 i-=lowbit(i); } return res; } int main() { scanf("%d%d",&n,&m); memset(a,0,sizeof a); memset(s1,0,sizeof s1); memset(s2,0,sizeof s2); for(int i=1;i<=n;i++){ scanf("%d",&a[i]); update(i,a[i]-a[i-1]); //完全版树状数组在构建时输入a[i]-a[i-1] } int s,x,y,k; while(m--){ scanf("%d",&s); if(s==1){ scanf("%d%d%d",&x,&y,&k); update(x,k); update(y+1,-k); } else if(s==2){ scanf("%d",&x); printf("%d\n",getSum(x)-getSum(x-1)); } } return 0; }
树状数组的应用包括RMQ问题,求逆序对等等。树状数组用在RMQ问题需要在查询函数改变一下,详细情况见博客RMQ问题。关于求逆序对数,将原数列从开始一个一个地加入元素到和它大小所对应的颠倒位置。因为某个较大的数先出现在数列中,所以它先被加入到树状数组中,对较小的数(树状数组中位置靠后的数)产生影响,从而达到统计逆序对的功能。随着插入新数,顺便求和,可得到逆序对数。例题洛谷P1774、P2678
#include<bits/stdc++.h> using namespace std; long long n,num[500005],loc[500005],tree[500005]; bool cmd(int a,int b) { return num[a]==num[b]?a>b:num[a]>num[b]; } int lowbit(int x) { return x&(-x); } void add(int k,int v) { while(k<=n){ tree[k]+=v; k+=lowbit(k); } } int query(int k) { int ans=0; while(k>0){ ans+=tree[k]; k-=lowbit(k); } return ans; } int main() { scanf("%lld",&n); for(int i=1;i<=n;i++){ scanf("%lld",&num[i]); loc[i]=i; } sort(loc+1,loc+n+1,cmd); long long ans=0; for(int i=1;i<=n;i++){ ans+=query(loc[i]); add(loc[i],1); } printf("%lld\n",ans); return 0; }
线段树的核心在于push_down+lazytag,当然单点修改区间查询和区间修改单点查询这两种情况是用不着的。借助push_down和lazytag,线段树可以处理更加复杂的区间维护问题。
对于单点查询区间修改,线段树上每个节点的sum是该区间内数字变动的前缀。查询时,从上到下降路过节点的sum加起来求一个前缀和。
下方是完整的线段树模板,包括区间查询区间修改和乘除法。乘除法注意lazytag的操作与加法有所不同
/* 完整的线段树 */ #include<bits/stdc++.h> using namespace std; long long n,m,p; long long input[100005]; struct node { long long l,r,sum,plz,mlz; }tree[400005]; void build(long long i,long long l,long long r) { tree[i].l=l; tree[i].r=r; tree[i].plz=0; tree[i].mlz=1; if(l==r){ tree[i].sum=input[l]%p; return; } long long mid=(l+r)>>1; build(i<<1,l,mid); build(i<<1|1,mid+1,r); tree[i].sum=(tree[i<<1].sum+tree[i<<1|1].sum)%p; } inline void push_down(long long i) { long long k1=tree[i].mlz; long long k2=tree[i].plz; tree[i<<1].mlz=(tree[i<<1].mlz*k1)%p; tree[i<<1|1].mlz=(tree[i<<1|1].mlz*k1)%p; tree[i<<1].plz=(tree[i<<1].plz*k1+k2)%p; tree[i<<1|1].plz=(tree[i<<1|1].plz*k1+k2)%p; tree[i<<1].sum=(tree[i<<1].sum*k1+k2*(tree[i<<1].r-tree[i<<1].l+1))%p; tree[i<<1|1].sum=(tree[i<<1|1].sum*k1+k2*(tree[i<<1|1].r-tree[i<<1|1].l+1))%p; tree[i].plz=0; tree[i].mlz=1; } void add(long long i,long long l,long long r,long long k) { if(tree[i].l>=l&&tree[i].r<=r){ tree[i].plz=(tree[i].plz+k)%p; tree[i].sum=(tree[i].sum+k*(tree[i].r-tree[i].l+1))%p; return; } push_down(i); if(tree[i<<1].r>=l){ add(i<<1,l,r,k); } if(tree[i<<1|1].l<=r){ add(i<<1|1,l,r,k); } tree[i].sum=(tree[i<<1].sum+tree[i<<1|1].sum)%p; } void mul(long long i,long long l,long long r,long long k) { if(tree[i].l>=l&&tree[i].r<=r){ tree[i].sum=(tree[i].sum*k)%p; tree[i].plz=(tree[i].plz*k)%p; tree[i].mlz=(tree[i].mlz*k)%p; return; } push_down(i); if(tree[i<<1].r>=l){ mul(i<<1,l,r,k); } if(tree[i<<1|1].l<=r){ mul(i<<1|1,l,r,k); } tree[i].sum=(tree[i<<1].sum+tree[i<<1|1].sum)%p; } long long getsum(int i,int l,int r) { if(tree[i].l>=l&&tree[i].r<=r){ return tree[i].sum; } if(tree[i].r<l||tree[i].l>r){ return 0; } push_down(i); long long ans=0; if(tree[i<<1].r>=l){ ans+=getsum(i<<1,l,r); } if(tree[i<<1|1].l<=r){ ans+=getsum(i<<1|1,l,r); } return ans; } int main() { scanf("%lld%lld%lld",&n,&m,&p); for(int i=1;i<=n;i++){ long long temp; scanf("%lld",&temp); input[i]=temp%p; } build(1,1,n); long long s,x,y,k; while(m--){ scanf("%lld",&s); if(s==1){ scanf("%lld%lld%lld",&x,&y,&k); mul(1,x,y,k%p); } else if(s==2){ scanf("%lld%lld%lld",&x,&y,&k); add(1,x,y,k%p); } else if(s==3){ scanf("%lld%lld",&x,&y); printf("%lld\n",getsum(1,x,y)%p); } } return 0; }
(定位当前节点的左孩子和右孩子涉及到位运算,务必关注优先级和括号。或者保险一点直接i*2,i*2+1)
一些优秀的线段树题目中,线段树并不是考察的主要目标,而是求解答案的数据结构。POJ2991,考察区间之间的向量转向问题。
POJ2828,思路是关键,之后利用树状数组动态维护区间和
POJ2777,线段树每个节点的关键量变为一个状态压缩的值。查询时对每个状态压缩的值,找有几个状态
POJ2886,和POJ2828类似,需要注意细节
POJ1151,扫描线+离散化+线段树,求矩形并。将每个矩形拆为两条线段,排序,等待扫描;将线段两端离散化;开始扫描,利用线段树找出当前情况下,线段覆盖的总长度,乘高度差即可得解。这类题总是对整个区间查询,于是就没有必要建树...
#include<stdio.h> #include<algorithm> #include<string.h> using namespace std; const int maxn=250; struct seg{ double x1,x2,y; int flag; bool operator <(const seg &A) const{ return y<A.y; } }node[maxn]; int col[maxn*4]; double rec[maxn],sum[maxn*4]; void pushup(int i,int l,int r){ if(col[i]) sum[i]=rec[r+1]-rec[l]; else if(l==r) sum[i]=0; else sum[i]=sum[i*2]+sum[i*2+1]; } void update(int L,int R,int k,int l,int r,int i){ if(l>=L&&r<=R){ col[i]+=k; pushup(i,l,r); return; } int m=(l+r)/2; if(L<=m) update(L,R,k,l,m,i*2); if(R>m) update(L,R,k,m+1,r,i*2+1); pushup(i,l,r); } int main(){ int cas=1; int n; while(scanf("%d",&n)!=EOF){ if(n==0) break; int cnt=0; for(int i=1;i<=n;i++){ double a,b,c,d; scanf("%lf%lf%lf%lf",&a,&b,&c,&d); node[cnt].x1=a;node[cnt].x2=c;node[cnt].y=b;node[cnt].flag=1;rec[cnt]=a;cnt++; node[cnt].x1=a;node[cnt].x2=c;node[cnt].y=d;node[cnt].flag=-1;rec[cnt]=c;cnt++; } sort(node,node+cnt); sort(rec,rec+cnt); memset(col,0,sizeof(col)); memset(sum,0,sizeof(sum)); double ans=0; for(int i=0;i<cnt-1;i++){ int l=lower_bound(rec,rec+cnt,node[i].x1)-rec; int r=lower_bound(rec,rec+cnt,node[i].x2)-rec-1; if(l<=r) update(l,r,node[i].flag,0,cnt-1,1); ans+=sum[1]*(node[i+1].y-node[i].y); } printf("Test case #%d\n",cas++); printf("Total explored area: %.2f\n\n",ans); } }
HDU1255,在上一题的基础上,需要求的是被覆盖两次以上的区间长度,改进pushup。
void pushup(int i,int l,int r){ if(col[i]>=2){ s2[i]=s1[i]=rec[r+1]-rec[l]; } else if(col[i]==1){ s1[i]=rec[r+1]-rec[l]; if(l==r) s2[i]=0; else s2[i]=s1[i*2]+s1[i*2+1]; } else{ if(l==r) s1[i]=s2[i]=0; else{ s1[i]=s1[i*2]+s1[i*2+1]; s2[i]=s2[i*2]+s2[i*2+1]; } } }
POJ1177,求矩阵并的周长。和求矩阵并比较类似,分两次求和
#include<stdio.h> #include<algorithm> #include<cmath> #include<string.h> using namespace std; const int maxn=10005; struct seg{ int x1,x2,y,flag; bool operator < (const seg &A) const{ return y<A.y; } }node1[maxn],node2[maxn]; int n,rec1[maxn],rec2[maxn],col[maxn*4],sum[maxn*4]; void pushup(int i,int l,int r,int f){ if(col[i]){ if(f) sum[i]=rec1[r+1]-rec1[l]; else sum[i]=rec2[r+1]-rec2[l]; } else if(l==r) sum[i]=0; else sum[i]=sum[i*2]+sum[i*2+1]; } void update(int i,int l,int r,int k,int L,int R,int f){ if(l>=L&&r<=R){ col[i]+=k; pushup(i,l,r,f); return; } int m=(l+r)/2; if(m>=L) update(i*2,l,m,k,L,R,f); if(R>m) update(i*2+1,m+1,r,k,L,R,f); pushup(i,l,r,f); } int main(){ while(scanf("%d",&n)!=EOF){ int cnt=0; for(int i=0;i<n;i++){ int a,b,c,d; scanf("%d%d%d%d",&a,&b,&c,&d); node1[cnt].x1=a;node1[cnt].x2=c;node1[cnt].y=b;node1[cnt].flag=1;rec1[cnt]=a; node2[cnt].x1=b;node2[cnt].x2=d;node2[cnt].y=a;node2[cnt].flag=1;rec2[cnt++]=b; node1[cnt].x1=a;node1[cnt].x2=c;node1[cnt].y=d;node1[cnt].flag=-1;rec1[cnt]=c; node2[cnt].x1=b;node2[cnt].x2=d;node2[cnt].y=c;node2[cnt].flag=-1;rec2[cnt++]=d; } sort(node1,node1+cnt); sort(rec1,rec1+cnt); memset(col,0,sizeof(col)); memset(sum,0,sizeof(sum)); int ans=0,last=0; for(int i=0;i<cnt;i++){ int l=lower_bound(rec1,rec1+cnt,node1[i].x1)-rec1; int r=lower_bound(rec1,rec1+cnt,node1[i].x2)-rec1-1; if(l<=r) update(1,0,cnt-1,node1[i].flag,l,r,1); ans+=abs(sum[1]-last); //printf("%d %d %d %d %d\n",i,l,r,sum[1],last); last=sum[1]; } //printf("%d\n",ans); sort(node2,node2+cnt); sort(rec2,rec2+cnt); memset(col,0,sizeof(col)); memset(sum,0,sizeof(sum)); last=0; for(int i=0;i<cnt;i++){ int l=lower_bound(rec2,rec2+cnt,node2[i].x1)-rec2; int r=lower_bound(rec2,rec2+cnt,node2[i].x2)-rec2-1; if(l<=r) update(1,0,cnt-1,node2[i].flag,l,r,0); ans+=abs(sum[1]-last); last=sum[1]; } printf("%d\n",ans); } }