算法学习(6):线段树
线段树
介绍
线段树(Segment Tree)几乎是算法竞赛最常用的数据结构了,它主要用于维护区间信息(要求满足结合律)。
建立
void build(ll l = 1, ll r = n, ll p = 1)
{
if (l == r) // 到达叶子节点
tree[p] = A[l]; // 用数组中的数据赋值
else
{
ll mid = (l + r) / 2;
build(l, mid, p * 2); // 先建立左右子节点
build(mid + 1, r, p * 2 + 1);
tree[p] = tree[p * 2] + tree[p * 2 + 1]; // 该节点的值等于左右子节点之和
}
}
区间修改
void update(ll l, ll r, ll d, ll p = 1, ll cl = 1, ll cr = n)
{
if (cl > r || cr < l) // 区间无交集
return; // 剪枝
else if (cl >= l && cr <= r) // 当前节点对应的区间包含在目标区间中
{
tree[p] += (cr - cl + 1) * d; // 更新当前区间的值
if (cr > cl) // 如果不是叶子节点
mark[p] += d; // 给当前区间打上标记
}
else // 与目标区间有交集,但不包含于其中
{
ll mid = (cl + cr) / 2;
mark[p * 2] += mark[p]; // 标记向下传递
mark[p * 2 + 1] += mark[p];
tree[p * 2] += mark[p] * (mid - cl + 1); // 往下更新一层
tree[p * 2 + 1] += mark[p] * (cr - mid);
mark[p] = 0; // 清除标记
update(l, r, d, p * 2, cl, mid); // 递归地往下寻找
update(l, r, d, p * 2 + 1, mid + 1, cr);
tree[p] = tree[p * 2] + tree[p * 2 + 1]; // 根据子节点更新当前节点的值
}
}
区间查询
ll query(ll l, ll r, ll p = 1, ll cl = 1, ll cr = n)
{
if (cl > r || cr < l)
return 0;
else if (cl >= l && cr <= r)
return tree[p];
else
{
ll mid = (cl + cr) / 2;
push_down(p, cr - cl + 1);
return query(l, r, p * 2, cl, mid) + query(l, r, p * 2 + 1, mid + 1, cr);
// 上一行拆成三行写就和区间修改格式一致了
}
}
模板1
#include <bits/stdc++.h>
#define MAXN 100005
using namespace std;
typedef long long ll;
inline ll read()
{
ll ans = 0;
char c = getchar();
while (!isdigit(c))
c = getchar();
while (isdigit(c))
{
ans = ans * 10 + c - '0';
c = getchar();
}
return ans;
}
ll n, m, A[MAXN], tree[MAXN * 4], mark[MAXN * 4]; // 经验表明开四倍空间不会越界
inline void push_down(ll p, ll len)
{
mark[p * 2] += mark[p];
mark[p * 2 + 1] += mark[p];
tree[p * 2] += mark[p] * (len - len / 2);
tree[p * 2 + 1] += mark[p] * (len / 2);
mark[p] = 0;
}
void build(ll l = 1, ll r = n, ll p = 1)
{
if (l == r)
tree[p] = A[l];
else
{
ll mid = (l + r) / 2;
build(l, mid, p * 2);
build(mid + 1, r, p * 2 + 1);
tree[p] = tree[p * 2] + tree[p * 2 + 1];
}
}
void update(ll l, ll r, ll d, ll p = 1, ll cl = 1, ll cr = n)
{
if (cl > r || cr < l)
return;
else if (cl >= l && cr <= r)
{
tree[p] += (cr - cl + 1) * d;
if (cr > cl)
mark[p] += d;
}
else
{
ll mid = (cl + cr) / 2;
push_down(p, cr - cl + 1);
update(l, r, d, p * 2, cl, mid);
update(l, r, d, p * 2 + 1, mid + 1, cr);
tree[p] = tree[p * 2] + tree[p * 2 + 1];
}
}
ll query(ll l, ll r, ll p = 1, ll cl = 1, ll cr = n)
{
if (cl > r || cr < l)
return 0;
else if (cl >= l && cr <= r)
return tree[p];
else
{
ll mid = (cl + cr) / 2;
push_down(p, cr - cl + 1);
return query(l, r, p * 2, cl, mid) + query(l, r, p * 2 + 1, mid + 1, cr);
}
}
int main()
{
n = read();
m = read();
for (int i = 1; i <= n; ++i)
A[i] = read();
build();
for (int i = 0; i < m; ++i)
{
ll opr = read(), l = read(), r = read();
if (opr == 1)
{
ll d = read();
update(l, r, d);
}
else
printf("%lld\n", query(l, r));
}
return 0;
}
模板2
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const int MAXN = 1e5 + 5;
ll tree[MAXN << 2], mark[MAXN << 2], n, m, A[MAXN];
void push_down(int p, int len)
{
tree[p << 1] += mark[p] * (len - len / 2);
mark[p << 1] += mark[p];
tree[p << 1 | 1] += mark[p] * (len / 2);
mark[p << 1 | 1] += mark[p];
mark[p] = 0;
}
void build(int p = 1, int cl = 1, int cr = n)
{
if (cl == cr) { tree[p] = A[cl]; return; }
int mid = (cl + cr) >> 1;
build(p << 1, cl, mid);
build(p << 1 | 1, mid + 1, cr);
tree[p] = tree[p << 1] + tree[p << 1 | 1];
}
ll query(int l, int r, int p = 1, int cl = 1, int cr = n)
{
if (cl >= l && cr <= r) return tree[p];
push_down(p, cr - cl + 1);
ll mid = (cl + cr) >> 1, ans = 0;
if (mid >= l) ans += query(l, r, p << 1, cl, mid);
if (mid < r) ans += query(l, r, p << 1 | 1, mid + 1, cr);
return ans;
}
void update(int l, int r, int d, int p = 1, int cl = 1, int cr = n)
{
if (cl >= l && cr <= r) { tree[p] += d * (cr - cl + 1), mark[p] += d; return; }
push_down(p, cr - cl + 1);
int mid = (cl + cr) >> 1;
if (mid >= l) update(l, r, d, p << 1, cl, mid);
if (mid < r) update(l, r, d, p << 1 | 1, mid + 1, cr);
tree[p] = tree[p << 1] + tree[p << 1 | 1];
}
int main()
{
ios::sync_with_stdio(false);
cin >> n >> m;
for (int i = 1; i <= n; ++i)
cin >> A[i];
build();
while (m--)
{
int o, l, r, d;
cin >> o >> l >> r;
if (o == 1)
cin >> d, update(l, r, d);
else
cout << query(l, r) << '\n';
}
return 0;
}
浙公网安备 33010602011771号