2025.5.2笔记
免责声明
我的代码由于大部分是复制粘贴的模板,再在其基础上改的,因此有些注释可能不合时宜。
一些代码是我很早以前写的,所以可能有许多冗余或不规范的地方。
请谅解。
如果我有写错或令人不理解的地方,请及时指出,谢谢!!!
线段树
解决单点加,区间加,区间求和。
这类问题基本上就是往线段树上去想。
大部分数据结构优化都是用线段树。
树状数组没用。——钟皓曦
基本结构:

线段树有 \(\log n\) 层。
线段树共有 \(1 + 2 + 4 + \cdots + 2^{\log n} \approx 2^{\log n} = 2n\) 个节点。
在每一个节点存一个和(sum)的信息。
区间合并:自己的 sum \(=\) 左儿子的 sum \(+\) 右儿子的 sum。
数组开空间的时候,要记得开 \(4\) 倍,因为它的标号方式不一定是到 \(2n\)。
钟皓曦的写法:
点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define root 1, n, 1
#define lson l, m, rt << 1
#define rson m + 1, r, rt << 1 | 1
const int maxn = 100010;
int n, m, a[maxn];
struct node // 一个线段树节点
{
int sum; // 代表区间和
int size; // 代表区间长度
int add; // 这段区间被整体加了多少
node()
{
sum = size = add = 0;
}
void init(int v) // 用一个数初始化
{
sum = v;
size = 1;
}
} z[maxn << 2]; // z[i]就代表线段树的第i个节点
node operator+(const node &l, const node &r)
{
node res;
res.sum = l.sum + r.sum;
res.size = l.size + r.size;
return res;
}
void color(int l, int r, int rt, int v) // 给l,r,rt这个节点打一个+v的懒标记
{
z[rt].add += v;
z[rt].sum += z[rt].size * v;
}
void push_col(int l, int r, int rt) // 标记下放 把标记告诉儿子
{
if (z[rt].add == 0)
return; // 没标记 不需要下放 可以不要这句话 但会慢些
int m = (l + r) >> 1;
color(lson, z[rt].add);
color(rson, z[rt].add);
z[rt].add = 0;
}
void build(int l, int r, int rt) // 建树 初始化l,r,rt这个节点
// 编号为rt的线段树节点 所对应的区间是l~r
{
if (l == r)
{
z[rt].init(a[l]);
return;
}
int m = (l + r) >> 1;
build(lson);
build(rson);
z[rt] = z[rt << 1] + z[rt << 1 | 1];
}
node query(int l, int r, int rt, int nowl, int nowr)
// l,r,rt描述了一个线段树节点
// nowl nowr代表了询问的区间的左端点和右端点
{
if (nowl <= l && r <= nowr)
return z[rt];
push_col(l, r, rt);
int m = (l + r) >> 1;
if (nowl <= m)
{
if (m < nowr)
return query(lson, nowl, nowr) + query(rson, nowl, nowr);
else
return query(lson, nowl, nowr);
}
else
return query(rson, nowl, nowr);
}
void modify(int l, int r, int rt, int nowl, int nowr, int v)
// 把nowl~nowr这段区间全部整体+v
{
if (nowl <= l && r <= nowr) // 当前线段树节点被修改区间整体包含
{
color(l, r, rt, v); // 给l,r,rt这个节点打一个+v的懒标记
return;
}
push_col(l, r, rt);
int m = (l + r) >> 1;
if (nowl <= m)
modify(lson, nowl, nowr, v);
if (m < nowr)
modify(rson, nowl, nowr, v);
z[rt] = z[rt << 1] + z[rt << 1 | 1];
}
int main()
{
cin >> n;
for (int i = 1; i <= n; i++)
cin >> a[i];
build(root);
cin >> m;
for (int i = 1; i <= m; i++)
{
int opt;
cin >> opt;
if (opt == 1) // 询问
{
int l, r;
cin >> l >> r;
cout << query(root, l, r).sum << "\n";
}
else
{
int l, r, v;
cin >> l >> r >> v;
modify(root, l, r, v);
}
}
return 0;
}
我(王铭宇)的写法:
点击查看代码
#include <iostream>
#include <algorithm>
#include <vector>
#include <queue>
#include <cstdio>
#include <cstdlib>
#include <cmath>
#include <cstring>
using namespace std;
const int N = 10000005;
typedef long long ll;
int n;
ll a[N];
ll sum[N << 2];
ll mx[N << 2];
ll tag[N];
int ls(int x)
{
return x << 1;
}
int rs(int x)
{
return x << 1 | 1;
}
void pushup(int x)
{
sum[x] = sum[ls(x)] + sum[rs(x)];
mx[x] = max(mx[ls(x)], mx[rs(x)]);
}
void add(int x, int l, int r, ll k)
{
sum[x] += k * (r - l + 1);
mx[x] += k;
tag[x] += k;
}
void pushdown(int x, int l, int r)
{
int mid = (l + r) >> 1;
if (tag[x] != 0)
{
add(ls(x), l, mid, tag[x]);
add(rs(x), mid + 1, r, tag[x]);
tag[x] = 0;
}
}
void build(int x, int l, int r)
{
if (l == r)
{
sum[x] = mx[x] = a[l];
return ;
}
int mid = (l + r) >> 1;
build(ls(x), l, mid);
build(rs(x), mid + 1, r);
pushup(x);
}
ll query_sum(int x, int l, int r, int L, int R)
{
if (L <= l && r <= R)
return sum[x];
pushdown(x, l, r);
ll ret = 0;
int mid = (l + r) >> 1;
if (L <= mid)
{
ret += query_sum(ls(x), l, mid, L, R);
}
if (mid < R)
{
ret += query_sum(rs(x), mid + 1, r, L, R);
}
return ret;
}
ll query_max(int x, int l, int r, int L, int R)
{
if (L <= l && r <= R)
return mx[x];
pushdown(x, l, r);
ll ret = -2e9;
int mid = (l + r) >> 1;
if (L <= mid)
{
ret = max(ret, query_max(ls(x), l, mid, L, R));
}
if (mid < R)
{
ret = max(ret, query_max(rs(x), mid + 1, r, L, R));
}
return ret;
}
void add_one(int x, int l, int r, int p, int k)
{
if (l == r)
{
sum[x] += k;
mx[x] += k;
return;
}
pushdown(x, l, r);
int mid = (l + r) >> 1;
if (p <= mid)
add_one(ls(x), l, mid, p, k);
else
add_one(rs(x), mid + 1, r, p, k);
pushup(x);
}
void add_many(int x, int l, int r, int L, int R, ll k)
{
if (L <= l && R >= r)
{
add(x, l, r, k);
return ;
}
pushdown(x, l, r);
int mid = (l + r) >> 1;
if (L <= mid)
add_many(ls(x), l, mid, L, R, k);
if (mid < R)
add_many(rs(x), mid + 1, r, L, R, k);
pushup(x);
}
int main()
{
int m;
cin >> n >> m;
for (int i = 1; i <= n; ++i)
{
cin >> a[i];
}
build(1, 1, n);
while (m--)
{
int opt;
cin >> opt;
if (opt == 1)
{
int x, y;
ll k;
cin >> x >> y >> k;
add_many(1, 1, n, x, y, k);
}
if (opt == 2)
{
int x, y;
cin >> x >> y;
cout << query_sum(1, 1, n, x, y) << '\n';
}
}
return 0;
}
建议大家用钟皓曦的写法,因为这种写法使用结构体存储每个节点的信息,虽代码量可能较大,但具有通用性,可在此基础上做很小的改动来解决一些复杂的题。
假设还要询问 max 和 min 的值的话,钟神的代码:
点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define root 1, n, 1
#define lson l, m, rt << 1
#define rson m + 1, r, rt << 1 | 1
const int maxn = 100010;
int n, m, a[maxn];
struct node // 一个线段树节点
{ // 第一个要修改的地方:要维护的东西
int sum; // 代表区间和
int minv; // 代表区间最小值
int maxv; // 代表区间最大值
int size; // 代表区间长度
int add; // 这段区间被整体加了多少
node()
{
sum = size = add = minv = maxv = 0;
}
void init(int v) // 用一个数初始化
{ // 第二个修改的地方:怎么用一个数初始化
sum = minv = maxv = v;
size = 1;
}
} z[maxn << 2]; // z[i]就代表线段树的第i个节点
node operator+(const node &l, const node &r)
{ // 第三个需要修改的地方:左右儿子怎么合并
node res;
res.sum = l.sum + r.sum;
res.size = l.size + r.size;
res.minv = min(l.minv, r.minv);
res.maxv = max(l.maxv, r.maxv);
return res;
}
void color(int l, int r, int rt, int v) // 给l,r,rt这个节点打一个+v的懒标记
{ // 第四个需要修改的地方:怎么打标记
z[rt].add += v;
z[rt].sum += z[rt].size * v;
z[rt].minv += v;
z[rt].maxv += v;
}
void push_col(int l, int r, int rt) // 标记下放 把标记告诉儿子
{
if (z[rt].add == 0)
return; // 没标记 不需要下放 可以不要这句话 但会慢些
int m = (l + r) >> 1;
color(lson, z[rt].add);
color(rson, z[rt].add);
z[rt].add = 0;
}
void build(int l, int r, int rt) // 建树 初始化l,r,rt这个节点
// 编号为rt的线段树节点 所对应的区间是l~r
{
if (l == r)
{
z[rt].init(a[l]);
return;
}
int m = (l + r) >> 1;
build(lson);
build(rson);
z[rt] = z[rt << 1] + z[rt << 1 | 1];
}
node query(int l, int r, int rt, int nowl, int nowr)
// l,r,rt描述了一个线段树节点
// nowl nowr代表了询问的区间的左端点和右端点
{
if (nowl <= l && r <= nowr)
return z[rt];
push_col(l, r, rt);
int m = (l + r) >> 1;
if (nowl <= m)
{
if (m < nowr)
return query(lson, nowl, nowr) + query(rson, nowl, nowr);
else
return query(lson, nowl, nowr);
}
else
return query(rson, nowl, nowr);
}
void modify(int l, int r, int rt, int nowl, int nowr, int v)
// 把nowl~nowr这段区间全部整体+v
{
if (nowl <= l && r <= nowr) // 当前线段树节点被修改区间整体包含
{
color(l, r, rt, v); // 给l,r,rt这个节点打一个+v的懒标记
return;
}
push_col(l, r, rt);
int m = (l + r) >> 1;
if (nowl <= m)
modify(lson, nowl, nowr, v);
if (m < nowr)
modify(rson, nowl, nowr, v);
z[rt] = z[rt << 1] + z[rt << 1 | 1];
}
int main()
{
cin >> n;
for (int i = 1; i <= n; i++)
cin >> a[i];
build(root);
cin >> m;
for (int i = 1; i <= m; i++)
{
int opt;
cin >> opt;
if (opt == 1) // 询问
{
int l, r;
cin >> l >> r;
cout << query(root, l, r).sum << "\n";
}
else
{
int l, r, v;
cin >> l >> r >> v;
modify(root, l, r, v);
}
}
return 0;
}
例1

不用结构体可能根本写不出来。——钟皓曦
点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define root 1, n, 1
#define lson l, m, rt << 1
#define rson m + 1, r, rt << 1 | 1
const int maxn = 100010;
int n, m, a[maxn];
struct node // 一个线段树节点
{
int sum; // 代表区间相邻两数差的绝对值的和
int lv; // 最左边的数是多少
int rv; // 最右边的数是多少
int add; // 这段区间被整体加了多少
int size; // 区间长度
node()
{
sum = add = 0;
}
void init(int v) // 用一个数初始化
{
sum = 0;
lv = rv = v;
size = 1;
}
} z[maxn << 2]; // z[i]就代表线段树的第i个节点
node operator+(const node &l, const node &r)
{
node res;
res.sum = l.sum + r.sum + abs(l.rv - r.lv);
res.lv = l.lv;
res.rv = r.rv;
return res;
}
void color(int l, int r, int rt, int v) // 给l,r,rt这个节点打一个+v的懒标记
{
z[rt].add += v;
z[rt].lv += v;
z[rt].rv += v;
}
void push_col(int l, int r, int rt) // 标记下放 把标记告诉儿子
{
if (z[rt].add == 0)
return; // 没标记 不需要下放 可以不要这句话 但会慢些
int m = (l + r) >> 1;
color(lson, z[rt].add);
color(rson, z[rt].add);
z[rt].add = 0;
}
void build(int l, int r, int rt) // 建树 初始化l,r,rt这个节点
// 编号为rt的线段树节点 所对应的区间是l~r
{
if (l == r)
{
z[rt].init(a[l]);
return;
}
int m = (l + r) >> 1;
build(lson);
build(rson);
z[rt] = z[rt << 1] + z[rt << 1 | 1];
}
node query(int l, int r, int rt, int nowl, int nowr)
// l,r,rt描述了一个线段树节点
// nowl nowr代表了询问的区间的左端点和右端点
{
if (nowl <= l && r <= nowr)
return z[rt];
push_col(l, r, rt);
int m = (l + r) >> 1;
if (nowl <= m)
{
if (m < nowr)
return query(lson, nowl, nowr) + query(rson, nowl, nowr);
else
return query(lson, nowl, nowr);
}
else
return query(rson, nowl, nowr);
}
void modify(int l, int r, int rt, int nowl, int nowr, int v)
// 把nowl~nowr这段区间全部整体+v
{
if (nowl <= l && r <= nowr) // 当前线段树节点被修改区间整体包含
{
color(l, r, rt, v); // 给l,r,rt这个节点打一个+v的懒标记
return;
}
push_col(l, r, rt);
int m = (l + r) >> 1;
if (nowl <= m)
modify(lson, nowl, nowr, v);
if (m < nowr)
modify(rson, nowl, nowr, v);
z[rt] = z[rt << 1] + z[rt << 1 | 1];
}
int main()
{
cin >> n;
for (int i = 1; i <= n; i++)
cin >> a[i];
build(root);
cin >> m;
for (int i = 1; i <= m; i++)
{
int opt;
cin >> opt;
if (opt == 1) // 询问
{
int l, r;
cin >> l >> r;
cout << query(root, l, r).sum << "\n";
}
else
{
int l, r, v;
cin >> l >> r >> v;
modify(root, l, r, v);
}
}
return 0;
}
例2

\(\displaystyle\sum^r_{i = l}(x_i + v)^2 = \sum^r_{i = l}x_i^2 + 2v\sum^r_{i = l}x_i + v(r - l + 1)\)
点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define root 1, n, 1
#define lson l, m, rt << 1
#define rson m + 1, r, rt << 1 | 1
const int maxn = 100010;
int n, m, a[maxn];
struct node // 一个线段树节点
{
int sum; // 代表区间和
int sum2; // 代表区间平方和
int size; // 代表区间长度
int add; // 这段区间被整体加了多少
node()
{
sum = size = add = 0;
}
void init(int v) // 用一个数初始化
{
sum = v;
sum2 = v * v;
size = 1;
}
} z[maxn << 2]; // z[i]就代表线段树的第i个节点
node operator+(const node &l, const node &r)
{
node res;
res.sum2 = l.sum2 + r.sum2;
res.sum = l.sum + r.sum;
res.size = l.size + r.size;
return res;
}
void color(int l, int r, int rt, int v) // 给l,r,rt这个节点打一个+v的懒标记
{
z[rt].add += v;
z[rt].sum2 = z[rt].sum2 + 2 * v * z[rt].sum + z[rt].size * v * v;
z[rt].sum += z[rt].size * v;
}
void push_col(int l, int r, int rt) // 标记下放 把标记告诉儿子
{
if (z[rt].add == 0)
return; // 没标记 不需要下放 可以不要这句话 但会慢些
int m = (l + r) >> 1;
color(lson, z[rt].add);
color(rson, z[rt].add);
z[rt].add = 0;
}
void build(int l, int r, int rt) // 建树 初始化l,r,rt这个节点
// 编号为rt的线段树节点 所对应的区间是l~r
{
if (l == r)
{
z[rt].init(a[l]);
return;
}
int m = (l + r) >> 1;
build(lson);
build(rson);
z[rt] = z[rt << 1] + z[rt << 1 | 1];
}
node query(int l, int r, int rt, int nowl, int nowr)
// l,r,rt描述了一个线段树节点
// nowl nowr代表了询问的区间的左端点和右端点
{
if (nowl <= l && r <= nowr)
return z[rt];
push_col(l, r, rt);
int m = (l + r) >> 1;
if (nowl <= m)
{
if (m < nowr)
return query(lson, nowl, nowr) + query(rson, nowl, nowr);
else
return query(lson, nowl, nowr);
}
else
return query(rson, nowl, nowr);
}
void modify(int l, int r, int rt, int nowl, int nowr, int v)
// 把nowl~nowr这段区间全部整体+v
{
if (nowl <= l && r <= nowr) // 当前线段树节点被修改区间整体包含
{
color(l, r, rt, v); // 给l,r,rt这个节点打一个+v的懒标记
return;
}
push_col(l, r, rt);
int m = (l + r) >> 1;
if (nowl <= m)
modify(lson, nowl, nowr, v);
if (m < nowr)
modify(rson, nowl, nowr, v);
z[rt] = z[rt << 1] + z[rt << 1 | 1];
}
int main()
{
cin >> n;
for (int i = 1; i <= n; i++)
cin >> a[i];
build(root);
cin >> m;
for (int i = 1; i <= m; i++)
{
int opt;
cin >> opt;
if (opt == 1) // 询问
{
int l, r;
cin >> l >> r;
cout << query(root, l, r).sum << "\n";
}
else
{
int l, r, v;
cin >> l >> r >> v;
modify(root, l, r, v);
}
}
return 0;
}
例3
只有乘法的情况:
点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define root 1, n, 1
#define lson l, m, rt << 1
#define rson m + 1, r, rt << 1 | 1
const int maxn = 100010;
int n, m, a[maxn];
struct node // 一个线段树节点
{ // 第一个要修改的地方:标记的定义
int sum; // 代表区间和
int size; // 代表区间长度
int mul; // 这段区间被整体乘了多少
node()
{ // 第二个要修改的地方:标记的初始化
sum = size = 0;
mul = 1;
}
void init(int v) // 用一个数初始化
{
sum = v;
size = 1;
}
} z[maxn << 2]; // z[i]就代表线段树的第i个节点
node operator+(const node &l, const node &r)
{
node res;
res.sum = l.sum + r.sum;
res.size = l.size + r.size;
return res;
}
void color(int l, int r, int rt, int v)
{ // 第三个要修改的地方:打标记
z[rt].mul *= v;
z[rt].sum *= v;
}
void push_col(int l, int r, int rt) // 标记下放 把标记告诉儿子
{ // 第四个要修改的地方:下放标记
if (z[rt].mul == 1)
return; // 没标记 不需要下放 可以不要这句话 但会慢些
int m = (l + r) >> 1;
color(lson, z[rt].mul);
color(rson, z[rt].mul);
z[rt].mul = 1;
}
void build(int l, int r, int rt) // 建树 初始化l,r,rt这个节点
// 编号为rt的线段树节点 所对应的区间是l~r
{
if (l == r)
{
z[rt].init(a[l]);
return;
}
int m = (l + r) >> 1;
build(lson);
build(rson);
z[rt] = z[rt << 1] + z[rt << 1 | 1];
}
node query(int l, int r, int rt, int nowl, int nowr)
// l,r,rt描述了一个线段树节点
// nowl nowr代表了询问的区间的左端点和右端点
{
if (nowl <= l && r <= nowr)
return z[rt];
push_col(l, r, rt);
int m = (l + r) >> 1;
if (nowl <= m)
{
if (m < nowr)
return query(lson, nowl, nowr) + query(rson, nowl, nowr);
else
return query(lson, nowl, nowr);
}
else
return query(rson, nowl, nowr);
}
void modify(int l, int r, int rt, int nowl, int nowr, int v)
// 把nowl~nowr这段区间全部整体+v
{
if (nowl <= l && r <= nowr) // 当前线段树节点被修改区间整体包含
{
color(l, r, rt, v); // 给l,r,rt这个节点打一个+v的懒标记
return;
}
push_col(l, r, rt);
int m = (l + r) >> 1;
if (nowl <= m)
modify(lson, nowl, nowr, v);
if (m < nowr)
modify(rson, nowl, nowr, v);
z[rt] = z[rt << 1] + z[rt << 1 | 1];
}
int main()
{
cin >> n;
for (int i = 1; i <= n; i++)
cin >> a[i];
build(root);
cin >> m;
for (int i = 1; i <= m; i++)
{
int opt;
cin >> opt;
if (opt == 1) // 询问
{
int l, r;
cin >> l >> r;
cout << query(root, l, r).sum << "\n";
}
else
{
int l, r, v;
cin >> l >> r >> v;
modify(root, l, r, v);
}
}
return 0;
}
如果还有加法:
点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define root 1, n, 1
#define lson l, m, rt << 1
#define rson m + 1, r, rt << 1 | 1
const int maxn = 100010;
int n, m, a[maxn];
struct node // 一个线段树节点
{
int sum; // 代表区间和
int size; // 代表区间长度
int add;
int mul;
// x*mul+add
node()
{
sum = size = add = 0;
mul = 1;
}
void init(int v) // 用一个数初始化
{
sum = v;
size = 1;
}
} z[maxn << 2]; // z[i]就代表线段树的第i个节点
node operator+(const node &l, const node &r)
{
node res;
res.sum = l.sum + r.sum;
res.size = l.size + r.size;
return res;
}
void color(int l, int r, int rt, int mul, int add) // 给l,r,rt这个节点打一个*mul+add的懒标记
{
z[rt].mul *= mul;
z[rt].add = z[rt].add * mul + add;
z[rt].sum = z[rt].sum * mul + add * z[rt].size;
}
void push_col(int l, int r, int rt) // 标记下放 把标记告诉儿子
{
if (z[rt].mul == 1 && z[rt].add == 0)
return; // 没标记 不需要下放 可以不要这句话 但会慢些
int m = (l + r) >> 1;
color(lson, z[rt].mul, z[rt].add);
color(rson, z[rt].mul, z[rt].add);
z[rt].mul = 1;
z[rt].add = 0;
}
void build(int l, int r, int rt) // 建树 初始化l,r,rt这个节点
// 编号为rt的线段树节点 所对应的区间是l~r
{
if (l == r)
{
z[rt].init(a[l]);
return;
}
int m = (l + r) >> 1;
build(lson);
build(rson);
z[rt] = z[rt << 1] + z[rt << 1 | 1];
}
node query(int l, int r, int rt, int nowl, int nowr)
// l,r,rt描述了一个线段树节点
// nowl nowr代表了询问的区间的左端点和右端点
{
if (nowl <= l && r <= nowr)
return z[rt];
push_col(l, r, rt);
int m = (l + r) >> 1;
if (nowl <= m)
{
if (m < nowr)
return query(lson, nowl, nowr) + query(rson, nowl, nowr);
else
return query(lson, nowl, nowr);
}
else
return query(rson, nowl, nowr);
}
void modify(int l, int r, int rt, int nowl, int nowr, int mul, int add)
// 把nowl~nowr这段区间全部整体+v
{
if (nowl <= l && r <= nowr) // 当前线段树节点被修改区间整体包含
{
color(l, r, rt, mul, add); // 给l,r,rt这个节点打一个+v的懒标记
return;
}
push_col(l, r, rt);
int m = (l + r) >> 1;
if (nowl <= m)
modify(lson, nowl, nowr, mul, add);
if (m < nowr)
modify(rson, nowl, nowr, mul, add);
z[rt] = z[rt << 1] + z[rt << 1 | 1];
}
int main()
{
cin >> n;
for (int i = 1; i <= n; i++)
cin >> a[i];
build(root);
cin >> m;
for (int i = 1; i <= m; i++)
{
int opt;
cin >> opt;
if (opt == 1) // 询问
{
int l, r;
cin >> l >> r;
cout << query(root, l, r).sum << "\n";
}
else
{
int l, r, add, mul;
cin >> l >> r >> add >> mul;
modify(root, l, r, add, mul);
}
}
return 0;
}
一定要考虑标记生效的顺序。
拓展:区间推平一定要先推平,把懒标记删掉。
例4
给区间加上一个等差数列。
一个区间被加上了多个等差数列,它还是被加上了一个等差数列。
于是,每个节点要维护首项、公差。
点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define root 1, n, 1
#define lson l, m, rt << 1
#define rson m + 1, r, rt << 1 | 1
const int maxn = 100010;
int n, m, a[maxn];
struct node // 一个线段树节点
{
int sum; // 代表区间和
int size; // 代表区间长度
int x, y; // 给这段区间加上了一个首项为x 公差为y的等差数列
node()
{
sum = size = x = y = 0;
}
void init(int v) // 用一个数初始化
{
sum = v;
size = 1;
}
} z[maxn << 2]; // z[i]就代表线段树的第i个节点
node operator+(const node &l, const node &r)
{
node res;
res.sum = l.sum + r.sum;
res.size = l.size + r.size;
return res;
}
void color(int l, int r, int rt, int x, int y) // 给l,r,rt这个节点加上一个首项为x公差为y的等差数列
{
z[rt].x += x;
z[rt].y += y;
z[rt].sum += (x + x + (z[rt].size - 1) * y) * z[rt].size / 2;
}
void push_col(int l, int r, int rt) // 标记下放 把标记告诉儿子
{
if (z[rt].x == 0 && z[rt].y == 0)
return; // 没标记 不需要下放 可以不要这句话 但会慢些
int m = (l + r) >> 1;
color(lson, z[rt].x, z[rt].y);
color(rson, z[rt].x + z[rt << 1].size * z[rt].y, z[rt].y);
z[rt].x = z[rt].y = 0;
}
void build(int l, int r, int rt) // 建树 初始化l,r,rt这个节点
// 编号为rt的线段树节点 所对应的区间是l~r
{
if (l == r)
{
z[rt].init(a[l]);
return;
}
int m = (l + r) >> 1;
build(lson);
build(rson);
z[rt] = z[rt << 1] + z[rt << 1 | 1];
}
node query(int l, int r, int rt, int nowl, int nowr)
// l,r,rt描述了一个线段树节点
// nowl nowr代表了询问的区间的左端点和右端点
{
if (nowl <= l && r <= nowr)
return z[rt];
push_col(l, r, rt);
int m = (l + r) >> 1;
if (nowl <= m)
{
if (m < nowr)
return query(lson, nowl, nowr) + query(rson, nowl, nowr);
else
return query(lson, nowl, nowr);
}
else
return query(rson, nowl, nowr);
}
void modify(int l, int r, int rt, int nowl, int nowr, int x, int y)
// 把nowl~nowr这段区间全部整体+v
{
if (nowl <= l && r <= nowr) // 当前线段树节点被修改区间整体包含
{
color(l, r, rt, x, y); // 给l,r,rt这个节点打一个+v的懒标记
return;
}
push_col(l, r, rt);
int m = (l + r) >> 1;
if (nowl <= m)
modify(lson, nowl, nowr, x, y);
if (m < nowr)
modify(rson, nowl, nowr, x, y);
z[rt] = z[rt << 1] + z[rt << 1 | 1];
}
int main()
{
cin >> n;
for (int i = 1; i <= n; i++)
cin >> a[i];
build(root);
cin >> m;
for (int i = 1; i <= m; i++)
{
int opt;
cin >> opt;
if (opt == 1) // 询问
{
int l, r;
cin >> l >> r;
cout << query(root, l, r).sum << "\n";
}
else
{
int l, r, x, y;
cin >> l >> r >> x >> y;
modify(root, l, r, x, y);
}
}
return 0;
}
取模
不要 #define int long long,容易 TLE。

乘法一定要转 long long,尽量不要多取模。


三次方会爆 long long。

比取模快很多。

c++对负数取模还是负数。

取模很慢。
例5
可能左儿子,可能右儿子,也有可能跨越中间。
所以要维护最大字段和、最大后缀和、最大前缀和、区间和。
\(maxsum = \max\{l.maxsum, r.maxsum, l.maxsuff + r.maxpre\}\)
\(maxpre = \max\{l.maxpre, l.sum + r.maxpre\}\)
\(maxsuff = \max\{r.maxsuff, r.sum + l.maxsuff\}\)
我的远古码风:
点击查看代码
#include <iostream>
#include <vector>
#include <queue>
#include <cstdio>
#include <cstdlib>
#include <cmath>
#include <cstring>
#include <algorithm>
#define int long long
#define inf 1e18 + 7
using namespace std;
const int N = 5e4 + 10;
typedef long long ll;
ll a[N << 2];
ll sum[N << 2];
ll ans[N << 2];
ll max_sum_front[N << 2];
ll max_sum_back[N << 2];
struct Node
{
ll frt, bck, sum, ans;
};
int ls(int x)
{
return x << 1;
}
int rs(int x)
{
return x << 1 | 1;
}
void push_up(int x)
{
sum[x] = sum[ls(x)] + sum[rs(x)];
max_sum_front[x] = max(max_sum_front[ls(x)], sum[ls(x)] + max_sum_front[rs(x)]);
max_sum_back[x] = max(max_sum_back[rs(x)], sum[rs(x)] + max_sum_back[ls(x)]);
ans[x] = max(max(ans[ls(x)], ans[rs(x)]), max_sum_back[ls(x)] + max_sum_front[rs(x)]);
}
void build(int x, int l, int r)
{
if (l == r)
{
max_sum_front[x] = max_sum_back[x] = sum[x] = ans[x] = a[l];
return;
}
int mid = (l + r) >> 1;
build(ls(x), l, mid);
build(rs(x), mid + 1, r);
push_up(x);
}
void gai(int x, int l, int r, int p, ll k)
{
if (l == r && l == p)
{
sum[x] = max_sum_back[x] = max_sum_front[x] = ans[x] = k;
return;
}
int mid = (l + r) >> 1;
if (p <= mid)
gai(ls(x), l, mid, p, k);
else
gai(rs(x), mid + 1, r, p, k);
push_up(x);
}
Node query(int x, int l, int r, int L, int R)
{
if (L <= l && r <= R)
return (Node){max_sum_front[x], max_sum_back[x], sum[x], ans[x]};
int mid = (l + r) >> 1;
bool fl1 = 0, fl2 = 0;
Node a, b;
if (L <= mid)
{
fl1 = 1;
a = query(ls(x), l, mid, L, R);
}
if (R > mid)
{
fl2 = 1;
b = query(rs(x), mid + 1, r, L, R);
}
if (!fl1 && fl2)
return b;
if (!fl2 && fl1)
return a;
return (Node){max(a.frt, a.sum + b.frt), max(b.bck, b.sum + a.bck), a.sum + b.sum, max(max(a.ans, b.ans), a.bck + b.frt)};
}
signed main()
{
int n;
cin >> n;
for (int i = 1; i <= n; ++i)
cin >> a[i];
build(1, 1, n);
int q;
cin >> q;
while (q--)
{
int opt, x, y;
cin >> opt >> x >> y;
if (opt == 0)
gai(1, 1, n, x, y);
if (opt == 1)
cout << query(1, 1, n, x, y).ans << '\n';
}
return 0;
}
例6
分类讨论。
区间重合和区间不重合。
点击查看代码
#include <iostream>
#include <algorithm>
#define int long long
using std::max;
const int N = 1e4 + 10;
typedef long long ll;
struct Node
{
ll sum;
ll max_pre;
ll max_suff;
ll max_sum;
Node()
{
sum = max_pre = max_suff = max_sum = 0;
}
void init(int v)
{
sum = max_pre = max_suff = max_sum = v;
}
friend Node operator+(const Node &a, const Node &b)
{
Node ret;
ret.sum = a.sum + b.sum;
ret.max_pre = max(a.max_pre, a.sum + b.max_pre);
ret.max_suff = max(b.max_suff, b.sum + a.max_suff);
ret.max_sum = max({a.max_sum, b.max_sum, a.max_suff + b.max_pre});
return ret;
}
};
Node z[N << 2];
ll a[N];
// 构建线段树
void build(int l, int r, int rt)
{
if (l == r)
{
z[rt].init(a[l]);
return;
}
int m = (l + r) >> 1;
build(l, m, rt << 1);
build(m + 1, r, rt << 1 | 1);
z[rt] = z[rt << 1] + z[rt << 1 | 1];
}
// 查询线段树
Node query(int l, int r, int rt, int nowl, int nowr)
{
if (nowl <= l && r <= nowr)
{
return z[rt];
}
int m = (l + r) >> 1;
if (nowl <= m)
{
if (nowr > m)
{
return query(l, m, rt << 1, nowl, nowr) + query(m + 1, r, rt << 1 | 1, nowl, nowr);
}
else
{
return query(l, m, rt << 1, nowl, nowr);
}
}
else
{
return query(m + 1, r, rt << 1 | 1, nowl, nowr);
}
}
// 快速读入函数
template <typename T>
void readin(T &x)
{
x = 0;
int f = 1;
char c = getchar();
while (c < '0' || c > '9')
{
if (c == '-')
f = -f;
c = getchar();
}
while (c >= '0' && c <= '9')
{
x = x * 10 + (c - '0');
c = getchar();
}
x = x * f;
}
// 快速输出函数
void write(ll x)
{
if (x < 0)
{
putchar('-');
x = -x;
}
if (x < 10)
{
putchar(x + '0');
return;
}
write(x / 10);
putchar(x % 10 + '0');
}
signed main()
{
int t;
readin(t);
while (t--)
{
int n;
readin(n);
for (int i = 1; i <= n; ++i)
{
readin(a[i]);
}
build(1, n, 1);
int m;
readin(m);
while (m--)
{
int x1, y1, x2, y2;
readin(x1);
readin(y1);
readin(x2);
readin(y2);
ll ans;
if (y1 < x2)
{
ans = query(1, n, 1, x1, y1).max_suff + query(1, n, 1, y1, x2).sum + query(1, n, 1, x2, y2).max_pre - a[y1] - a[x2];
}
else
{
ll case1 = query(1, n, 1, x1, x2).max_suff + query(1, n, 1, x2, y1).max_pre - a[x2];
ll case2 = query(1, n, 1, x1, x2).max_suff + query(1, n, 1, x2, y1).sum + query(1, n, 1, y1, y2).max_pre - a[x2] - a[y1];
ll case3 = query(1, n, 1, x2, y1).max_sum;
ll case4 = query(1, n, 1, x2, y1).max_suff + query(1, n, 1, y1, y2).max_pre - a[y1];
ans = max({case1, case2, case3, case4});
}
write(ans);
putchar('\n');
}
}
return 0;
}
例7
开方无法维护。
每个数最多 \(1e18\),最多开方 \(10\) 次以内。
开方到 \(1\) 或者 \(0\) 就不用开了。
维护区间和,区间最大值。
如果最大值已经 \(\leq 1\) 了,那就没有必要开方了。
这样复杂度就很小了。
我的远古码风:
点击查看代码
#include <iostream>
#include <vector>
#include <queue>
#include <cstdio>
#include <cstdlib>
#include <cmath>
#include <cstring>
#include <algorithm>
using namespace std;
const int N = 1e5 + 10;
typedef long long ll;
bool is[N << 2];
ll a[N];
ll sum[N << 2];
ll cnt[N << 2];
ll readin()
{
ll ret = 0, f = 1;
char c = getchar();
while (c < '0' || c > '9')
{
if (c == '-')
f = -1;
c = getchar();
}
while (c >= '0' && c <= '9')
{
ret = ret * 10 + (c - '0');
c = getchar();
}
return ret * f;
}
int ls(int x)
{
return x << 1;
}
int rs(int x)
{
return x << 1 | 1;
}
void push_up(int x)
{
is[x] = (is[ls(x)] && is[rs(x)]);
sum[x] = sum[ls(x)] + sum[rs(x)];
}
void build(int x, int l, int r)
{
if (l == r)
{
is[x] = (a[l] == 0 || a[l] == 1);
sum[x] = a[l];
return ;
}
int mid = (l + r) >> 1;
build(ls(x), l, mid);
build(rs(x), mid + 1, r);
push_up(x);
}
void gai_lr(int x, int l, int r, int L, int R)
{
if (l == r)
{
sum[x] = sqrt(sum[x]);
is[x] = (sum[x] == 0 || sum[x] == 1);
return ;
}
int mid = (l + r) >> 1;
if (L <= mid)
if (!is[ls(x)])
gai_lr(ls(x), l, mid, L, R);
if (R > mid)
if (!is[rs(x)])
gai_lr(rs(x), mid + 1, r, L, R);
push_up(x);
}
ll query(int x, int l, int r, int L, int R)
{
if (L <= l && r <= R)
return sum[x];
int mid = (l + r) >> 1;
ll ret = 0;
if (L <= mid)
ret += query(ls(x), l, mid, L, R);
if (R > mid)
ret += query(rs(x), mid + 1, r, L, R);
return ret;
}
int main()
{
// freopen("in.in", "r", stdin);
// freopen("out.out", "w", stdout);
int n;
ll cn = 0;
while (~scanf("%d", &n))
{
cn++;
memset(is, 0, sizeof (is));
memset(sum, 0, sizeof (sum));
memset(cnt, 0, sizeof (cnt));
for (int i = 1; i <= n; ++i)
a[i] = readin();
build(1, 1, n);
int m;
m = readin();
printf("Case #%lld:\n", cn);
while (m--)
{
int opt, l, r;
opt = readin();
l = readin();
r = readin();
if (l > r)
swap(l, r);
if (opt == 0)
gai_lr(1, 1, n, l, r);
if (opt == 1)
cout << query(1, 1, n, l, r) << '\n';
}
cout << '\n';
}
return 0;
}
例8
即每次操作后,求逆序对的个数。
我们记录 \(f[i]\) 表示 \(i\) 后比 \(a[i]\) 小的数个数。
则答案即求 \(\sum^n_{i = 1}f[i]\)。
这道题让我们每一次都拿出来排序再放回去,这样肯定会 TLE,所以我们就不跳题目的坑,直接不排序。
每一次排序,被排序的数的 \(f\) 都会变成零,引文后面没有比它小的数了。
所以我们分治。
线段树查询 \(minv\) \(=\) 最小的 \(a[i]\),如果 \(minv\) \(\leq\) \(a[p_i]\),就分治暴力修改所有 \(a[i] \leq a[p_i]\) 的值,然后把 \(a[i]\) 修改为 \(\infty\) ;否则,就不用管了。
这是一个和 GSS4(上面)差不多的问题。
可持久化
能够访问历史版本,且强制在线。
常见:数组、并查集、平衡树、线段树。
核心:不能修改原来的值。
线段树
单点加。
查询第 \(i\) 次修改后的区间和。
每次修改 \(p\),所有包含 \(p\) 的节点都会被修改。
只有这些共 \(\log n\) 个节点是不一样的。
新建这 \(\log n\) 个节点,然后把这些点连向之前的现在缺失的左/右儿子。
这样只需要 \(\log n\) 的复杂度了。
点击查看代码
#include <bits/stdc++.h>
using namespace std;
const int N = 1e6 + 10;
struct node
{
int l, r; // 左儿子 右儿子编号
int sum; // 区间和
node()
{
l = r = sum = 0;
}
} z[N * 30];
int cnt; // 总共有多少个节点
int a[N];
int root[N];
void update(int p)
{
z[p].sum = z[z[p].l].sum + z[z[p].r].sum;
}
int build(int l, int r) // 当前的区间为l~r 是这段区间对应的节点编号
{
cnt++;
int p = cnt;
if (l == r)
{
z[p].sum = a[l];
return p;
}
int m = (l + r) >> 1;
z[p].l = build(l, m);
z[p].r = build(m + 1, r);
update(p);
return p;
}
int query(int l, int r, int rt, int nowl, int nowr)
// 当前线段树节点编号为rt 对应的区间为l~r 要询问nowl~nowr这段区间的和
{
if (nowl <= l && r <= nowr)
return z[rt].sum;
int m = (l + r) >> 1;
if (nowl <= m)
{
if (m < nowr)
return query(l, m, z[rt].l, nowl, nowr) + query(m + 1, r, z[rt].r, nowl, nowr);
else
return query(l, m, z[rt].l, nowl, nowr);
}
else
return query(m + 1, r, z[rt].r, nowl, nowr);
}
int modify(int l, int r, int rt, int p, int v) // 返回修改后的新节点编号
// 当前线段树节点编号为rt 对应的区间为l~r 要把a[p]+=v
{
cnt++;
int q = cnt; // 新的节点q用于修改
z[q] = z[rt];
if (l == r)
{
z[q].sum = v;
return q;
}
int m = (l + r) >> 1;
if (p <= m) // 在左儿子
z[q].l = modify(l, m, z[q].l, p, v);
else
z[q].r = modify(m + 1, r, z[q].r, p, v);
update(q);
return q;
}
void readin(int &x)
{
x = 0;
int f = 1;
char c = getchar();
while (c < '0' || c > '9')
{
if (c == '-')
f = -f;
c = getchar();
}
while (c >= '0' && c <= '9')
{
x = x * 10 + (c - '0');
c = getchar();
}
x = x * f;
}
int main()
{
int n, m;
readin(n);
readin(m);
for (int i = 1; i <= n; i++)
cin >> a[i];
root[0] = build(1, n); // root[i]代表第i次操作后的根节点是谁
for (int i = 1; i <= m; ++i)
{
int v, opt, p;
readin(v), readin(opt), readin(p);
if (opt == 1)
{
int c;
readin(c);
root[i] = modify(1, n, root[v], p, c);
}
else
{
cout << query(1, n, root[v], p, p) << '\n';
root[i] = root[v]; // 复制版本
}
}
return 0;
}
主席树(前缀值域可持久化线段树)
解决求区间第 \(k\) 小的值。
对数组的每个前缀都要维护一棵值域线段树。
值域线段树节点 \(l\sim r\) 的值为 \(k\) 表示大于等于 \(l\) 小于等于 \(r\) 的数有 \(k\) 个。
因为每一个前缀相对于上一个前缀只是加了最后一个数,所以可以用可持久化线段树来做。
\(l\sim r\) 转化成 \(1\sim l - 1\) 和 \(1\sim r\)。
点击查看代码
#include <bits/stdc++.h>
using namespace std;
const int N = 1e6 + 10;
struct node
{
int l, r; // 左儿子 右儿子编号
int sum; // 区间和
node()
{
l = r = sum = 0;
}
} z[N * 30];
int cnt; // 总共有多少个节点
int a[N];
int root[N];
void update(int p)
{
z[p].sum = z[z[p].l].sum + z[z[p].r].sum;
}
int modify(int l, int r, int rt, int p, int v) // 返回修改后的新节点编号
// 当前线段树节点编号为rt 对应的区间为l~r 要把a[p]+=v
{
cnt++;
int q = cnt; // 新的节点q用于修改
z[q] = z[rt];
if (l == r)
{
z[q].sum += v;
return q;
}
int m = (l + r) >> 1;
if (p <= m) // 在左儿子
z[q].l = modify(l, m, z[q].l, p, v);
else
z[q].r = modify(m + 1, r, z[q].r, p, v);
update(q);
return q;
}
int query(int p1, int p2, int l, int r, int k)
// 当前对应的值域范围为l~r
// 要询问第k小的数
// 需要用p1和p2这两颗线段树来询问
{
if (l == r)
return l;
int m = (l + r) >> 1;
if (z[z[p2].l].sum - z[z[p1].l].sum >= k)
return query(z[p1].l, z[p2].l, l, m, k);
// z[z[p2].l].sum - z[z[p1].l].sum代表aL~aR有多少个数在[l,m]之间
else
return query(z[p1].r, z[p2].r, m + 1, r, k - (z[z[p2].l].sum - z[z[p1].l].sum));
}
void readin(int &x)
{
x = 0;
int f = 1;
char c = getchar();
while (c < '0' || c > '9')
{
if (c == '-')
f = -f;
c = getchar();
}
while (c >= '0' && c <= '9')
{
x = x * 10 + (c - '0');
c = getchar();
}
x = x * f;
}
int main()
{
root[0] = 0;
int maxv = 1e9;
int n;
readin(n);
int m;
readin(m);
for (int i = 1; i <= n; ++i)
readin(a[i]);
// root[i] 代表a1~ai这些数所对应的值域线段树的根
// 值域范围是1~maxv
for (int i = 1; i <= n; ++i)
root[i] = modify(0, maxv, root[i - 1], a[i], 1);
while (m--)
{
int l, r, k;
readin(l), readin(r), readin(k);
cout << query(root[l - 1], root[r], 0, maxv, k) << '\n';
}
return 0;
}
数组
方法1
每个位置开一个 vector,存每一个修改的时间戳、值。
查询的时候二分查找。
方法2
开一个 map<pair<int, int>, int>。
这三个 int 分别表示位置、时间、值。
查询直接 lower_bound。
并查集
老师说要讲,又不讲了。。。
例9
求 \(l\sim r\) 有多少个不同的数。
强制在线。
用 \(pre[i]\) 记录每一个地方 \(i\) 上一个 \(a[i]\) 出现的位置。
则每次询问即询问 \(l\sim r\) 有多少个 \(pre[i] < l\)。
即求 \(1\sim r\) 中 \(pre[i] < l\) 的个数 \(-\) \(1\sim l - 1\) 中 \(pre[i] < l\) 的个数。
于是,用主席树维护。
例10

根号分治。
以 \(\sqrt{n}\) 为界限,如果一个点周围连了超过 \(\sqrt{n}\) 个点,就称它为大点,否则称它为小点。
如果修改的一个点是小点,那么就直接暴力。
否则,如果这个点连的是大点,显然大点的数量不会超过 \(\sqrt{n}\) 个;如果连的是小点的话,就处理出每个大点连的白小点和黑小点的数量。
例11
先用 LCA 求出 \(l, r\) 之间的路径。
看 这里
因为 \(fib\) 数列是满足不是三角形的最小数列。
ryf大佬的题单
完成情况:


浙公网安备 33010602011771号