线段树基础总结
参考题目:A Simple Problem with Integers POJ-3468
距离上一次写这道题已经过去两个月了,前天打模拟赛时连线段树都手敲不出来了。所以这次重新来复习一下线段树。
这次主要是记录一下对线段树区间修改的理解:
一开始我们先是理解线段树的建树原理以及查询原理。利用一个二叉树建立,每个父节点可以记录左右子结点的和或者最大值,以此来维护区间内容。
查询也类似,通过递归寻找,找到所有符合条件(区间内深度最浅)结点,然后将其求和返回或者求最大值返回。
而在之后有单点修改。单点修改重建树来理解,就是递归,找到范围(l,r) ,l=r 时的结点位置更新并 push_up() 向上建树。
但是对于大量数据查询,以及面对区间修改问题时,要一个一个找到位置再更新push_up()就会非常费时间。
所以我们试想能否使用一个标记 tag 传递,在更新范围内的结点就把 tag 传递下去,然后对所有含有tag 的叶子结点更新,再push_up();
但是对于上面实际上还有更加优化的方案:即 lazy[] 标记。我们用 lazy[]标记记录每个结点修改信息,就像上面一样。但是当递归到某个完全被包含于查询区间的 部分时,我们直接对这个区间更新 即 (r-l+1) * lazy[p] . (完全被包含,所以直接乘以长度即可),单接下来我就直接返回了,不往下递归了。那没有被更新的那些子结点怎么办?
既然叫 lazy[] 就是真的懒的意思。如果你下次查的区间不包含我要更新的所有叶子结点(或者没有被更新到的子结点部分),那我直接利用已有的信息就可以返回你要的答案。
如果你查的区间包含的话,那我就在查询的同时把那些子结点更新掉,利用我之前 残留下没有更新的 lazy[] 值。所以这样大大减少了修改的时间复杂度。
code:
#include <bits/stdc++.h> #define IOS ios::sync_with_stdio(0); cin.tie(0); #define mp make_pair #define Accept 0 using namespace std; typedef long long ll; typedef unsigned long long ull; typedef pair<int, int> pii; const double Pi = acos(-1.0); const double esp = 1e-9; const int inf = 0x3f3f3f3f; const int maxn = 5e5+7; const int maxm = 1e6+7; const int mod = 1e9+7; int n,m; struct segmentTree { ll arr[maxn]; ll tree[maxn<<2]; ll lazy[maxn<<2]; //结点从1开始到n void init(int n){ for(int i=1;i<=n;i++){ scanf("%lld",&arr[i]); } } void build(int l,int r,int p){ lazy[p] = 0; if(l==r) { tree[p] = arr[l]; return ;} int mid = (l+r)>>1; build(l,mid,p<<1); build(mid+1,r,p<<1|1); sum_up(p); } void sum_down(int p,int m){ if(lazy[p]){ lazy[p<<1] += lazy[p]; lazy[p<<1|1] += lazy[p]; tree[p<<1] += (ll)lazy[p]*(m-(m>>1)); tree[p<<1|1] += (ll)lazy[p]*(m>>1); lazy[p] = 0; } } void sum_up(int p){ tree[p] = tree[p<<1] + tree[p<<1|1]; } void update(int ql,int qr,int k,int l,int r,int p){ //完全包含区间,直接返回就不更新其子结点了 //直到下次需要子结点时才更新
if(ql>r||qr<l) return ; if(ql<=l&&r<=qr){ tree[p] += (ll) (r-l+1) * k; lazy[p] += k;//用lazy[p]暂时保存所有跟新值(次数) return ; } sum_down(p,r-l+1); int mid = (l+r)>>1; //部分包含左区间 if(ql<=mid) update(ql,qr,k,l,mid,p<<1); //部分包含右区间 if(qr>mid) update(ql,qr,k,mid+1,r,p<<1|1); sum_up(p); } ll query(int ql,int qr,int l,int r,int p){ ll res = 0;
if(ql>r||qr<l) return 0; if(ql<=l&&r<=qr){ return tree[p]; } sum_down(p,r-l+1); int mid = (l+r)>>1; if(ql<=mid) res += query(ql,qr,l,mid,p<<1); if(qr>mid) res += query(ql,qr,mid+1,r,p<<1|1); return res; } }seg; int main(){ scanf("%d %d",&n,&m); seg.init(n); seg.build(1,n,1); string s; for(int i=0;i<m;i++){ cin>>s; if(s[0]=='Q'){ int L,R; scanf("%d %d",&L,&R); printf("%lld\n",seg.query(L,R,1,n,1)); }else{ int L,R,k; scanf("%d %d %d",&L,&R,&k); seg.update(L,R,k,1,n,1); } } return 0; }
2020/4/13 更新部分
对线段树的模板进行简化一部分:如果在设计线段树对 区间和、区间最值、xor等,区间修改、区间add、区间乘 的操作时 可能代码量会相对较大,所以这里写一个较为 精简一点的模板
#include<bits/stdc++.h> using namespace std; typedef long long ll; const int maxn = 1e5+10; int n, m; ll x; //区间最小、最大、和,区间修改和增加 struct Node{ ll Min, Max, sum; Node operator + (const Node &p)const{ Node ans; ans.Min = min(Min, p.Min); ans.Max = max(Max, p.Max); ans.sum = sum + p.sum; return ans; } void add(ll x, ll num){ Min += x; Max += x; sum += x *1ll* num; } void update(ll x, ll num){ Min = x; Max = x; sum = x * 1ll* num; } }node[maxn<<2]; ll lazy[maxn<<2][3]; void build(int l, int r, int rt){ if(l == r){ scanf("%lld", &x); node[rt] = Node{x, x, x}; return ; } int mid = l+r>>1; build(l, mid, rt<<1); build(mid+1, r, rt<<1|1); node[rt] = node[rt<<1] + node[rt<<1|1]; } void pushdown(int l, int r, int rt){ if(lazy[rt][1]){ ll x = lazy[rt][1], mid = l+r>>1; node[rt<<1].update(x, mid-l+1); node[rt<<1|1].update(x, r-mid); lazy[rt<<1][0] = lazy[rt<<1|1][0] = 0; lazy[rt<<1][1] = lazy[rt<<1|1][1] = x; lazy[rt][1] = 0; } if(lazy[rt][0]){ ll x = lazy[rt][0], mid = l+r>>1; node[rt<<1].add(x, mid-l+1); node[rt<<1|1].add(x, r-mid); lazy[rt<<1][0] += x; lazy[rt<<1|1][0] += x; lazy[rt][0] = 0; } } void add(int L, int R, ll c, int l, int r, int rt, int id){ if(L <= l &&r <= R){ if(id == 1){ node[rt].add(c, r-l+1); lazy[rt][0] += c; } else{ node[rt].update(c, r-l+1); lazy[rt][1] = c; lazy[rt][0] = 0; } return; } pushdown(l, r, rt); int mid = l+r>>1; if(L <= mid) add(L, R, c, l, mid, rt<<1, id); if(R > mid) add(L, R, c, mid+1, r, rt<<1|1, id); node[rt] = node[rt<<1] +node[rt<<1|1]; } Node query(int L, int R, int l, int r, int rt){ if(L <= l &&r <= R) return node[rt]; pushdown(l, r, rt); int mid = l+r>>1; Node ans; ans.Min = 1e18; ans.Max = -1e18; ans.sum = 0; if(L <= mid) ans = ans + query(L, R, l, mid,rt<<1); if(R > mid) ans = ans + query(L, R, mid+1, r, rt<<1|1); return ans; } int main(){ scanf("%d", &n); build(1, n, 1); scanf("%d",&m); int id, l, r; while(m--){ scanf("%d%d%d", &id, &l ,&r); if(id <= 2){ scanf("%lld", &x); add(l, r, x, 1, n, 1, id); } else{ Node ans = query(l, r, 1, n, 1); printf("%lld %lld %lld\n", ans.sum, ans.Max, ans.Min); } } return 0; }