#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int MOD = 998244353;
struct Node {
int l, r; // 区间左右端点
ll sum; // 区间和
ll lazy; // 乘法懒标记,表示该区间需要乘上的值
};
struct SegmentTree {
vector<Node> tr;
int n;
// 构造函数
SegmentTree(int size = 0) {
init(size);
}
// 初始化
void init(int size) {
n = size;
tr.resize(n * 4 + 10);
build(1, 1, n);
}
// 上传:用子节点更新父节点
void push_up(int p) {
tr[p].sum = (tr[p << 1].sum + tr[p << 1 | 1].sum) % MOD;
}
// 下传懒标记到子节点
void push_down(int p) {
if (tr[p].lazy == 1) return; // 懒标记为1,不需要下传
ll lz = tr[p].lazy;
int lc = p << 1; // 左子节点
int rc = p << 1 | 1; // 右子节点
// 更新左子节点
tr[lc].sum = tr[lc].sum * lz % MOD;
tr[lc].lazy = tr[lc].lazy * lz % MOD;
// 更新右子节点
tr[rc].sum = tr[rc].sum * lz % MOD;
tr[rc].lazy = tr[rc].lazy * lz % MOD;
// 清空当前节点的懒标记
tr[p].lazy = 1;
}
// 建树
void build(int p, int l, int r) {
tr[p].l = l;
tr[p].r = r;
tr[p].sum = 0;
tr[p].lazy = 1; // 乘法懒标记初始为1(单位元)
if (l == r) return; // 叶子节点
int mid = (l + r) >> 1;
build(p << 1, l, mid); // 建左子树
build(p << 1 | 1, mid + 1, r); // 建右子树
}
// 单点加:在位置 pos 加上 val
void add_point(int p, int pos, ll val) {
if (tr[p].l == tr[p].r) {
// 到达叶子节点
tr[p].sum = (tr[p].sum + val) % MOD;
return;
}
push_down(p); // 下传懒标记
int mid = (tr[p].l + tr[p].r) >> 1;
if (pos <= mid) add_point(p << 1, pos, val);
else add_point(p << 1 | 1, pos, val);
push_up(p); // 上传更新父节点
}
// 对外接口:单点加
void add_point(int pos, ll val) {
add_point(1, pos, val);
}
// 区间乘:将区间 [l, r] 的所有元素乘上 val
void mul_range(int p, int l, int r, ll val) {
if (l <= tr[p].l && tr[p].r <= r) {
// 当前区间完全包含在 [l, r] 中
tr[p].sum = tr[p].sum * val % MOD;
tr[p].lazy = tr[p].lazy * val % MOD;
return;
}
push_down(p); // 下传懒标记
int mid = (tr[p].l + tr[p].r) >> 1;
if (l <= mid) mul_range(p << 1, l, r, val);
if (r > mid) mul_range(p << 1 | 1, l, r, val);
push_up(p); // 上传更新父节点
}
// 对外接口:区间乘
void mul_range(int l, int r, ll val) {
if (l > r) return;
mul_range(1, l, r, val);
}
// 区间查询:查询区间 [l, r] 的和
ll query(int p, int l, int r) {
if (l <= tr[p].l && tr[p].r <= r) {
// 当前区间完全包含在 [l, r] 中
return tr[p].sum;
}
push_down(p); // 下传懒标记
int mid = (tr[p].l + tr[p].r) >> 1;
ll res = 0;
if (l <= mid) res = (res + query(p << 1, l, r)) % MOD;
if (r > mid) res = (res + query(p << 1 | 1, l, r)) % MOD;
return res;
}
// 对外接口:区间查询
ll query(int l, int r) {
if (l > r) return 0;
return query(1, l, r);
}
// 单点查询(特殊情况的区间查询)
ll query_point(int pos) {
return query(pos, pos);
}
};
// ==================== 使用示例 ====================
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int n = 10;
SegmentTree seg(n); // 建立大小为 n 的线段树
// 单点加:在位置 3 加 5
seg.add_point(3, 5);
// 在位置 5 加 7
seg.add_point(5, 7);
// 区间查询 [1, 10] 的和
cout << "sum[1,10] = " << seg.query(1, 10) << endl; // 输出 12
// 区间乘:将 [5, 10] 的所有元素乘 2
seg.mul_range(5, 10, 2);
// 再次查询
cout << "sum[1,10] = " << seg.query(1, 10) << endl; // 输出 5 + 14 = 19
// 单点查询
cout << "point[5] = " << seg.query_point(5) << endl; // 输出 14
return 0;
}