线段树模板
通用线段树模板(支持区间加法 + 区间乘法 + 区间求和):
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
class SegmentTree {
public:
struct Node {
ll l, r;
ll sum, add, mul;
};
vector<Node> tree;
vector<ll> data;
SegmentTree(const vector<ll>& input) {
data = input;
int size = input.size() - 1; // 1-based
tree.resize(size * 4 + 10);
build(1, 1, size);
}
void build(int i, int l, int r) {
tree[i] = {l, r, 0, 0, 1};
if (l == r) {
tree[i].sum = data[l];
return;
}
int mid = (l + r) >> 1;
build(i << 1, l, mid);
build(i << 1 | 1, mid + 1, r);
push_up(i);
}
void range_add(int l, int r, ll val) {
add(1, l, r, val);
}
void range_mult(int l, int r, ll val) {
mult(1, l, r, val);
}
ll range_query(int l, int r) {
return query(1, l, r);
}
private:
void push_up(int i) {
tree[i].sum = tree[i << 1].sum + tree[i << 1 | 1].sum;
}
void push_down(int i) {
ll mul = tree[i].mul, add = tree[i].add;
apply(i << 1, mul, add);
apply(i << 1 | 1, mul, add);
tree[i].mul = 1;
tree[i].add = 0;
}
void apply(int i, ll mul, ll add) {
tree[i].sum = tree[i].sum * mul + (tree[i].r - tree[i].l + 1) * add;
tree[i].mul *= mul;
tree[i].add = tree[i].add * mul + add;
}
void add(int i, int l, int r, ll val) {
if (tree[i].l >= l && tree[i].r <= r) {
apply(i, 1, val);
return;
}
push_down(i);
if (tree[i << 1].r >= l) add(i << 1, l, r, val);
if (tree[i << 1 | 1].l <= r) add(i << 1 | 1, l, r, val);
push_up(i);
}
void mult(int i, int l, int r, ll val) {
if (tree[i].l >= l && tree[i].r <= r) {
apply(i, val, 0);
return;
}
push_down(i);
if (tree[i << 1].r >= l) mult(i << 1, l, r, val);
if (tree[i << 1 | 1].l <= r) mult(i << 1 | 1, l, r, val);
push_up(i);
}
ll query(int i, int l, int r) {
if (tree[i].l >= l && tree[i].r <= r)
return tree[i].sum;
push_down(i);
ll res = 0;
if (tree[i << 1].r >= l) res += query(i << 1, l, r);
if (tree[i << 1 | 1].l <= r) res += query(i << 1 | 1, l, r);
return res;
}
};
int main() {
ll n, m;
cin >> n >> m;
vector<ll> arr(n + 1); // 1-based index
for (int i = 1; i <= n; i++) {
cin >> arr[i];
}
SegmentTree seg(arr);
while (m--) {
int op;
cin >> op;
if (op == 1) {
ll l, r, val;
cin >> l >> r >> val;
seg.range_mult(l, r, val);
} else if (op == 2) {
ll l, r, val;
cin >> l >> r >> val;
seg.range_add(l, r, val);
} else if (op == 3) {
ll l, r;
cin >> l >> r;
cout << seg.range_query(l, r) << endl;
}
}
return 0;
}
通用线段树模板(支持区间加法 + 区间乘法 + 区间求和)取模:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
class SegmentTree {
public:
struct Node {
ll l, r;
ll sum, add, mul;
};
vector<Node> tree;
vector<ll> data;
ll mod;
SegmentTree(const vector<ll>& input, ll mod_val) {
mod = mod_val;
data = input;
int size = input.size() - 1; // 1-based
tree.resize(size * 4 + 10);
build(1, 1, size);
}
void build(int i, int l, int r) {
tree[i] = {l, r, 0, 0, 1};
if (l == r) {
tree[i].sum = data[l] % mod;
return;
}
int mid = (l + r) >> 1;
build(i << 1, l, mid);
build(i << 1 | 1, mid + 1, r);
push_up(i);
}
void range_add(int l, int r, ll val) {
add(1, l, r, val);
}
void range_mult(int l, int r, ll val) {
mult(1, l, r, val);
}
ll range_query(int l, int r) {
return query(1, l, r);
}
private:
void push_up(int i) {
tree[i].sum = (tree[i << 1].sum + tree[i << 1 | 1].sum) % mod;
}void push_down(int i) {
ll mul = tree[i].mul, add = tree[i].add;
apply(i << 1, mul, add);
apply(i << 1 | 1, mul, add);
tree[i].mul = 1;
tree[i].add = 0;
}
void apply(int i, ll mul, ll add) {
tree[i].sum = (tree[i].sum * mul % mod + (tree[i].r - tree[i].l + 1) * add % mod) % mod;
tree[i].mul = tree[i].mul * mul % mod;
tree[i].add = (tree[i].add * mul % mod + add) % mod;
}
void add(int i, int l, int r, ll val) {
if (tree[i].l >= l && tree[i].r <= r) {
apply(i, 1, val);
return;
}
push_down(i);
if (tree[i << 1].r >= l) add(i << 1, l, r, val);
if (tree[i << 1 | 1].l <= r) add(i << 1 | 1, l, r, val);
push_up(i);
}
void mult(int i, int l, int r, ll val) {
if (tree[i].l >= l && tree[i].r <= r) {
apply(i, val, 0);
return;
}
push_down(i);
if (tree[i << 1].r >= l) mult(i << 1, l, r, val);
if (tree[i << 1 | 1].l <= r) mult(i << 1 | 1, l, r, val);
push_up(i);
}
ll query(int i, int l, int r) {
if (tree[i].l >= l && tree[i].r <= r)
return tree[i].sum;
push_down(i);
ll res = 0;
if (tree[i << 1].r >= l) res = (res + query(i << 1, l, r)) % mod;
if (tree[i << 1 | 1].l <= r) res = (res + query(i << 1 | 1, l, r)) % mod;
return res;
}
};
int main() {
ll n, m, mod;
cin >> n >> m >> mod;
vector<ll> arr(n + 1);
for (int i = 1; i <= n; i++) {
cin >> arr[i];
}
SegmentTree seg(arr, mod);
while (m--) {
int op;
cin >> op;
if (op == 1) {
ll l, r, val;
cin >> l >> r >> val;
seg.range_mult(l, r, val);
} else if (op == 2) {
ll l, r, val;
cin >> l >> r >> val;
seg.range_add(l, r, val);
} else if (op == 3) {
ll l, r;
cin >> l >> r;
cout << seg.range_query(l, r) << endl;
}
}
return 0;
}
两端序列分别建立不同线段树 tip:abc357f题
#include<bits/stdc++.h>
typedef long long LL;
using namespace std;
const int N = 2e5 + 10;
const int MOD = 998244353;
class segment {
#define lson (root << 1)
#define rson (root << 1 | 1)
public:
LL sab[N << 2];
LL sum[N << 2][2];
LL lazy[N << 2][2];
void pushup(int root) {
sab[root] = (sab[lson] + sab[rson]) % MOD;
sum[root][0] = (sum[lson][0] + sum[rson][0]) % MOD;
sum[root][1] = (sum[lson][1] + sum[rson][1]) % MOD;
}
void build(int root, int l, int r, vector<int> &a, vector<int> &b) {
if (l == r) {
sab[root] = 1ll * a[l - 1] * b[l - 1] % MOD;
sum[root][0] = a[l - 1];
sum[root][1] = b[l - 1];
return;
}
int mid = (l + r) >> 1;
build(lson, l, mid, a, b);
build(rson, mid + 1, r, a, b);
pushup(root);
}
void pushdown(int root, int l, int mid, int r) {
if (lazy[root][0] || lazy[root][1]) {
sab[lson] = (sab[lson] + lazy[root][0] * sum[lson][1] % MOD + lazy[root][1] * sum[lson][0] % MOD +
lazy[root][0] * lazy[root][1] % MOD * (mid - l + 1) % MOD) % MOD;
sum[lson][0] = (sum[lson][0] + lazy[root][0] * (mid - l + 1) % MOD) % MOD;
sum[lson][1] = (sum[lson][1] + lazy[root][1] * (mid - l + 1) % MOD) % MOD;
sab[rson] = (sab[rson] + lazy[root][0] * sum[rson][1] % MOD + lazy[root][1] * sum[rson][0] % MOD +
lazy[root][0] * lazy[root][1] % MOD * (r - mid) % MOD) % MOD;
sum[rson][0] = (sum[rson][0] + lazy[root][0] * (r - mid) % MOD) % MOD;
sum[rson][1] = (sum[rson][1] + lazy[root][1] * (r - mid) % MOD) % MOD;
lazy[lson][0] = (lazy[lson][0] + lazy[root][0]) % MOD;
lazy[lson][1] = (lazy[lson][1] + lazy[root][1]) % MOD;
lazy[rson][0] = (lazy[rson][0] + lazy[root][0]) % MOD;
lazy[rson][1] = (lazy[rson][1] + lazy[root][1]) % MOD;
lazy[root][0] = lazy[root][1] = 0;
}
}
void update(int root, int l, int r, int L, int R, LL val, int op) {
if (L <= l && r <= R) {
sab[root] = (sab[root] + val * sum[root][op ^ 1] % MOD) % MOD;
sum[root][op] = (sum[root][op] + val * (r - l + 1) % MOD) % MOD;
lazy[root][op] = (lazy[root][op] + val) % MOD;
return;
}
int mid = (l + r) >> 1;
pushdown(root, l, mid, r);
if (L <= mid)
update(lson, l, mid, L, R, val, op);
if (R > mid)
update(rson, mid + 1, r, L, R, val, op);
pushup(root);
}
LL query(int root, int l, int r, int L, int R) {
if (L <= l && r <= R) {
return sab[root];
}
int mid = (l + r) >> 1;
pushdown(root, l, mid, r);
LL ans = 0;
if (L <= mid)
ans += query(lson, l, mid, L, R);
if (R > mid)
ans += query(rson, mid + 1, r, L, R);
ans %= MOD;
return ans;
}
} sg;
int main() {
int n, q;
cin >> n >> q;
vector<int> a(n), b(n);
for (auto &i: a)
cin >> i;
for (auto &i: b)
cin >> i;
sg.build(1, 1, n, a, b);
while (q--) {
int op;
cin >> op;
if (op == 3) {
int l, r;
cin >> l >> r;
int ans = sg.query(1, 1, n, l, r);
cout << ans << endl;
} else {
int l, r, x;
cin >> l >> r >> x;
sg.update(1, 1, n, l, r, x, op - 1);
}
}
return 0;
}

浙公网安备 33010602011771号