算法学习(6):线段树

线段树

介绍

线段树(Segment Tree)几乎是算法竞赛最常用的数据结构了,它主要用于维护区间信息(要求满足结合律)。

建立

void build(ll l = 1, ll r = n, ll p = 1)
{
    if (l == r) // 到达叶子节点
        tree[p] = A[l]; // 用数组中的数据赋值
    else
    {
        ll mid = (l + r) / 2;
        build(l, mid, p * 2); // 先建立左右子节点
        build(mid + 1, r, p * 2 + 1);
        tree[p] = tree[p * 2] + tree[p * 2 + 1]; // 该节点的值等于左右子节点之和
    }
}

区间修改

void update(ll l, ll r, ll d, ll p = 1, ll cl = 1, ll cr = n)
{
    if (cl > r || cr < l) // 区间无交集
        return; // 剪枝
    else if (cl >= l && cr <= r) // 当前节点对应的区间包含在目标区间中
    {
        tree[p] += (cr - cl + 1) * d; // 更新当前区间的值
        if (cr > cl) // 如果不是叶子节点
            mark[p] += d; // 给当前区间打上标记
    }
    else // 与目标区间有交集,但不包含于其中
    {
        ll mid = (cl + cr) / 2;
        mark[p * 2] += mark[p]; // 标记向下传递
        mark[p * 2 + 1] += mark[p];
        tree[p * 2] += mark[p] * (mid - cl + 1); // 往下更新一层
        tree[p * 2 + 1] += mark[p] * (cr - mid);
        mark[p] = 0; // 清除标记
        update(l, r, d, p * 2, cl, mid); // 递归地往下寻找
        update(l, r, d, p * 2 + 1, mid + 1, cr);
        tree[p] = tree[p * 2] + tree[p * 2 + 1]; // 根据子节点更新当前节点的值
    }
}

区间查询

ll query(ll l, ll r, ll p = 1, ll cl = 1, ll cr = n)
{
    if (cl > r || cr < l)
        return 0;
    else if (cl >= l && cr <= r)
        return tree[p];
    else
    {
        ll mid = (cl + cr) / 2;
        push_down(p, cr - cl + 1);
        return query(l, r, p * 2, cl, mid) + query(l, r, p * 2 + 1, mid + 1, cr); 
        // 上一行拆成三行写就和区间修改格式一致了
    }
}

模板1

#include <bits/stdc++.h>
#define MAXN 100005
using namespace std;
typedef long long ll;
inline ll read()
{
    ll ans = 0;
    char c = getchar();
    while (!isdigit(c))
        c = getchar();
    while (isdigit(c))
    {
        ans = ans * 10 + c - '0';
        c = getchar();
    }
    return ans;
}
ll n, m, A[MAXN], tree[MAXN * 4], mark[MAXN * 4]; // 经验表明开四倍空间不会越界
inline void push_down(ll p, ll len)
{
    mark[p * 2] += mark[p];
    mark[p * 2 + 1] += mark[p];
    tree[p * 2] += mark[p] * (len - len / 2);
    tree[p * 2 + 1] += mark[p] * (len / 2);
    mark[p] = 0;
}
void build(ll l = 1, ll r = n, ll p = 1)
{
    if (l == r)
        tree[p] = A[l];
    else
    {
        ll mid = (l + r) / 2;
        build(l, mid, p * 2);
        build(mid + 1, r, p * 2 + 1);
        tree[p] = tree[p * 2] + tree[p * 2 + 1];
    }
}
void update(ll l, ll r, ll d, ll p = 1, ll cl = 1, ll cr = n)
{
    if (cl > r || cr < l)
        return;
    else if (cl >= l && cr <= r)
    {
        tree[p] += (cr - cl + 1) * d;
        if (cr > cl)
            mark[p] += d;
    }
    else
    {
        ll mid = (cl + cr) / 2;
        push_down(p, cr - cl + 1);
        update(l, r, d, p * 2, cl, mid);
        update(l, r, d, p * 2 + 1, mid + 1, cr);
        tree[p] = tree[p * 2] + tree[p * 2 + 1];
    }
}
ll query(ll l, ll r, ll p = 1, ll cl = 1, ll cr = n)
{
    if (cl > r || cr < l)
        return 0;
    else if (cl >= l && cr <= r)
        return tree[p];
    else
    {
        ll mid = (cl + cr) / 2;
        push_down(p, cr - cl + 1);
        return query(l, r, p * 2, cl, mid) + query(l, r, p * 2 + 1, mid + 1, cr);
    }
}
int main()
{
    n = read();
    m = read();
    for (int i = 1; i <= n; ++i)
        A[i] = read();
    build();
    for (int i = 0; i < m; ++i)
    {
        ll opr = read(), l = read(), r = read();
        if (opr == 1)
        {
            ll d = read();
            update(l, r, d);
        }
        else
            printf("%lld\n", query(l, r));
    }
    return 0;
}

模板2

#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const int MAXN = 1e5 + 5;
ll tree[MAXN << 2], mark[MAXN << 2], n, m, A[MAXN];
void push_down(int p, int len)
{
    tree[p << 1] += mark[p] * (len - len / 2);
    mark[p << 1] += mark[p];
    tree[p << 1 | 1] += mark[p] * (len / 2);
    mark[p << 1 | 1] += mark[p];
    mark[p] = 0;
}
void build(int p = 1, int cl = 1, int cr = n)
{
    if (cl == cr) { tree[p] = A[cl]; return; }
    int mid = (cl + cr) >> 1;
    build(p << 1, cl, mid);
    build(p << 1 | 1, mid + 1, cr);
    tree[p] = tree[p << 1] + tree[p << 1 | 1];
}
ll query(int l, int r, int p = 1, int cl = 1, int cr = n)
{
    if (cl >= l && cr <= r) return tree[p];
    push_down(p, cr - cl + 1);
    ll mid = (cl + cr) >> 1, ans = 0;
    if (mid >= l) ans += query(l, r, p << 1, cl, mid);
    if (mid < r) ans += query(l, r, p << 1 | 1, mid + 1, cr);
    return ans;
}
void update(int l, int r, int d, int p = 1, int cl = 1, int cr = n)
{
    if (cl >= l && cr <= r) { tree[p] += d * (cr - cl + 1), mark[p] += d; return; }
    push_down(p, cr - cl + 1);
    int mid = (cl + cr) >> 1;
    if (mid >= l) update(l, r, d, p << 1, cl, mid);
    if (mid < r) update(l, r, d, p << 1 | 1, mid + 1, cr);
    tree[p] = tree[p << 1] + tree[p << 1 | 1];
}
int main()
{
    ios::sync_with_stdio(false);
    cin >> n >> m;
    for (int i = 1; i <= n; ++i)
        cin >> A[i];
    build();
    while (m--)
    {
        int o, l, r, d;
        cin >> o >> l >> r;
        if (o == 1)
            cin >> d, update(l, r, d);
        else
            cout << query(l, r) << '\n';
    }
    return 0;
}

posted on 2021-05-05 17:37  小星◎  阅读(99)  评论(0)    收藏  举报

导航