【知识点】线段树
前言
线段树用于解决需要区间查询、区间修改的问题。
例题
该题要求同时实现“加法”和“乘法”两个运算,所以需要\(add\)和\(mul\)两个懒惰标记。
过程
建树
建树流程都是一致的。
void build(i64 s, i64 t, i64 p) { //s为起点,t为终点,p为节点编号
if (s == t) { //如果是叶子,结束递归
d[p] = a[s];
mul[p] = 1;
return;
}
i64 mid = (s + t) / 2;
build(s, mid, 2 * p), build(mid + 1, t, 2 * p + 1);
d[p] = d[2 * p] + d[2 * p + 1]; //建树
mul[p] = 1;
return;
}
懒惰标记下传
注意顺序:一定要先下传\(mul\)标记,再下传\(add\)标记。
void pushdown(i64 s, i64 t, i64 p) {
i64 mid = (s + t) / 2;
//更新子节点的值
d[2 * p] = ((d[2 * p] * mul[p]) % m + (add[p] * (mid - s + 1)) % m) % m;
d[2 * p + 1] = ((d[2 * p + 1] * mul[p]) % m + (add[p] * (t - mid )) % m) % m;
//更新子节点的mul标记
mul[2 * p] = (mul[2 * p] * mul[p]) % m;
mul[2 * p + 1] = (mul[2 * p + 1] * mul[p]) % m;
//更新子节点的add标记
add[2 * p] = (add[2 * p] * mul[p] + add[p]) % m;
add[2 * p + 1] = (add[2 * p + 1] * mul[p] + add[p]) % m;
//清空父节点的标记
mul[p] = 1, add[p] = 0;
return;
}
区间更新
void muls(i64 l, i64 r, i64 k, i64 s, i64 t, i64 p) {
if (l <= s && t <= r) { //该区间被完全包含
add[p] = (add[p] * k) % m;
mul[p] = (mul[p] * k) % m;
d[p] = (d[p] * k) % m;
return;
}
pushdown(s, t, p); //下传标记
int mid = (s + t) / 2;
if (l <= mid)
muls(l, r, k, s, mid, 2 * p);
if (r > mid)
muls(l, r, k, mid + 1, t, 2 * p + 1);
d[p] = (d[2 * p] + d[2 * p + 1]) % m; //记得更新该节点的值
return;
}
void adds(i64 l, i64 r, i64 k, i64 s, i64 t, i64 p) {
if (l <= s && t <= r) {
add[p] = (add[p] + k) % m;
d[p] = (d[p] + k * (t - s + 1)) % m;
return;
}
pushdown(s, t, p);
int mid = (s + t) / 2;
if (l <= mid)
adds(l, r, k, s, mid, 2 * p);
if (r > mid)
adds(l, r, k, mid + 1, t, 2 * p + 1);
d[p] = (d[2 * p] + d[2 * p + 1]) % m;
return;
}
区间查询
被查询的点都是真实值,即其祖先都没有懒惰标记。所以访问到的非终点的节点都要记得下传标记。
i64 find(i64 l, i64 r, i64 s, i64 t, i64 p) {
if (l <= s && t <= r)
return d[p];
i64 res = 0;
pushdown(s, t, p); //下传标记
i64 mid = (s + t) / 2;
if (l <= mid)
res += find(l, r, s, mid, 2 * p);
if (r > mid)
res += find(l, r, mid + 1, t, 2 * p + 1);
return res % m;
}
AC代码
#include <bits/stdc++.h>
using namespace std;
using i64 = long long;
i64 n, q, m, opt, x, y, k;
i64 a[100100], d[400100], add[400100], mul[400100];
void build(i64 s, i64 t, i64 p) { //s为起点,t为终点,p为节点编号
if (s == t) { //如果是叶子,结束递归
d[p] = a[s];
mul[p] = 1;
return;
}
i64 mid = (s + t) / 2;
build(s, mid, 2 * p), build(mid + 1, t, 2 * p + 1);
d[p] = d[2 * p] + d[2 * p + 1]; //建树
mul[p] = 1;
return;
}
void pushdown(i64 s, i64 t, i64 p) {
i64 mid = (s + t) / 2;
//更新子节点的值
d[2 * p] = ((d[2 * p] * mul[p]) % m + (add[p] * (mid - s + 1)) % m) % m;
d[2 * p + 1] = ((d[2 * p + 1] * mul[p]) % m + (add[p] * (t - mid )) % m) % m;
//更新子节点的mul标记
mul[2 * p] = (mul[2 * p] * mul[p]) % m;
mul[2 * p + 1] = (mul[2 * p + 1] * mul[p]) % m;
//更新子节点的add标记
add[2 * p] = (add[2 * p] * mul[p] + add[p]) % m;
add[2 * p + 1] = (add[2 * p + 1] * mul[p] + add[p]) % m;
//清空父节点的标记
mul[p] = 1, add[p] = 0;
return;
}
void muls(i64 l, i64 r, i64 k, i64 s, i64 t, i64 p) {
if (l <= s && t <= r) { //该区间被完全包含
add[p] = (add[p] * k) % m;
mul[p] = (mul[p] * k) % m;
d[p] = (d[p] * k) % m;
return;
}
pushdown(s, t, p); //下传标记
int mid = (s + t) / 2;
if (l <= mid)
muls(l, r, k, s, mid, 2 * p);
if (r > mid)
muls(l, r, k, mid + 1, t, 2 * p + 1);
d[p] = (d[2 * p] + d[2 * p + 1]) % m; //记得更新该节点的值
return;
}
void adds(i64 l, i64 r, i64 k, i64 s, i64 t, i64 p) {
if (l <= s && t <= r) {
add[p] = (add[p] + k) % m;
d[p] = (d[p] + k * (t - s + 1)) % m;
return;
}
pushdown(s, t, p);
int mid = (s + t) / 2;
if (l <= mid)
adds(l, r, k, s, mid, 2 * p);
if (r > mid)
adds(l, r, k, mid + 1, t, 2 * p + 1);
d[p] = (d[2 * p] + d[2 * p + 1]) % m;
return;
}
i64 find(i64 l, i64 r, i64 s, i64 t, i64 p) {
if (l <= s && t <= r)
return d[p];
i64 res = 0;
pushdown(s, t, p); //下传标记
i64 mid = (s + t) / 2;
if (l <= mid)
res += find(l, r, s, mid, 2 * p);
if (r > mid)
res += find(l, r, mid + 1, t, 2 * p + 1);
return res % m;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0), cout.tie(0);
cin >> n >> q >> m;
for (i64 i = 1; i <= n; i++)
cin >> a[i];
build(1, n, 1);
while (q--) {
cin >> opt;
if (opt == 1) {
cin >> x >> y >> k;
muls(x, y, k, 1, n, 1);
}
if (opt == 2) {
cin >> x >> y >> k;
adds(x, y, k, 1, n, 1);
}
if (opt == 3) {
cin >> x >> y;
cout << find(x, y, 1, n, 1) << '\n';
}
}
return 0;
}

浙公网安备 33010602011771号