【知识点】线段树

前言

线段树用于解决需要区间查询、区间修改的问题。

定义可以参考Morphis老师的博客:线段树博客

例题

线段树例题

该题要求同时实现“加法”和“乘法”两个运算,所以需要\(add\)\(mul\)两个懒惰标记。

过程

建树

建树流程都是一致的。

void build(i64 s, i64 t, i64 p) {	//s为起点,t为终点,p为节点编号
	if (s == t) {	//如果是叶子,结束递归
		d[p] = a[s];
		mul[p] = 1;
		return;
	}
	i64 mid = (s + t) / 2;
	build(s, mid, 2 * p), build(mid + 1, t, 2 * p + 1);
	d[p] = d[2 * p] + d[2 * p + 1];	//建树
	mul[p] = 1;
	return;
}

懒惰标记下传

注意顺序:一定要先下传\(mul\)标记,再下传\(add\)标记。

void pushdown(i64 s, i64 t, i64 p) {
	i64 mid = (s + t) / 2;
	//更新子节点的值
	d[2 * p] = ((d[2 * p] * mul[p]) % m + (add[p] * (mid - s + 1)) % m) % m;
	d[2 * p + 1] = ((d[2 * p + 1] * mul[p]) % m + (add[p] * (t - mid )) % m) % m;
	//更新子节点的mul标记
	mul[2 * p] = (mul[2 * p] * mul[p]) % m;
	mul[2 * p + 1] = (mul[2 * p + 1] * mul[p]) % m;
	//更新子节点的add标记
	add[2 * p] = (add[2 * p] * mul[p] + add[p]) % m;
	add[2 * p + 1] = (add[2 * p + 1] * mul[p] + add[p]) % m;
	//清空父节点的标记
	mul[p] = 1, add[p] = 0;
	return;
}

区间更新

void muls(i64 l, i64 r, i64 k, i64 s, i64 t, i64 p) {
	if (l <= s && t <= r) {	//该区间被完全包含
		add[p] = (add[p] * k) % m;
		mul[p] = (mul[p] * k) % m;
		d[p] = (d[p] * k) % m;
		return;
	}
	pushdown(s, t, p);	//下传标记
	int mid = (s + t) / 2;
	if (l <= mid)
		muls(l, r, k, s, mid, 2 * p);
	if (r > mid)
		muls(l, r, k, mid + 1, t, 2 * p + 1);
	d[p] = (d[2 * p] + d[2 * p + 1]) % m;	//记得更新该节点的值
	return;
}

void adds(i64 l, i64 r, i64 k, i64 s, i64 t, i64 p) {
	if (l <= s && t <= r) {
		add[p] = (add[p] + k) % m;
		d[p] = (d[p] + k * (t - s + 1)) % m;
		return;
	}
	pushdown(s, t, p);
	int mid = (s + t) / 2;
	if (l <= mid)
		adds(l, r, k, s, mid, 2 * p);
	if (r > mid)
		adds(l, r, k, mid + 1, t, 2 * p + 1);
	d[p] = (d[2 * p] + d[2 * p + 1]) % m;
	return;
}

区间查询

被查询的点都是真实值,即其祖先都没有懒惰标记。所以访问到的非终点的节点都要记得下传标记。

i64 find(i64 l, i64 r, i64 s, i64 t, i64 p) {
	if (l <= s && t <= r)
		return d[p];
	i64 res = 0;
	pushdown(s, t, p);	//下传标记
	i64 mid = (s + t) / 2;
	if (l <= mid)
		res += find(l, r, s, mid, 2 * p);
	if (r > mid)
		res += find(l, r, mid + 1, t, 2 * p + 1);
	return res % m;
}

AC代码

#include <bits/stdc++.h>
using namespace std;
using i64 = long long;

i64 n, q, m, opt, x, y, k;
i64 a[100100], d[400100], add[400100], mul[400100];

void build(i64 s, i64 t, i64 p) {	//s为起点,t为终点,p为节点编号
	if (s == t) {	//如果是叶子,结束递归
		d[p] = a[s];
		mul[p] = 1;
		return;
	}
	i64 mid = (s + t) / 2;
	build(s, mid, 2 * p), build(mid + 1, t, 2 * p + 1);
	d[p] = d[2 * p] + d[2 * p + 1];	//建树
	mul[p] = 1;
	return;
}

void pushdown(i64 s, i64 t, i64 p) {
	i64 mid = (s + t) / 2;
	//更新子节点的值
	d[2 * p] = ((d[2 * p] * mul[p]) % m + (add[p] * (mid - s + 1)) % m) % m;
	d[2 * p + 1] = ((d[2 * p + 1] * mul[p]) % m + (add[p] * (t - mid )) % m) % m;
	//更新子节点的mul标记
	mul[2 * p] = (mul[2 * p] * mul[p]) % m;
	mul[2 * p + 1] = (mul[2 * p + 1] * mul[p]) % m;
	//更新子节点的add标记
	add[2 * p] = (add[2 * p] * mul[p] + add[p]) % m;
	add[2 * p + 1] = (add[2 * p + 1] * mul[p] + add[p]) % m;
	//清空父节点的标记
	mul[p] = 1, add[p] = 0;
	return;
}

void muls(i64 l, i64 r, i64 k, i64 s, i64 t, i64 p) {
	if (l <= s && t <= r) {	//该区间被完全包含
		add[p] = (add[p] * k) % m;
		mul[p] = (mul[p] * k) % m;
		d[p] = (d[p] * k) % m;
		return;
	}
	pushdown(s, t, p);	//下传标记
	int mid = (s + t) / 2;
	if (l <= mid)
		muls(l, r, k, s, mid, 2 * p);
	if (r > mid)
		muls(l, r, k, mid + 1, t, 2 * p + 1);
	d[p] = (d[2 * p] + d[2 * p + 1]) % m;	//记得更新该节点的值
	return;
}

void adds(i64 l, i64 r, i64 k, i64 s, i64 t, i64 p) {
	if (l <= s && t <= r) {
		add[p] = (add[p] + k) % m;
		d[p] = (d[p] + k * (t - s + 1)) % m;
		return;
	}
	pushdown(s, t, p);
	int mid = (s + t) / 2;
	if (l <= mid)
		adds(l, r, k, s, mid, 2 * p);
	if (r > mid)
		adds(l, r, k, mid + 1, t, 2 * p + 1);
	d[p] = (d[2 * p] + d[2 * p + 1]) % m;
	return;
}

i64 find(i64 l, i64 r, i64 s, i64 t, i64 p) {
	if (l <= s && t <= r)
		return d[p];
	i64 res = 0;
	pushdown(s, t, p);	//下传标记
	i64 mid = (s + t) / 2;
	if (l <= mid)
		res += find(l, r, s, mid, 2 * p);
	if (r > mid)
		res += find(l, r, mid + 1, t, 2 * p + 1);
	return res % m;
}

int main() {
	ios::sync_with_stdio(false);
	cin.tie(0), cout.tie(0);
	cin >> n >> q >> m;
	for (i64 i = 1; i <= n; i++)
		cin >> a[i];
	build(1, n, 1);
	while (q--) {
		cin >> opt;
		if (opt == 1) {
			cin >> x >> y >> k;
			muls(x, y, k, 1, n, 1);
		}
		if (opt == 2) {
			cin >> x >> y >> k;
			adds(x, y, k, 1, n, 1);
		}
		if (opt == 3) {
			cin >> x >> y;
			cout << find(x, y, 1, n, 1) << '\n';
		}
	}
	return 0;
}
posted @ 2025-08-14 14:19  Alkaid16  阅读(6)  评论(0)    收藏  举报