线段树板子,懒标记,区间乘法,单点加法,区间求和

#include <bits/stdc++.h>
using namespace std;

typedef long long ll;
const int MOD = 998244353;

struct Node {
    int l, r;       // 区间左右端点
    ll sum;         // 区间和
    ll lazy;        // 乘法懒标记,表示该区间需要乘上的值
};

struct SegmentTree {
    vector<Node> tr;
    int n;
    
    // 构造函数
    SegmentTree(int size = 0) {
        init(size);
    }
    
    // 初始化
    void init(int size) {
        n = size;
        tr.resize(n * 4 + 10);
        build(1, 1, n);
    }
    
    // 上传:用子节点更新父节点
    void push_up(int p) {
        tr[p].sum = (tr[p << 1].sum + tr[p << 1 | 1].sum) % MOD;
    }
    
    // 下传懒标记到子节点
    void push_down(int p) {
        if (tr[p].lazy == 1) return;  // 懒标记为1,不需要下传
        
        ll lz = tr[p].lazy;
        int lc = p << 1;      // 左子节点
        int rc = p << 1 | 1;  // 右子节点
        
        // 更新左子节点
        tr[lc].sum = tr[lc].sum * lz % MOD;
        tr[lc].lazy = tr[lc].lazy * lz % MOD;
        
        // 更新右子节点
        tr[rc].sum = tr[rc].sum * lz % MOD;
        tr[rc].lazy = tr[rc].lazy * lz % MOD;
        
        // 清空当前节点的懒标记
        tr[p].lazy = 1;
    }
    
    // 建树
    void build(int p, int l, int r) {
        tr[p].l = l;
        tr[p].r = r;
        tr[p].sum = 0;
        tr[p].lazy = 1;  // 乘法懒标记初始为1(单位元)
        
        if (l == r) return;  // 叶子节点
        
        int mid = (l + r) >> 1;
        build(p << 1, l, mid);       // 建左子树
        build(p << 1 | 1, mid + 1, r);  // 建右子树
    }
    
    // 单点加:在位置 pos 加上 val
    void add_point(int p, int pos, ll val) {
        if (tr[p].l == tr[p].r) {
            // 到达叶子节点
            tr[p].sum = (tr[p].sum + val) % MOD;
            return;
        }
        
        push_down(p);  // 下传懒标记
        
        int mid = (tr[p].l + tr[p].r) >> 1;
        if (pos <= mid) add_point(p << 1, pos, val);
        else add_point(p << 1 | 1, pos, val);
        
        push_up(p);  // 上传更新父节点
    }
    
    // 对外接口:单点加
    void add_point(int pos, ll val) {
        add_point(1, pos, val);
    }
    
    // 区间乘:将区间 [l, r] 的所有元素乘上 val
    void mul_range(int p, int l, int r, ll val) {
        if (l <= tr[p].l && tr[p].r <= r) {
            // 当前区间完全包含在 [l, r] 中
            tr[p].sum = tr[p].sum * val % MOD;
            tr[p].lazy = tr[p].lazy * val % MOD;
            return;
        }
        
        push_down(p);  // 下传懒标记
        
        int mid = (tr[p].l + tr[p].r) >> 1;
        if (l <= mid) mul_range(p << 1, l, r, val);
        if (r > mid) mul_range(p << 1 | 1, l, r, val);
        
        push_up(p);  // 上传更新父节点
    }
    
    // 对外接口:区间乘
    void mul_range(int l, int r, ll val) {
        if (l > r) return;
        mul_range(1, l, r, val);
    }
    
    // 区间查询:查询区间 [l, r] 的和
    ll query(int p, int l, int r) {
        if (l <= tr[p].l && tr[p].r <= r) {
            // 当前区间完全包含在 [l, r] 中
            return tr[p].sum;
        }
        
        push_down(p);  // 下传懒标记
        
        int mid = (tr[p].l + tr[p].r) >> 1;
        ll res = 0;
        if (l <= mid) res = (res + query(p << 1, l, r)) % MOD;
        if (r > mid) res = (res + query(p << 1 | 1, l, r)) % MOD;
        
        return res;
    }
    
    // 对外接口:区间查询
    ll query(int l, int r) {
        if (l > r) return 0;
        return query(1, l, r);
    }
    
    // 单点查询(特殊情况的区间查询)
    ll query_point(int pos) {
        return query(pos, pos);
    }
};

// ==================== 使用示例 ====================

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    
    int n = 10;
    SegmentTree seg(n);  // 建立大小为 n 的线段树
    
    // 单点加:在位置 3 加 5
    seg.add_point(3, 5);
    
    // 在位置 5 加 7
    seg.add_point(5, 7);
    
    // 区间查询 [1, 10] 的和
    cout << "sum[1,10] = " << seg.query(1, 10) << endl;  // 输出 12
    
    // 区间乘:将 [5, 10] 的所有元素乘 2
    seg.mul_range(5, 10, 2);
    
    // 再次查询
    cout << "sum[1,10] = " << seg.query(1, 10) << endl;  // 输出 5 + 14 = 19
    
    // 单点查询
    cout << "point[5] = " << seg.query_point(5) << endl;  // 输出 14
    
    return 0;
}
posted @ 2026-04-07 17:20  majikko  阅读(1)  评论(0)    收藏  举报