线段树

线段树

基本概念:

线段树是一种二叉搜索树,与区间树相似,它将一个区间划分成一些单元区间,每个单元区间对应线段树中的一个叶结点。(摘自百度百科)

线段树是一种完全二叉树,支持操作:

  1. 单点修改, 区间查询
  2. 区间修改, 单点查询
  3. 区间修改, 区间查询

线段树的基本操作:

建树操作(build):

每次向左右子树递归建树, 每个结点的子树构建后,注意pushup

void build(int u,int l,int r)
{
    
    if (l == r){
        tr[u] = {l,r,a[l],a[l],a[l],a[l]};          //单个结点的初始化
    }else{
        
        tr[u] = {l,r};
        
        int mid = l + r >> 1;
        build(u << 1, l , mid);
        build(u << 1 | 1, mid + 1, r);
        
        pushup(u);
    }
}

查询操作(query):

当每个结点维护的内容较多时, query()返回整个结构体, 注意根据题意选择递归方向

Node query(int u,int l,int r)
{
    if (tr[u].l >= l && tr[u].r <= r)
        return tr[u];
    
    int mid = tr[u].l + tr[u].r >> 1;
    
    if (l > mid)        //完全在右区间
        return query(u << 1 | 1, l, r);
    else if (r <= mid)  //完全在左区间
        return query(u << 1, l, r);
    else{               //横跨两个区间
        auto left = query(u << 1, l, r);
        auto right = query(u << 1 | 1, l, r);
        Node res;
		
        pushup(ans, left, right);		//用左右结点更新ans
		
	return ans;
    }

}

修改操作(modify):

每个结点的子树修改结束后注意pushup()

单点修改:

void modify(int u,int x,int v)
{
    if (tr[u].l == x && tr[u].r == x)
        tr[u] = {x, x, v, v, v, v};
    else{
        
        int mid = tr[u].l + tr[u].r >> 1;
        if (x <= mid)
            modify(u << 1, x, v);
        else
            modify(u << 1 | 1, x, v);
        
        pushup(u);
    }
}

修改之前pushdown, 修改后pushup

区间修改:

void modify(int u, int l, int r, int d)
{
    
    if (tr[u].l >= l && tr[u].r <= r){
        tr[u].add += d;
        tr[u].sum += (LL)d * (tr[u].r - tr[u].l + 1);
    }
        
    else{
        
        pushdown(u);
        
        int mid = tr[u].r + tr[u].l >> 1;
        if (l <= mid)  
            modify(u << 1, l, r, d);
        if (r > mid)
            modify(u << 1 | 1, l, r, d);
        
        pushup(u);
    }
}

pushup,pushdown:

pushup(u): 用\(u\)的子结点更新\(u\)的信息
pushdown(u): 用\(u\) 的信息更新 \(u\)子结点的信息

void pushup(int u)
{
    tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}

void pushdown(int u)
{
    auto &root = tr[u], &left = tr[u << 1], &right = tr[u << 1 | 1];
    if (tr[u].add){
       
       left.add += root.add;
       right.add += root.add;
       
       left.sum += (LL)(left.r - left.l + 1) * root.add;
       right.sum += (LL)(right.r - right.l + 1) * root.add;
       
       root.add = 0;
       
    }
}

单点修改

AcWing 1275. 最大数

算法思路:

题目要求每次在队尾加一个结点, 建树时将所有结点设为0.

#include <iostream>
#include <cstring>

using namespace std;

const int N = 2 * 1e5 + 10;

struct Node{
    int l, r;
    int v;
}tr[4 * N];
int m, p;

void pushup(int u)
{
    tr[u].v = max(tr[u << 1].v, tr[u << 1 | 1].v);
}

void build(int u, int l, int r)
{
    tr[u].l = l;
    tr[u].r = r;
    
    if (l == r)
        return ;
    
    int mid = l + r >> 1;
    
    build(u << 1, l, mid);
    build(u << 1 | 1, mid + 1, r);
} 


int query(int u, int l, int r)
{
    if (tr[u].l >= l && tr[u].r <= r)
        return tr[u].v;
    
    int mid = tr[u].l + tr[u].r >> 1;
    
    int ans = 0;
    if (l <= mid)
        ans = max(query(u << 1, l, r), ans);
    if (r > mid)
        ans = max(query(u << 1 | 1, l, r), ans);
    
    return ans;
}

void modify(int u, int x, int v)
{
    if (tr[u].l == x && tr[u].r == x){
        tr[u].v = v;
        return;
    }
    
    int mid = tr[u].l + tr[u].r >> 1;
    
    if (x <= mid)
        modify(u << 1, x, v);
    else
        modify(u << 1 | 1, x, v);
    
    pushup(u);
}


int main()
{
    cin >> m >> p;
    int last = 0;
    int cnt = 0;
    build(1, 1, m);
    
    char op[2];
    int x;
    while (m -- ){
        scanf("%s%d",op, &x);
        
        if (op[0] == 'Q'){
            
            last = query(1, cnt - x + 1, cnt);
            printf("%d\n",last);
            
        }else{
            
            x = (last + x) % p;
            modify(1, cnt + 1, x);
            cnt ++;
            
        }
        
    }
    
    return 0;
}

AcWing 245. 你能回答这些问题吗

算法思路:

题目要求单点修改+区间查询, 只需要pushup操作, 问题在如何确定线段树所要维护的内容:

求最大连续字段和\(v\),对于\(u\)结点的\(v\),分三种情况

  1. ==左区间的\(v\)
  2. ==右区间的\(v\)
  3. ==左区间的最大后缀和+右区间的最大前缀和

对于最大前缀和与最大后缀和, 分两种情况(以最大前缀和为例):

  1. ==左区间的最大前缀和
  2. ==左区间的和+右区间的最大前缀和

综上,需要维护信息: 区间和, 区间最大前缀和, 区间最大后缀和, 最大连续子段和

#include <iostream>
#include <cstring>
#include <queue>

using namespace std;

const int N = 500000 + 10;

struct Node{
    int l, r;
    int v;
    int ll, rr;
    int sum;
}tr[N * 4];
int n, m;
int a[N];

void pushup(Node &root, Node &left, Node &right)
{
    //auto left = &tr[u << 1], right = &tr[u << 1 | 1], root = &tr[u];
    
    root.ll = max(left.ll, left.sum + right.ll);
    
    root.sum = left.sum + right.sum;
    
    root.rr = max(right.rr, right.sum + left.rr);
    
    root.v = max(max(right.v, left.v), left.rr + right.ll);
    
}

void pushup(int u)
{
    pushup(tr[u], tr[u << 1], tr[u << 1 | 1]);
}

void build(int u, int l, int r)
{
    if (l == r)
        tr[u] = {l, r, a[l], a[l], a[r], a[r]};
    else{
        tr[u] = {l, r};
        
        int mid = l + r >> 1;
        
        build (u << 1, l, mid);
        build (u << 1 | 1, mid + 1, r);
        
        pushup(u);
    }
}

Node query(int u, int l, int r)
{
    if (tr[u].l >= l && tr[u].r <= r)
        return tr[u];
        
    int mid = tr[u].l + tr[u].r >> 1;
    
    if (r <= mid)
        return query(u << 1, l, r);
    else if (l > mid)
        return query(u << 1 | 1, l, r);
    else{
        auto ans1 = query(u << 1, l ,r);
        auto ans2 = query(u << 1 | 1, l, r);
        Node ans;
        pushup(ans, ans1, ans2);
        
        return ans;
    }
    
    
}

void modify(int u, int x, int v)
{
    if (tr[u].l == x && tr[u].r == x){
        tr[u].v = v;
        tr[u].sum = v;
        tr[u].ll = v;
        tr[u].rr = v;
        return ;
    }
    
    int mid = tr[u].l + tr[u].r >> 1;
    
    if (x <= mid)
        modify(u << 1, x, v);
    else
        modify(u << 1 | 1, x, v);
    
    pushup(u);
}


int main()
{
    cin >> n >> m;
    
    for (int i = 1; i <= n; i ++ )
        scanf("%d",&a[i]);
    
    build(1, 1, n);
    
    int op;
    int x, y;
        
    while (m -- ){
        scanf("%d%d%d",&op, &x, &y);
        
        if (op == 1){
            if (x > y)
                swap(x, y);
            auto ans = query(1, x, y);
            
            printf("%d\n", ans.v);
            
        }else{
            
            modify(1, x, y);
        }
        
    }
    
    return 0;
    
    
}

AcWing 246. 区间最大公约数

算法思路:

题目涉及区间修改和区间查询, 区间整体加减, 通过差分数组转化为单点的修改

一、差分数组:

\(a[]\)为原数组, \(b[]\)为差分数组:

  1. \(b[i] = a[i] - a[i - 1]\)
  2. \(a[i] = b[1] + b[2] + ... + b[i]\)
  3. \(a[ ]\)的区间\([L,R]\)\(d\) 等价于: \(b[ ]\)\(b[l] + d, b[r + 1] - d\)

二、最大公约数的性质:

\(gcd(a_{1}, a_{2}, a_{3} ... a_{n}) = gcd(a_{1}, a_{2} - a_{1}, a_{3} - a_{2} ... a_{n} - a_{n - 1})\)

综上, 线段树维护信息: 差分数组的前缀和, 区间的最大公约数. (所有信息都是关于差分数组的)
对于查询区间\([L, R]\): $ ans = gcd(a_{L}, a_{L + 1} - a_{L} ... a_{R} - a_{R - 1}) = gcd( 差分数组前缀和, gcd( b_{L + 1}, b_{L + 2} ... b_{R}) $

#include <iostream>
#include <cstring>
#include <queue>

using namespace std;
typedef long long LL;
const int N = 500010;

struct Node{
    int l, r;
    LL sum, v;
}tr[N * 4];

LL a[N];
LL b[N];
int n, m;

LL gcd(LL a, LL b)
{
    return b ? gcd(b, a % b) : a;
}

void pushup(Node &root, Node &ll, Node &rr)
{
    root.sum = ll.sum + rr.sum;
    root.v = gcd(ll.v, rr.v);
}

void pushup(int u)
{
    pushup(tr[u], tr[u << 1], tr[u << 1 | 1]);
}

void build(int u, int l, int r)
{
    if (l == r){
        tr[u] = {l, l, b[l], b[l]};
        return;
    }
    
    tr[u].l = l;
    tr[u].r = r;
    
    int mid = l + r >> 1;
    
    build(u << 1, l, mid);
    build(u << 1 | 1, mid + 1, r);
    
    pushup(u);
    
    
}


void modify(int u, int x, LL v)
{
    
    if (tr[u].l == x && tr[u].r == x){
        tr[u].sum += v;
        tr[u].v += v;
        return;
    }
    
    int mid = tr[u].l + tr[u].r >> 1;
    
    if (x <= mid)
        modify(u << 1, x, v);
    else
        modify(u << 1 | 1, x, v);
        
    pushup(u);
    
}

Node query(int u, int l, int r)
{
    
    if (tr[u].l >= l && tr[u].r <= r)
        return tr[u];
    
    int mid = tr[u].l + tr[u].r >> 1;
    
    if (r <= mid)
        return query(u << 1, l, r);
    else if (l > mid)
        return query(u << 1 | 1, l, r);
    else{
        auto ans1 = query(u << 1, l, r);
        auto ans2 = query(u << 1 | 1, l, r);
        
        Node ans;
        
        pushup(ans, ans1, ans2);
        
        return ans;
    }
    
}

int main()
{
    cin >> n >> m;
    
    for (int i = 1; i <= n; i ++ )
        cin >> a[i];
    for (int i = 1; i <= n; i ++ )
        b[i] = a[i] - a[i - 1];
    
    build(1, 1, n);
    
    char op[4];
    int l, r;
    LL x;
    while (m -- ){
        scanf("%s%d%d",op, &l, &r);
        
        if (op[0] == 'Q'){
            auto it1 = query(1, 1, l);
            auto it2 = query(1, l + 1, r);
            printf("%lld\n", abs(gcd(it1.sum, it2.v)));
            
        }else{
            scanf("%lld",&x);
            modify(1, l, x);
            if (r + 1 <= n) modify(1, r + 1, -x);
        }
    }
    
    
    return 0;
    
}
posted @ 2021-03-16 20:39  lhqwd  阅读(33)  评论(0)    收藏  举报