线段树模板

通用线段树模板(支持区间加法 + 区间乘法 + 区间求和):

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

class SegmentTree {
public:
    struct Node {
        ll l, r;
        ll sum, add, mul;
    };

    vector<Node> tree;
    vector<ll> data;

    SegmentTree(const vector<ll>& input) {
        data = input;
        int size = input.size() - 1; // 1-based
        tree.resize(size * 4 + 10);
        build(1, 1, size);
    }

    void build(int i, int l, int r) {
        tree[i] = {l, r, 0, 0, 1};
        if (l == r) {
            tree[i].sum = data[l];
            return;
        }
        int mid = (l + r) >> 1;
        build(i << 1, l, mid);
        build(i << 1 | 1, mid + 1, r);
        push_up(i);
    }

    void range_add(int l, int r, ll val) {
        add(1, l, r, val);
    }

    void range_mult(int l, int r, ll val) {
        mult(1, l, r, val);
    }

    ll range_query(int l, int r) {
        return query(1, l, r);
    }

private:
    void push_up(int i) {
        tree[i].sum = tree[i << 1].sum + tree[i << 1 | 1].sum;
    }

    void push_down(int i) {
        ll mul = tree[i].mul, add = tree[i].add;
        apply(i << 1, mul, add);
        apply(i << 1 | 1, mul, add);
        tree[i].mul = 1;
        tree[i].add = 0;
    }

    void apply(int i, ll mul, ll add) {
        tree[i].sum = tree[i].sum * mul + (tree[i].r - tree[i].l + 1) * add;
        tree[i].mul *= mul;
        tree[i].add = tree[i].add * mul + add;
    }

    void add(int i, int l, int r, ll val) {
        if (tree[i].l >= l && tree[i].r <= r) {
            apply(i, 1, val);
            return;
        }
        push_down(i);
        if (tree[i << 1].r >= l) add(i << 1, l, r, val);
        if (tree[i << 1 | 1].l <= r) add(i << 1 | 1, l, r, val);
        push_up(i);
    }

    void mult(int i, int l, int r, ll val) {
        if (tree[i].l >= l && tree[i].r <= r) {
            apply(i, val, 0);
            return;
        }
        push_down(i);
        if (tree[i << 1].r >= l) mult(i << 1, l, r, val);
        if (tree[i << 1 | 1].l <= r) mult(i << 1 | 1, l, r, val);
        push_up(i);
    }

    ll query(int i, int l, int r) {
        if (tree[i].l >= l && tree[i].r <= r)
            return tree[i].sum;
        push_down(i);
        ll res = 0;
        if (tree[i << 1].r >= l) res += query(i << 1, l, r);
        if (tree[i << 1 | 1].l <= r) res += query(i << 1 | 1, l, r);
        return res;
    }
};
int main() {
    ll n, m;
    cin >> n >> m;
    vector<ll> arr(n + 1); // 1-based index
    for (int i = 1; i <= n; i++) {
        cin >> arr[i];
    }

    SegmentTree seg(arr);
    while (m--) {
        int op;
        cin >> op;
        if (op == 1) {
            ll l, r, val;
            cin >> l >> r >> val;
            seg.range_mult(l, r, val);
        } else if (op == 2) {
            ll l, r, val;
            cin >> l >> r >> val;
            seg.range_add(l, r, val);
        } else if (op == 3) {
            ll l, r;
            cin >> l >> r;
            cout << seg.range_query(l, r) << endl;
        }
    }

    return 0;
}

通用线段树模板(支持区间加法 + 区间乘法 + 区间求和)取模:


#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

class SegmentTree {
public:
    struct Node {
        ll l, r;
        ll sum, add, mul;
    };

vector<Node> tree;
vector<ll> data;
ll mod;

SegmentTree(const vector<ll>& input, ll mod_val) {
    mod = mod_val;
    data = input;
    int size = input.size() - 1; // 1-based
    tree.resize(size * 4 + 10);
    build(1, 1, size);
}

void build(int i, int l, int r) {
    tree[i] = {l, r, 0, 0, 1};
    if (l == r) {
        tree[i].sum = data[l] % mod;
        return;
    }
    int mid = (l + r) >> 1;
    build(i << 1, l, mid);
    build(i << 1 | 1, mid + 1, r);
    push_up(i);
}

void range_add(int l, int r, ll val) {
    add(1, l, r, val);
}

void range_mult(int l, int r, ll val) {
    mult(1, l, r, val);
}

ll range_query(int l, int r) {
    return query(1, l, r);
}
private:
    void push_up(int i) {
        tree[i].sum = (tree[i << 1].sum + tree[i << 1 | 1].sum) % mod;
    }void push_down(int i) {
    ll mul = tree[i].mul, add = tree[i].add;
    apply(i << 1, mul, add);
    apply(i << 1 | 1, mul, add);
    tree[i].mul = 1;
    tree[i].add = 0;
}

void apply(int i, ll mul, ll add) {
    tree[i].sum = (tree[i].sum * mul % mod + (tree[i].r - tree[i].l + 1) * add % mod) % mod;
    tree[i].mul = tree[i].mul * mul % mod;
    tree[i].add = (tree[i].add * mul % mod + add) % mod;
}

void add(int i, int l, int r, ll val) {
    if (tree[i].l >= l && tree[i].r <= r) {
        apply(i, 1, val);
        return;
    }
    push_down(i);
    if (tree[i << 1].r >= l) add(i << 1, l, r, val);
    if (tree[i << 1 | 1].l <= r) add(i << 1 | 1, l, r, val);
    push_up(i);
}

void mult(int i, int l, int r, ll val) {
    if (tree[i].l >= l && tree[i].r <= r) {
        apply(i, val, 0);
        return;
    }
    push_down(i);
    if (tree[i << 1].r >= l) mult(i << 1, l, r, val);
    if (tree[i << 1 | 1].l <= r) mult(i << 1 | 1, l, r, val);
    push_up(i);
}

ll query(int i, int l, int r) {
    if (tree[i].l >= l && tree[i].r <= r)
        return tree[i].sum;
    push_down(i);
    ll res = 0;
    if (tree[i << 1].r >= l) res = (res + query(i << 1, l, r)) % mod;
    if (tree[i << 1 | 1].l <= r) res = (res + query(i << 1 | 1, l, r)) % mod;
    return res;
}
};
int main() {
    ll n, m, mod;
    cin >> n >> m >> mod;
    vector<ll> arr(n + 1); 
    for (int i = 1; i <= n; i++) {
        cin >> arr[i];
    }

    SegmentTree seg(arr, mod);
    while (m--) {
        int op;
        cin >> op;
        if (op == 1) {
            ll l, r, val;
            cin >> l >> r >> val;
            seg.range_mult(l, r, val);
        } else if (op == 2) {
            ll l, r, val;
            cin >> l >> r >> val;
            seg.range_add(l, r, val);
        } else if (op == 3) {
            ll l, r;
            cin >> l >> r;
            cout << seg.range_query(l, r) << endl;
        }
    }

    return 0;
}

两端序列分别建立不同线段树 tip:abc357f题

#include<bits/stdc++.h>

typedef long long LL;
using namespace std;
const int N = 2e5 + 10;
const int MOD = 998244353;

class segment {
#define lson (root << 1)
#define rson (root << 1 | 1)
public:
    LL sab[N << 2];
    LL sum[N << 2][2];
    LL lazy[N << 2][2];

    void pushup(int root) {
        sab[root] = (sab[lson] + sab[rson]) % MOD;
        sum[root][0] = (sum[lson][0] + sum[rson][0]) % MOD;
        sum[root][1] = (sum[lson][1] + sum[rson][1]) % MOD;
    }

    void build(int root, int l, int r, vector<int> &a, vector<int> &b) {
        if (l == r) {
            sab[root] = 1ll * a[l - 1] * b[l - 1] % MOD;
            sum[root][0] = a[l - 1];
            sum[root][1] = b[l - 1];
            return;
        }
        int mid = (l + r) >> 1;
        build(lson, l, mid, a, b);
        build(rson, mid + 1, r, a, b);
        pushup(root);
    }

    void pushdown(int root, int l, int mid, int r) {
        if (lazy[root][0] || lazy[root][1]) {
            sab[lson] = (sab[lson] + lazy[root][0] * sum[lson][1] % MOD + lazy[root][1] * sum[lson][0] % MOD +
                         lazy[root][0] * lazy[root][1] % MOD * (mid - l + 1) % MOD) % MOD;
            sum[lson][0] = (sum[lson][0] + lazy[root][0] * (mid - l + 1) % MOD) % MOD;
            sum[lson][1] = (sum[lson][1] + lazy[root][1] * (mid - l + 1) % MOD) % MOD;

            sab[rson] = (sab[rson] + lazy[root][0] * sum[rson][1] % MOD + lazy[root][1] * sum[rson][0] % MOD +
                         lazy[root][0] * lazy[root][1] % MOD * (r - mid) % MOD) % MOD;
            sum[rson][0] = (sum[rson][0] + lazy[root][0] * (r - mid) % MOD) % MOD;
            sum[rson][1] = (sum[rson][1] + lazy[root][1] * (r - mid) % MOD) % MOD;

            lazy[lson][0] = (lazy[lson][0] + lazy[root][0]) % MOD;
            lazy[lson][1] = (lazy[lson][1] + lazy[root][1]) % MOD;
            lazy[rson][0] = (lazy[rson][0] + lazy[root][0]) % MOD;
            lazy[rson][1] = (lazy[rson][1] + lazy[root][1]) % MOD;

            lazy[root][0] = lazy[root][1] = 0;
        }
    }

    void update(int root, int l, int r, int L, int R, LL val, int op) {
        if (L <= l && r <= R) {
            sab[root] = (sab[root] + val * sum[root][op ^ 1] % MOD) % MOD;
            sum[root][op] = (sum[root][op] + val * (r - l + 1) % MOD) % MOD;
            lazy[root][op] = (lazy[root][op] + val) % MOD;
            return;
        }
        int mid = (l + r) >> 1;
        pushdown(root, l, mid, r);
        if (L <= mid)
            update(lson, l, mid, L, R, val, op);
        if (R > mid)
            update(rson, mid + 1, r, L, R, val, op);
        pushup(root);
    }

    LL query(int root, int l, int r, int L, int R) {
        if (L <= l && r <= R) {
            return sab[root];
        }
        int mid = (l + r) >> 1;
        pushdown(root, l, mid, r);
        LL ans = 0;
        if (L <= mid)
            ans += query(lson, l, mid, L, R);
        if (R > mid)
            ans += query(rson, mid + 1, r, L, R);
        ans %= MOD;
        return ans;
    }

} sg;

int main() {
    int n, q;
    cin >> n >> q;
    vector<int> a(n), b(n);
    for (auto &i: a)
        cin >> i;
    for (auto &i: b)
        cin >> i;
    sg.build(1, 1, n, a, b);
    while (q--) {
        int op;
        cin >> op;
        if (op == 3) {
            int l, r;
            cin >> l >> r;
            int ans = sg.query(1, 1, n, l, r);
            cout << ans << endl;
        } else {
            int l, r, x;
            cin >> l >> r >> x;
            sg.update(1, 1, n, l, r, x, op - 1);
        }
    }
    return 0;
}

posted @ 2025-05-14 19:33  cloudbless  阅读(133)  评论(0)    收藏  举报