算法竞赛模板

ExtractStars的算法竞赛模板 v4.0

缺省部

#include <bits/stdc++.h>

using namespace std;

using ll = long long;
using ull= unsigned long long;
using ld = ld;

#define inf 0x3f3f3f3f
#define infll 0x3f3f3f3f3f3f3f3fLL

void solve()
{
}

int main()
{
    ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
    // freopen("test.in", "r", stdin);
    // freopen("test.out", "w", stdout);
    int _ = 1;
    cin >> _;
    while (_--)
    {
        solve();
    }
    return 0;
}

一、STL库

容器

vector

变长数组,支持动态增删元素:

vector<int> v; // 声明一个空的vector

vector<vector<int>> vv(n,vector<int>(m)); // 创建一个n*m的二维数组

v.push_back(); // 向尾部插入元素
v.pop_back();   // 弹出尾部元素

int size = v.size(); // 获取元素个数
bool isEmpty = v.empty(); // 判断是否为空
v.clear(); // 清空vector
int frontElement = v.front(); // 获取首元素
int backElement = v.back();   // 获取尾元素
int capacity = v.capacity(); // 获取容量

v.resize(size + 1); // 改变大小为 size+1,并在末尾填充默认值
v.resize(size - 1); // 改变大小为 size-1

v.assign({1, 2, 3}); // 用列表 {1, 2, 3} 中的元素替换向量中的元素
v.insert(v.begin() + 1, 4); // 在第二个位置插入元素 4

v.erase(v.begin() + 1); // 删除第二个位置的元素
v.erase(v.begin() + l, v.begin() + r); // 删除区间[l,r)的元素
vector<int> v2;
v2.swap(v); // 交换 v 和 v2 的元素

pair<int, int>

存储两个元素的组合:

支持比较运算,以first为第一关键字,以second为第二关键字(字典序)

pair<int, int> p = make_pair(1, 2);
int firstElement = p.first;  // 获取第一个元素
int secondElement = p.second; // 获取第二个元素

string

字符串类型:

string s; // 声明一个空的字符串
string s2 = "hello"; // 初始化一个字符串
string s3(5, 'a'); // 初始化一个包含5个字符 'a' 的字符串

int length = s.size(); // 获取字符串长度
bool isEmpty = s.empty(); // 判断是否为空

char& frontChar = s.front(); // 获取首字符的引用
char& backChar = s.back();   // 获取尾字符的引用
const char* c_str = s.c_str(); // 返回指向以 null 结尾的字符数组的指针

s.append("world"); // 在字符串末尾添加 "world"
s.push_back('!'); // 在字符串末尾添加一个字符 '!'
s.insert(5, " "); // 在位置 5 插入一个空格
s.clear(); // 清空字符串
s.erase(5, 1); // 从位置 5 开始删除一个字符
s.replace(5, 1, ","); // 从位置 5 开始替换一个字符为 ","

string s5 = s.substr(5); // 从位置 5 开始到末尾的子串
string substr = s.substr(0, 5); // 返回从位置 0 开始长度为 5 的子串

size_t found = s.find("l"); // 返回第一个 "l" 的位置
size_t rfound = s.rfind("l"); // 返回最后一个 "l" 的位置

s.swap(s2); // 交换 s 和 s2 的内容
string s4 = s + s2; // 拼接 s 和 s2

queue

队列:

// 声明一个空的队列
queue<int> q; 

// 向队尾插入一个元素
q.push(1); 

// 弹出队头元素
q.pop(); 

// 获取队列中元素个数
int size = q.size(); 

// 判断队列是否为空
bool isEmpty = q.empty(); 

// 获取队头元素
int frontElement = q.front();

// 获取队尾元素
int backElement = q.back();   

deque

双端队列

// 创建一个空的双端队列
deque<int> dq1;

// 创建一个包含数组所有元素的双端队列
vector<int> arr = {3, 1, 4, 1, 5, 9};
deque<int> dq2(arr.begin(), arr.end());

// 在队尾插入一个元素
dq1.push_back(10);

// 在队头插入一个元素
dq1.push_front(20);

// 删除队尾元素
dq1.pop_back();

// 删除队头元素
dq1.pop_front();

// 获取队头元素
int frontElement = dq1.front();

// 获取队尾元素
int backElement = dq1.back();

// 获取双端队列中元素个数
int size = dq1.size();

// 判断双端队列是否为空
bool isEmpty = dq1.empty();

priority_queue

优先队列,默认为大根堆:

// 创建一个空的大根堆
priority_queue<int> pq1;

// 创建一个空的小根堆
priority_queue<int, vector<int>, greater<int>> pq2;

// 创建一个包含数组所有元素的大根堆
vector<int> arr = {3, 1, 4, 1, 5, 9};
priority_queue<int> pq3(arr.begin(), arr.end());

// 创建一个包含数组所有元素的小根堆
priority_queue<int, vector<int>, greater<int>> pq4(arr.begin(), arr.end());

// 插入一个元素
pq1.push(10);

// 弹出堆顶元素
pq1.pop();

// 获取堆顶元素
int topElement = pq1.top();

// 获取堆中元素个数
int size = pq1.size();

// 判断堆是否为空
bool isEmpty = pq1.empty();

stack

栈:

stack<int> s; // 声明一个空的栈
s.push(1); // 向栈顶插入一个元素
s.pop(); // 弹出栈顶元素
int size = s.size(); // 获取栈中元素个数
bool isEmpty = s.empty(); // 判断栈是否为空
int topElement = s.top(); // 获取栈顶元素

set

基于红黑树的有序集合:

// 创建一个空的集合
set<int> s1;

// 创建一个包含数组所有元素的集合
vector<int> arr = {3, 1, 4, 1, 5, 9};
set<int> s2(arr.begin(), arr.end());

// 插入一个元素
s1.insert(10);

// 删除10
s1.erase(10);

// 查找一个元素
auto it = s1.find(5);

// 获取集合中元素个数
int size = s1.size();

// 判断集合是否为空
bool isEmpty = s1.empty();

// 清空集合
s1.clear();

// 返回集合起始迭代器
auto begin = s1.begin();

// 返回集合末尾迭代器
auto end = s1.end();

// 返回集合中大于某个值的第一个元素的迭代器
auto = s1.upper_bound(5);

// 返回集合中大于等于某个值的第一个元素的迭代器
auto lower = s1.upper_bound(5);

// 返回集合中某个值出现的次数
int count = s1.count(5);

// 将集合合并
s1.merge(s2);

// 交换两个集合
s1.swap(s2);

unordered_set

基于哈希表的无序集合,操作类似set:

// 创建一个空的无序集合
unordered_set<int> s1;

// 创建一个包含数组所有元素的无序集合
vector<int> arr = {3, 1, 4, 1, 5, 9};
auto s2(arr.begin(), arr.end());

// 插入一个元素
s1.insert(10);

// 删除一个元素
s1.erase(10);

// 查找一个元素
auto it = s1.find(5);

// 获取集合中元素个数
int size = s1.size();

// 判断集合是否为空
bool isEmpty = s1.empty();

map

// 创建一个空的map
map<int, string> m;

// 使用数组初始化列表创建map
map<int, string> m = {{1, "One"}, {2, "Two"}, {3, "Three"}};

// 插入元素
m.insert(make_pair(4, "Four"));

// 访问元素
string value = m[1]; // 如果键不存在,会插入一个默认值

// 查找元素
map<int, string>::iterator it = m.find(2);

// 删除元素
m.erase(3);

// 遍历map
for (auto& pair : m) 
{
    cout << pair.first << ": " << pair.second << endl;
}

// 判断是否存在某个键
if (m.count(4) > 0) 
{
    cout << "Key 4 exists" << endl;
}

// 获取map的大小
int size = m.size();

// 清空map
m.clear();

bitset

位运算:

// 创建一个大小为10的bitset,初始值都为0
bitset<10> b1;

// 使用整数初始化bitset
bitset<10> b2(45); // 二进制表示为 101101

// 使用字符串初始化bitset
bitset<10> b3(string("1010101010"));

// 使用另一个bitset初始化bitset
bitset<10> b4(b3);

// 访问位
bool bit = b1[0]; // 获取第0位的值

// 设置位
b1.set(0, true); // 将第0位设置为1

// 清除位
b1.reset(0); // 将第0位清零

// 翻转位
b1.flip(0); // 将第0位翻转

// 获取位数
int size = b1.size(); // 返回10,即位数

// 获取字符串表示
string s = b1.to_string(); // 返回 "0000000000"

// 获取整数表示
unsigned long long value = b1.to_ullong(); // 返回0

// 判断是否有1
bool anyOne = b1.any();

// 判断是否全为0
bool noneOne = b1.none();

// 设置所有位为1
b1.set();

// 将第5位设置为1
b1.set(5, true);

// 将第3位设置为0
b1.set(3, false);

// 设置第7位为0
b1.reset(7);

// 翻转所有位
b1.flip();

// 翻转第2位
b1.flip(2);

// 检查第4位是否为1
bool test = b1.test(4);

// 获取1的个数
int count = b1.count();

// 将所有位清零
b1.reset();

algorithm库

非修改序列操作

  1. all_of(first, last, pred)
    检查 [first, last) 区间内的所有元素是否都满足谓词 pred,如果是,则返回 true;否则返回 false
  2. any_of(first, last, pred)
    检查 [first, last) 区间内的任意一个元素是否满足谓词 pred,如果有,则返回 true;否则返回 false
  3. none_of(first, last, pred)
    检查 [first, last) 区间内的所有元素是否都不满足谓词 pred,如果是,则返回 true;否则返回 false
  4. for_each(first, last, func)
    [first, last) 区间内的每个元素都调用函数 func
  5. for_each_n(first, n, func)
    对从 first 开始的 n 个元素都调用函数 func
  6. count(first, last, value)
    统计 [first, last) 区间内值等于 value 的元素个数。
  7. count_if(first, last, pred)
    统计 [first, last) 区间内满足谓词 pred 的元素个数。
  8. mismatch(first1, last1, first2)
    [first1, last1)[first2, last2) 中找到第一组不匹配的元素,并返回一个 pair,其中 pair.first 是第一个序列中不匹配的元素,pair.second 是第二个序列中对应的元素。
  9. find(first, last, value)
    [first, last) 区间内查找值为 value 的元素,如果找到,则返回指向该元素的迭代器;否则返回 last
  10. find_if(first, last, pred)
    [first, last) 区间内查找第一个满足谓词 pred 的元素,如果找到,则返回指向该元素的迭代器;否则返回 last
  11. find_if_not(first, last, pred)
    [first, last) 区间内查找第一个不满足谓词 pred 的元素,如果找到,则返回指向该元素的迭代器;否则返回 last
  12. find_end(first1, last1, first2, last2)
    [first1, last1) 区间内查找最后一组匹配 [first2, last2) 区间的子序列,并返回一个迭代器,指向匹配的起始位置;如果没有找到,则返回 last1
  13. find_first_of(first1, last1, first2, last2)
    [first1, last1) 区间内查找第一个与 [first2, last2) 中任意一个元素相等的元素,并返回指向该元素的迭代器;如果没有找到,则返回 last1
  14. adjacent_find(first, last)
    [first, last) 区间内查找相邻重复的元素,并返回指向第一对相邻重复元素的第一个元素的迭代器;如果没有找到,则返回 last
  15. adjacent_find(first, last, binary_pred)
    [first, last) 区间内查找相邻元素满足二元谓词 binary_pred 的第一组元素,并返回指向该组元素的第一个元素的迭代器;如果没有找到,则返回 last
  16. search(first1, last1, first2, last2)
    [first1, last1) 区间内查找第一次出现 [first2, last2) 中的子序列,并返回一个迭代器,指向匹配的起始位置;如果没有找到,则返回 last1
  17. search_n(first, last, count, value)
    [first, last) 区间内查找第一组连续 count 个值为 value 的元素,并返回一个迭代器,指向该组元素的第一个元素;如果没有找到,则返回 last
  18. search_n(first, last, count, value, pred)
    [first, last) 区间内查找第一组连续 count 个满足谓词 pred 的元素,并返回一个迭代器,指向该组元素的第一个元素;如果没有找到,则返回 last

修改序列操作

  1. copy(first1, last1, first2)
    复制 [first1, last1) 区间内的元素到以 first2 开始的目标区间。
  2. copy_if(first, last, result, pred)
    复制满足谓词 pred 的元素到目标区间。
  3. copy_n(first, n, result)
    复制从 first 开始的 n 个元素到目标区间。
  4. copy_backward(first1, last1, last2)
    从后向前复制 [first1, last1) 区间内的元素到以 last2 结束的目标区间。
  5. move(first, last, result)
    移动 [first, last) 区间内的元素到目标区间,源区间的元素可能会被修改为不确定状态。
  6. move_backward(first1, last1, last2)
    从后向前移动 [first1, last1) 区间内的元素到以 last2 结束的目标区间,源区间的元素可能会被修改为不确定状态。
  7. fill(first, last, value)
    [first, last) 区间内的所有元素设置为 value
  8. fill_n(first, n, value)
    将从 first 开始的 n 个元素设置为 value
  9. transform(first1, last1, first2, result, op)
    [first1, last1) 区间内的元素和 [first2, ......) 区间内的元素按照二元操作函数 op 进行操作,并将结果存储到目标区间。
  10. generate(first, last, gen)
    使用生成器函数 gen 生成 [first, last) 区间内的元素。
  11. generate_n(first, n, gen)
    使用生成器函数 gen 生成从 first 开始的 n 个元素。
  12. remove(first, last, value)
    [first, last) 区间内移除所有等于 value 的元素,并返回指向新的逻辑结尾的迭代器。
  13. remove_if(first, last, pred)
    [first, last) 区间内移除所有满足谓词 pred 的元素,并返回指向新的逻辑结尾的迭代器。
  14. remove_copy(first, last, result, value)
    复制 [first, last) 区间内除了值为 value 的元素到目标区间,并返回指向新的逻辑结尾的迭代器。
  15. remove_copy_if(first, last, result, pred)
    复制 [first, last) 区间内除了满足谓词 pred 的元素到目标区间,并返回指向新的逻辑结尾的迭代器。
  16. replace(first, last, old_value, new_value)
    [first, last) 区间内所有等于 old_value 的元素替换为 new_value
  17. replace_if(first, last, pred, new_value)
    [first, last) 区间内所有满足谓词 pred 的元素替换为 new_value
  18. replace_copy(first, last, result, old_value, new_value)
    复制 [first, last) 区间内的元素到目标区间,并将所有等于 old_value 的元素替换为 new_value
  19. replace_copy_if(first, last, result, pred, new_value)
    复制 [first, last) 区间内的元素到目标区间,并将所有满足谓词 pred 的元素替换为 new_value
  20. swap_ranges(first1, last1, first2)
    [first1, last1) 区间内的元素和 [first2, ...) 区间内的元素进行交换。
  21. reverse(first, last)
    [first, last) 区间内的元素反转。
  22. reverse_copy(first, last, result)
    [first, last) 区间内的元素反转后复制到目标区间。
  23. rotate(first, middle, last)
    [first, last) 区间内的元素进行循环左移,其中 middle 为旋转点。
  24. rotate_copy(first, middle, last, result)
    [first, last) 区间内的元素以 middle 为旋转点进行循环左移后复制到目标区间。
  25. unique(first, last)
    移除 [first, last) 区间中相邻重复的元素,并返回指向新的逻辑结尾的迭代器。
  26. unique_copy(first, last, result)
    复制 [first, last) 区间中相邻重复的元素到目标区间,并返回指向新的逻辑结尾的迭代器。
  27. shuffle(first, last, g)
    [first, last) 区间内的元素随机重排,使用生成器 g

排序和相关操作

  1. sort(first, last)
    [first, last) 区间内的元素进行升序排序。
  2. stable_sort(first, last)
    [first, last) 区间内的元素进行稳定升序排序。
  3. partial_sort(first, middle, last)
    [first, last) 区间内的元素进行部分排序,保证 [first, middle) 内的元素是升序的。
  4. partial_sort_copy(first1, last1, first2, last2)
    [first1, last1) 区间内的元素拷贝到 [first2, ...)区间内,并对拷贝的元素进行部分排序,保证 [first2, ...) 区间内的元素是升序的,最多拷贝 last2 - first2 个元素。
  5. is_sorted(first, last)
    检查 [first, last) 区间内的元素是否已经按升序排列。
  6. is_sorted_until(first, last)
    返回 [first, last) 区间内第一个不满足升序排列的元素的迭代器。
  7. nth_element(first, nth, last)
    重新排列 [first, last) 区间内的元素,使得第 nth - first 个元素是整个区间中第 nth - first 小的元素,且比它小的元素都排在它的前面,比它大的元素都排在它的后面。

二分查找

  1. lower_bound(first, last, value)
    [first, last) 区间内查找第一个大于等于 value 的元素,并返回指向该元素的迭代器。
  2. upper_bound(first, last, value)
    [first, last) 区间内查找第一个大于 value 的元素,并返回指向该元素的迭代器。
  3. binary_search(first, last, value)
    [first, last) 区间内查找是否存在值为 value 的元素,如果存在则返回 true;否则返回 false

最小值和最大值

  1. min(a, b) / max(a, b)
    返回 ab 中的最小值和最大值。
  2. min_element(first, last) / max_element(first, last)
    返回 [first, last) 区间内的最小值和最大值的迭代器。
  3. minmax_element(first, last)
    返回一个 pair,其中 pair.first 是最小值的迭代器,pair.second 是最大值的迭代器。

数值操作

  1. accumulate(first, last, init)
    [first, last) 区间内的元素进行累加,初始值为 init
  2. inner_product(first1, last1, first2, init)
    计算 [first1, last1)[first2, ...) 区间内对应元素的内积,并加上初始值 init
  3. partial_sum(first, last, result)
    计算部分和,将 [first, last) 区间内的元素累加到目标区间 result
  4. adjacent_difference(first, last, result)
    计算相邻元素的差值,将 [first, last) 区间内的元素的相邻元素之差存储到目标区间 result 中。

迭代器操作

函数 参数及说明 功能描述
advance(it, n) it 表示某个迭代器,n 为整数。 it 迭代器前进或后退 n 个位置。
distance(first, last) firstlast 都是迭代器。 计算 firstlast 之间的距离。
begin(cont) cont 表示某个容器。 返回一个指向 cont 容器中第一个元素的迭代器。
end(cont) cont 表示某个容器。 返回一个指向 cont 容器中最后一个元素之后位置的迭代器。
prev(it) it 为指定的迭代器。注意,it 至少为双向迭代器。 返回一个指向上一个位置处的迭代器。
next(it) it 为指定的迭代器。注意,it 最少为前向迭代器。 返回一个指向下一个位置处的迭代器。

其他操作

  1. swap(a, b)
    交换 ab 的值。
  2. next_permutation(first, last) / prev_permutation(first, last)
    [first, last) 区间内的元素重新排列为下一个或上一个字典序排列,如果已经是最后一个或第一个排列,则返回 false
  3. random_shuffle(first, last)
    [first, last) 区间内的元素进行随机重排。
  4. sort_heap(first, last)
    [first, last) 区间内的元素转换为堆序列

二、数据结构

并查集

朴素并查集

算法介绍

并查集用父指针维护若干不交集,按秩合并与路径压缩可以在近乎常数的均摊时间内完成合并与查询。find 返回代表元并在回溯过程中压缩路径,unite 将两个集合的小树挂到大树以减少高度。赛时常配合排序或扫描线解决连通性判定与合并类问题。

常见例题

题目:给定 n 个点与 q 次操作,操作分为两种,合并集合和询问两个点是否连通,要求在线输出所有连通性答案。

做法:用并查集维护集合代表,遇到合并操作就合并,遇到询问操作就比较代表是否相同并输出答案。

代码
// 朴素并查集,按大小合并与路径压缩
// 功能:维护不交集的合并与连通性查询;支持带参数构造与延后init
// 复杂度:单次摊还近乎O(1)
struct DisjointSet
{
    int n;                  // 元素个数
    vector<int> parent;     // 父指针
    vector<int> compSize;   // 以代表为根的集合大小

    // 构造函数,若给定规模则直接完成初始化
    DisjointSet(int n_ = 0)
    {
        if (n_) init(n_);
        else n = 0;
    }

    // 初始化,将每个元素设为独立集合
    void init(int n_)
    {
        n = n_;
        parent.resize(n + 1);
        compSize.assign(n + 1, 1);
        for (int i = 1; i <= n; i++) parent[i] = i;
    }

    // 查询并返回x的代表元,带路径压缩
    int find(int x)
    {
        while (x != parent[x]) x = parent[x] = parent[parent[x]];
        return x;
    }

    // 合并a与b所在集合,返回是否发生合并
    bool unite(int a, int b)
    {
        int x = find(a), y = find(b);
        if (x == y) return false;
        if (compSize[x] < compSize[y]) swap(x, y);
        parent[y] = x, compSize[x] += compSize[y];
        return true;
    }

    // 查询a与b是否连通
    bool same(int a, int b)
    {
        return find(a) == find(b);
    }

    // 返回元素r所在集合大小
    int size(int r)
    {
        return compSize[find(r)];
    }
};


可撤销并查集

算法介绍

可撤销并查集通过保存每次修改的快照来支持回滚,核心思想是不做路径压缩,只按大小合并,并把被修改的父指针与大小信息压栈,回滚时弹栈恢复。它与分治或线段树分配边的框架组合,可以离线处理带删除的连边问题。

常见例题

题目:在一个包含 n 个点的动态图中,给定 m 条边的添加与删除序列,并给定若干时刻的询问连通块数量,要求按时间顺序输出答案。

做法:将每条边的生存区间分配到线段树节点,在分治遍历时把该节点负责的所有边合并进可撤销并查集,进入子区间递归,回溯时回滚到进入前的栈高度,这样每条边被合并的次数为其被覆盖的结点数,整体复杂度近似为 O((n + m) log m)。

代码
// 可撤销并查集(Rollback DSU),不做路径压缩,仅按大小合并
// 功能:支持合并、查询、快照与回滚;构造时可直接指定规模
// 复杂度:合并与回滚摊还近乎O(1)
struct RollbackDSU
{
    int n;                                  // 元素个数
    vector<int> parent;                     // 父指针
    vector<int> compSize;                   // 集合大小
    vector<pair<int,int>> history;          // 变更栈:(who, oldParent) 或 (who, -oldSize)
    int comps;                               // 当前连通块数量

    // 构造函数,若给定规模则直接初始化
    RollbackDSU(int n_ = 0)
    {
        if (n_) init(n_);
        else n = 0, comps = 0;
    }

    // 初始化n个独立集合
    void init(int n_)
    {
        n = n_;
        parent.resize(n + 1);
        compSize.assign(n + 1, 1);
        for (int i = 1; i <= n; i++) parent[i] = i;
        history.clear();
        comps = n;
    }

    // 查找代表元,不做路径压缩以便回滚
    int find(int x)
    {
        while (x != parent[x]) x = parent[x];
        return x;
    }

    // 保存当前历史栈高度作为快照
    int snapshot()
    {
        return (int)history.size();
    }

    // 回滚到指定快照高度
    void rollback(int snap)
    {
        while ((int)history.size() > snap)
        {
            auto [who, val] = history.back(); history.pop_back();
            if (val >= 0) parent[who] = val;
            else compSize[who] = -val;
        }
    }

    // 合并a与b所在集合,返回是否发生合并
    bool unite(int a, int b)
    {
        int x = find(a), y = find(b);
        if (x == y) return false;
        if (compSize[x] < compSize[y]) swap(x, y);
        history.emplace_back(y, parent[y]); parent[y] = x;
        history.emplace_back(x, -compSize[x]); compSize[x] += compSize[y];
        comps--;
        return true;
    }

    // 返回当前连通块数量
    int count()
    {
        return comps;
    }

    // 判断是否连通
    bool same(int a, int b)
    {
        return find(a) == find(b);
    }
};


树状数组

一维树状数组

算法介绍

树状数组用低位元操作维护前缀信息,支持单点修改与前缀查询,区间查询通过差前缀实现。模板通常提供 add 和 sum 两个基本操作,并在 select 中利用二进制提升按权选择第 k 小位置。

常见例题

题目:给定长度为 n 的数组和 q 次操作,操作一为在位置 p 上加上 v,操作二为询问区间 [l, r] 的元素和。

做法:用一维树状数组维护前缀和,区间和由 sum(r) 减去 sum(l − 1) 得到,所有操作在对数时间内完成。

代码
// 一维树状数组(Fenwick),支持单点加、前缀和、区间和、按权选择
// 功能:维护加法型前缀信息;构造时可直接给定规模
// 复杂度:单次操作O(log n)
template <typename T>
struct Fenwick
{
    int n;           // 大小
    vector<T> bit;   // 树状数组,下标从1开始

    // 构造函数,若给定规模则完成初始化
    Fenwick(int n_ = 0)
    {
        if (n_) init(n_);
        else n = 0;
    }

    // 初始化大小为n_,元素清零
    void init(int n_)
    {
        n = n_;
        bit.assign(n + 1, T{});
    }

    // 在位置x增加值v
    void add(int x, T v)
    {
        for (int i = x; i <= n; i += i & -i) bit[i] = bit[i] + v;
    }

    // 查询前缀和sum[1..x]
    T sum(int x)
    {
        T res{};
        for (int i = x; i > 0; i -= i & -i) res = res + bit[i];
        return res;
    }

    // 查询区间[l..r]之和
    T rangeSum(int l, int r)
    {
        if (l > r) return T{};
        return sum(r) - sum(l - 1);
    }

    // 选择满足前缀和<=k的最大下标,要求所有值非负
    int select(T k)
    {
        int x = 0;
        T cur{};
        for (int pw = 1 << __lg(n); pw; pw >>= 1)
        {
            int nx = x + pw;
            if (nx <= n && cur + bit[nx] <= k) x = nx, cur = cur + bit[nx];
        }
        return x;
    }
};


二维树状数组

算法介绍

二维树状数组在两个维度上同时使用低位元操作,支持点更新与子矩形前缀查询,进而得到任意轴对齐矩形的区间和。内存为 O(nm),操作为 O(log n log m),适合中等规模的二维求和问题。

常见例题

题目:给定一个 n 行 m 列的整数矩阵,有 q 次操作,操作一为把位置 (x, y) 的值加上 v,操作二为询问子矩形 [x1, y1] 到 [x2, y2] 的元素和。

做法:用二维树状数组实现点更新与前缀矩形和的查询,答案由四个前缀的容斥组合得到。

代码
// 二维树状数组(Fenwick 2D),支持点加与子矩形求和
// 功能:维护矩阵轴对齐子矩形和;构造时可直接给定行列
// 复杂度:单次操作O(log n log m)
template <typename T>
struct Fenwick2D
{
    int n, m;                       // 行列尺寸
    vector<vector<T>> bit;          // 二维树状数组,下标从1开始

    // 构造函数,若给定尺寸则直接初始化
    Fenwick2D(int n_ = 0, int m_ = 0)
    {
        if (n_ && m_) init(n_, m_);
        else n = 0, m = 0;
    }

    // 初始化为n_行m_列
    void init(int n_, int m_)
    {
        n = n_, m = m_;
        bit.assign(n + 1, vector<T>(m + 1, T{}));
    }

    // 在坐标(x,y)加上v
    void add(int x, int y, T v)
    {
        for (int i = x; i <= n; i += i & -i)
            for (int j = y; j <= m; j += j & -j)
                bit[i][j] = bit[i][j] + v;
    }

    // 查询前缀子矩形[1..x][1..y]之和
    T sum(int x, int y)
    {
        T res{};
        for (int i = x; i > 0; i -= i & -i)
            for (int j = y; j > 0; j -= j & -j)
                res = res + bit[i][j];
        return res;
    }

    // 查询子矩形[x1..x2][y1..y2]之和
    T rangeSum(int x1, int y1, int x2, int y2)
    {
        if (x1 > x2 || y1 > y2) return T{};
        return sum(x2, y2) - sum(x1 - 1, y2) - sum(x2, y1 - 1) + sum(x1 - 1, y1 - 1);
    }
};


线段树

朴素线段树

算法介绍

在线段树上维护区间信息,支持单点修改与区间查询。将数组建成一棵完全二叉树形结构,每个结点覆盖一个区间 [l,r],结点信息由左右儿子合并得到。查询时在 [ql,qr] 与当前结点 [l,r] 的相交关系下递归;单点修改时自顶向下找到叶子并回溯更新父结点。

常见例题

题目:给定长度为 n 的数组 a,q 次操作,操作一为 1 x v 表示将 a[x] 赋值为 v,操作二为 2 l r 查询区间 [l,r] 的区间和、最小值与最大值。

做法:用朴素线段树维护 Info = {sum, min, max}。单点修改用 modify(pos, val)。区间查询用 rangeQuery(l, r) 返回 Info,再输出 sum/min/max 即可。由于所有区间都是闭区间,递归时严格用 [l,r] 与 [ql,qr] 比较,相等或包含时直接返回整段信息,完全不相交返回空信息。

代码
// 朴素线段树
// 约定:Info 需要支持 Info()+Info 合并,且提供静态的空信息构造 empty()
// 下面附了一个示例 Info,用于区间和/最小值/最大值
template <typename T>
struct Info
{
    T sum, mn, mx;

    // 构造与重置
    Info() : sum(0), mn(numeric_limits<T>::max()), mx(numeric_limits<T>::min()) {}
    Info(T v) : sum(v), mn(v), mx(v) {}
    static Info empty() { return {}; }

    // 合并两个区间的信息
    friend Info operator+(const Info &a, const Info &b)
    {
        if (a.mn == numeric_limits<T>::max())
            return b;
        if (b.mn == numeric_limits<T>::max())
            return a;
        Info c;
        c.sum = a.sum + b.sum;
        c.mn = min(a.mn, b.mn);
        c.mx = max(a.mx, b.mx);
        return c;
    }
};

template <typename Info>
struct SegmentTree
{
    int n;           // 维护的元素个数
    vector<Info> tr; // 线段树结点信息

    // 构造函数:给定长度,初值为 Info()
    SegmentTree(int n_ = 0) : n(0)
    {
        if (n_)
            init(n_);
    }

    // 构造函数:用数组初始化
    template <typename T>
    SegmentTree(const vector<T> &a) { init(a); }

    // 功能:按长度初始化为 n 个元素,初值为 Info()
    void init(int n_)
    {
        n = n_;
        tr.assign(4 * n + 4, Info::empty());
    }

    // 功能:用数组 a 初始化
    template <typename T>
    void init(const vector<T> &a)
    {
        n = (int)a.size();
        tr.assign(4 * n + 4, Info::empty());
        build(1, 0, n - 1, a);
    }

    // 功能:自底向上合并信息
    void pull(int p)
    {
        tr[p] = tr[p << 1] + tr[p << 1 | 1];
    }

    // 功能:建树,覆盖区间 [l,r]
    template <typename T>
    void build(int p, int l, int r, const vector<T> &a)
    {
        if (l == r)
        {
            tr[p] = Info(a[l]);
            return;
        }
        int m = (l + r) >> 1;
        build(p << 1, l, m, a), build(p << 1 | 1, m + 1, r, a);
        pull(p);
    }

    // 功能:单点赋值,将 pos 位置改成 v;覆盖区间 [l,r]
    void modify(int p, int l, int r, int pos, const Info &v)
    {
        if (l == r)
        {
            tr[p] = v;
            return;
        }
        int m = (l + r) >> 1;
        if (pos <= m)
            modify(p << 1, l, m, pos, v);
        else
            modify(p << 1 | 1, m + 1, r, pos, v);
        pull(p);
    }

    // 外部接口:单点赋值
    void modify(int pos, const Info &v) { modify(1, 0, n - 1, pos, v); }

    // 功能:查询区间 [ql,qr] 的聚合信息;当前结点覆盖 [l,r]
    Info rangeQuery(int p, int l, int r, int ql, int qr)
    {
        if (qr < l || r < ql)
            return Info::empty();
        if (ql <= l && r <= qr)
            return tr[p];
        int m = (l + r) >> 1;
        return rangeQuery(p << 1, l, m, ql, qr) + rangeQuery(p << 1 | 1, m + 1, r, ql, qr);
    }

    // 外部接口:区间查询
    Info rangeQuery(int l, int r) { return rangeQuery(1, 0, n - 1, l, r); }

    // 功能:在区间 [ql,qr] 上二分找第一个使 pred(Info) 为真的位置
    // 语义:pred 对“整段信息”判定是否存在可行解;若整段都不满足直接剪枝
    template <typename F>
    int findFirst(int p, int l, int r, int ql, int qr, F pred)
    {
        if (qr < l || r < ql || !pred(tr[p]))
            return -1;
        if (l == r)
            return l;
        int m = (l + r) >> 1, res = findFirst(p << 1, l, m, ql, qr, pred);
        if (res != -1)
            return res;
        return findFirst(p << 1 | 1, m + 1, r, ql, qr, pred);
    }

    // 外部接口:找第一个满足条件的位置
    template <typename F>
    int findFirst(int l, int r, F pred) { return findFirst(1, 0, n - 1, l, r, pred); }

    // 功能:在区间 [ql,qr] 上二分找最后一个使 pred(Info) 为真的位置
    template <typename F>
    int findLast(int p, int l, int r, int ql, int qr, F pred)
    {
        if (qr < l || r < ql || !pred(tr[p]))
            return -1;
        if (l == r)
            return l;
        int m = (l + r) >> 1, res = findLast(p << 1 | 1, m + 1, r, ql, qr, pred);
        if (res != -1)
            return res;
        return findLast(p << 1, l, m, ql, qr, pred);
    }

    // 外部接口:找最后一个满足条件的位置
    template <typename F>
    int findLast(int l, int r, F pred) { return findLast(1, 0, n - 1, l, r, pred); }
};

懒标记线段树

算法介绍

当需要对整段 [l,r] 进行区间操作时(如区间加、区间赋值、区间取 min 等),在线段树上叠加懒标记。若当前整段完全被修改区间覆盖,直接对结点信息进行一次“打标应用”,并把标记累积到当前结点;下推时再把标记分发给左右儿子。查询逻辑与朴素一致,但在向下递归前要先 push 把懒标记下传,保证子树信息正确。为配合原版模板功能,保留对整段应用标签的 rangeApply、以及在区间上用结点聚合信息二分定位的 findFirst / findLast。

常见例题

题目:给定长度为 n 的数组 a,q 次操作,操作一为 1 l r x 令区间 [l,r] 全部加上 x,操作二为 2 l r 查询 [l,r] 的区间和。额外要求支持查询“区间 [l,r] 内从左到右第一个前缀和超过 S 的位置”。

做法:用懒标树,Info 维护 sum 与 len,Tag 维护 add。区间加时整段 sum 增加 add×len。区间和直接返回。要二分位置时,pred(Info) 可以判断“这段的 sum 是否 > S”,findFirst(l, r, pred) 即可在 O(log n) 内返回答案。

代码
// 懒标记线段树
// 约定:Info 需要提供 apply(Tag) 与合并 operator+;Tag 需要提供 apply(Tag) 的“自合并”
// 下面附了一个常用示例:区间加 + 区间和
template <typename T>
struct Info
{
    T sum;
    int len;

    Info() : sum(0), len(0) {}
    Info(T v, int l) : sum(v), len(l) {}
    static Info empty() { return {}; }

    void apply(const T &add) { sum += add * len; } // 对整段加 add

    friend Info operator+(const Info &a, const Info &b)
    {
        if (a.len == 0)
            return b;
        if (b.len == 0)
            return a;
        return Info(a.sum + b.sum, a.len + b.len);
    }
};

template <typename T>
struct Tag
{
    T add;
    Tag() : add(0) {}
    explicit Tag(T a) : add(a) {}

    void apply(const Tag &t) { add += t.add; } // 累加标记
};

template <typename Info, typename Tag>
struct LazySegmentTree
{
    int n;           // 维护的元素个数
    vector<Info> tr; // 结点信息
    vector<Tag> tg;  // 懒标记

    // 构造函数:给定长度,初值为 Info()
    LazySegmentTree(int n_ = 0) : n(0)
    {
        if (n_)
            init(n_);
    }

    // 构造函数:用数组初始化
    template <typename T>
    LazySegmentTree(const vector<T> &a) { init(a); }

    // 功能:按长度初始化为 n 个元素,初值为 Info()
    void init(int n_)
    {
        n = n_;
        tr.assign(4 * n + 4, Info::empty());
        tg.assign(4 * n + 4, Tag());
    }

    // 功能:用数组 a 初始化
    template <typename T>
    void init(const vector<T> &a)
    {
        n = (int)a.size();
        tr.assign(4 * n + 4, Info::empty());
        tg.assign(4 * n + 4, Tag());
        build(1, 0, n - 1, a);
    }

    // 功能:自底向上合并信息
    void pull(int p) { tr[p] = tr[p << 1] + tr[p << 1 | 1]; }

    // 功能:把标记 t 应用到结点 p
    void apply(int p, const Tag &t)
    {
        // Info 上的 apply 需要你自己在 Info 中定义
        tr[p].apply(t.add);
        tg[p].apply(t);
    }

    // 功能:把结点 p 的懒标记下传到两个儿子
    void push(int p)
    {
        if (tg[p].add != 0)
        {
            apply(p << 1, tg[p]), apply(p << 1 | 1, tg[p]);
            tg[p] = Tag();
        }
    }

    // 功能:建树,覆盖区间 [l,r]
    template <typename T>
    void build(int p, int l, int r, const vector<T> &a)
    {
        if (l == r)
        {
            tr[p] = Info(a[l], 1);
            return;
        }
        int m = (l + r) >> 1;
        build(p << 1, l, m, a), build(p << 1 | 1, m + 1, r, a);
        pull(p);
    }

    // 功能:单点赋值,将 pos 位置改成 v;覆盖区间 [l,r]
    void modify(int p, int l, int r, int pos, const Info &v)
    {
        if (l == r)
        {
            tr[p] = v;
            return;
        }
        int m = (l + r) >> 1;
        push(p);
        if (pos <= m)
            modify(p << 1, l, m, pos, v);
        else
            modify(p << 1 | 1, m + 1, r, pos, v);
        pull(p);
    }

    // 外部接口:单点赋值
    void modify(int pos, const Info &v) { modify(1, 0, n - 1, pos, v); }

    // 功能:对区间 [ql,qr] 应用标签 t;当前结点覆盖 [l,r]
    void rangeApply(int p, int l, int r, int ql, int qr, const Tag &t)
    {
        if (qr < l || r < ql)
            return;
        if (ql <= l && r <= qr)
        {
            apply(p, t);
            return;
        }
        int m = (l + r) >> 1;
        push(p);
        rangeApply(p << 1, l, m, ql, qr, t), rangeApply(p << 1 | 1, m + 1, r, ql, qr, t);
        pull(p);
    }

    // 外部接口:区间打标
    void rangeApply(int l, int r, const Tag &t) { rangeApply(1, 0, n - 1, l, r, t); }

    // 功能:查询区间 [ql,qr] 的聚合信息;当前结点覆盖 [l,r]
    Info rangeQuery(int p, int l, int r, int ql, int qr)
    {
        if (qr < l || r < ql)
            return Info::empty();
        if (ql <= l && r <= qr)
            return tr[p];
        int m = (l + r) >> 1;
        push(p);
        return rangeQuery(p << 1, l, m, ql, qr) + rangeQuery(p << 1 | 1, m + 1, r, ql, qr);
    }

    // 外部接口:区间查询
    Info rangeQuery(int l, int r) { return rangeQuery(1, 0, n - 1, l, r); }

    // 功能:在 [ql,qr] 内用整段信息二分找第一个满足 pred 的位置
    template <typename F>
    int findFirst(int p, int l, int r, int ql, int qr, F pred)
    {
        if (qr < l || r < ql || !pred(tr[p]))
            return -1;
        if (l == r)
            return l;
        int m = (l + r) >> 1;
        push(p);
        int res = findFirst(p << 1, l, m, ql, qr, pred);
        if (res != -1)
            return res;
        return findFirst(p << 1 | 1, m + 1, r, ql, qr, pred);
    }

    // 外部接口:找第一个满足条件的位置
    template <typename F>
    int findFirst(int l, int r, F pred) { return findFirst(1, 0, n - 1, l, r, pred); }

    // 功能:在 [ql,qr] 内用整段信息二分找最后一个满足 pred 的位置
    template <typename F>
    int findLast(int p, int l, int r, int ql, int qr, F pred)
    {
        if (qr < l || r < ql || !pred(tr[p]))
            return -1;
        if (l == r)
            return l;
        int m = (l + r) >> 1;
        push(p);
        int res = findLast(p << 1 | 1, m + 1, r, ql, qr, pred);
        if (res != -1)
            return res;
        return findLast(p << 1, l, m, ql, qr, pred);
    }

    // 外部接口:找最后一个满足条件的位置
    template <typename F>
    int findLast(int l, int r, F pred) { return findLast(1, 0, n - 1, l, r, pred); }
};

动态开点线段树

算法介绍

动态开点线段树在需要访问某个子区间时才分配对应的结点,以指针或下标形式把不存在的子树延迟创建,它能在离散化困难或值域极大时以 O(k log U) 的时间与 O(k) 的空间处理 k 次有效访问,其中 U 是整个值域长度。所有区间语义统一为闭区间 [l,r],根结点覆盖外部给定的整段边界 [L,R],左右儿子分别覆盖 [l,mid] 与 [mid+1,r]。

常见例题

题目:给定 q 次操作,初始数组视为全零,坐标范围为 [1,10^18]。操作一是把区间 [l,r] 内每个元素加上 v,操作二是询问区间 [l,r] 的区间和。

做法:建立一个覆盖 [1,10^18] 的动态开点懒标树,遇到区间加时在完全覆盖的结点直接打标并累加,不完全覆盖时先 push 下传再递归左右儿子,回溯时 pull 合并;区间查询在完全覆盖时直接返回结点信息,不相交返回空信息,部分相交则先 push 再递归合并。

代码
// 动态开点线段树,示例:区间加 + 区间和
// 设计说明:Info 负责“如何合并”和“在长度len上如何应用标记”,Tag 负责“如何与另一个标记合并”与“是否为空”
// 构造函数:可指定值域 [L,R];所有函数均写明用途;单行语句尽量去花括号
struct Tag
{
    ll add; // 懒标记中的增量
    Tag(ll v = 0) : add(v) {}
    void apply(const Tag &o) { add += o.add; }
    bool isNeutral() const { return add == 0; }
};

struct Info
{
    ll sum;
    int len; // 区间和与区间长度
    Info(ll s = 0, int l = 0) : sum(s), len(l) {}
    static Info empty() { return {}; }
    void apply(const Tag &t) { sum += t.add * 1LL * len; }
    friend Info operator+(const Info &a, const Info &b)
    {
        if (a.len == 0)
            return b;
        if (b.len == 0)
            return a;
        return Info(a.sum + b.sum, a.len + b.len);
    }
};

template <class Info, class Tag>
struct DynamicSegTree
{
    struct Node
    {
        Info info;
        Tag tag;
        Node *ls, *rs;

        // 功能:结点构造,默认空信息与空标记
        Node() : info(Info::empty()), tag(Tag()), ls(nullptr), rs(nullptr) {}
    };

    Node *root;
    ll L, R; // 根指针与全局覆盖边界

    // 构造:给定整段边界 [L_,R_],根为空,按需开点
    DynamicSegTree(ll L_ = 1, ll R_ = 1) : root(nullptr), L(L_), R(R_) {}

    // 功能:把标记 v 作用到结点 u 表示的整段 [l,r]
    void apply(Node *u, ll l, ll r, const Tag &v)
    {
        if (!u)
            return;
        Info cur = u->info;
        if (cur.len == 0)
            u->info.len = int(r - l + 1);
        u->info.apply(v);
        u->tag.apply(v);
    }

    // 功能:下推结点 u 的懒标记到左右儿子
    void push(Node *u, ll l, ll r)
    {
        if (!u || u->tag.isNeutral() || l == r)
            return;
        ll m = (l + r) >> 1;
        if (!u->ls)
            u->ls = new Node();
        if (!u->rs)
            u->rs = new Node();
        apply(u->ls, l, m, u->tag), apply(u->rs, m + 1, r, u->tag);
        u->tag = Tag();
    }

    // 功能:自底向上合并左右儿子的区间信息
    void pull(Node *u)
    {
        Info Lc = u->ls ? u->ls->info : Info::empty();
        Info Rc = u->rs ? u->rs->info : Info::empty();
        u->info = Lc + Rc;
    }

    // 功能:在区间 [ql,qr] 上施加标记 v(外部闭区间,需保证 ql<=qr 且在 [L,R] 内)
    void rangeApply(ll ql, ll qr, const Tag &v) { rangeApply(root, L, R, ql, qr, v); }

    // 功能:查询区间 [ql,qr] 的聚合信息(外部闭区间)
    Info rangeQuery(ll ql, ll qr) { return rangeQuery(root, L, R, ql, qr); }

    // 功能:单点赋值,把位置 x 的信息改为 infoX(常用在需要精确覆盖一个点时)
    void modifyPoint(ll x, const Info &infoX) { modifyPoint(root, L, R, x, infoX); }

    // 功能:清空整棵树(递归释放内存)
    void clear()
    {
        clear(root);
        root = nullptr;
    }

private:
    // 递归实现:对 [ql,qr] 打标 v,当前结点 u 覆盖 [l,r]
    void rangeApply(Node *&u, ll l, ll r, ll ql, ll qr, const Tag &v)
    {
        if (qr < l || r < ql)
            return;
        if (!u)
            u = new Node();
        if (ql <= l && r <= qr)
        {
            apply(u, l, r, v);
            return;
        }
        push(u, l, r);
        ll m = (l + r) >> 1;
        rangeApply(u->ls, l, m, ql, qr, v), rangeApply(u->rs, m + 1, r, ql, qr, v);
        pull(u);
    }

    // 递归实现:查询 [ql,qr],当前结点 u 覆盖 [l,r]
    Info rangeQuery(Node *u, ll l, ll r, ll ql, ll qr)
    {
        if (!u || qr < l || r < ql)
            return Info::empty();
        if (ql <= l && r <= qr)
            return u->info;
        push(u, l, r);
        ll m = (l + r) >> 1;
        return rangeQuery(u->ls, l, m, ql, qr) + rangeQuery(u->rs, m + 1, r, ql, qr);
    }

    // 递归实现:单点赋值,当前结点 u 覆盖 [l,r]
    void modifyPoint(Node *&u, ll l, ll r, ll x, const Info &infoX)
    {
        if (!u)
            u = new Node();
        if (l == r)
        {
            u->info = infoX;
            return;
        }
        push(u, l, r);
        ll m = (l + r) >> 1;
        x <= m ? modifyPoint(u->ls, l, m, x, infoX) : modifyPoint(u->rs, m + 1, r, x, infoX);
        pull(u);
    }

    // 递归实现:释放以 u 为根的子树
    void clear(Node *u)
    {
        if (!u)
            return;
        clear(u->ls), clear(u->rs), delete u;
    }
};

李超线段树

算法介绍

李超线段树用线段树的形态在整个横坐标区间 [l,r] 上维护一组直线 y = kx + b 的下包(或上包),每个结点保存在该结点覆盖区间上“当前更优”的一条直线。插入一条直线时,比较其与结点已存直线在区间中点与端点处的函数值,若新线在整个子区间更优则替换并结束,若只在一半区间更优则递归下发到对应儿子,查询时在点 x 处自上而下取经过的所有直线的更优值。该结构支持在 O(log U) 时间内进行“全域直线插入”与“点查询”,扩展后可支持“区间插入直线”。常用于动态维护分段线性函数的最小值/最大值,如斜率优化、DP 转移、最优代价维护等场景。

常见例题

题目:给定 q 次操作,横坐标取值域为 [Xmin, Xmax]。操作一为在区间 [l,r] 上加入一条直线 y = kx + b(只在 [l,r] 有效),操作二为在点 x 处询问当前所有有效直线的最小函数值。

做法:建立一棵覆盖 [Xmin, Xmax] 的李超线段树,插入时若是全域插入就对根递归,如果是区间插入就只把直线递归分发到与 [l,r] 相交的结点;查询时从根走到叶,沿途用当前结点保存的直线在 x 处取更优值并与子树结果取更小值(维护最小值时比较“更小”的线,维护最大值时取反或相反的比较)。

代码
// 李超线段树(Li Chao Tree),维护“最小值”版本
// 功能:插入直线 y = kx + b(全域或区间),在点 x 查询所有已插入直线的最小值
// 复杂度:单次插入/查询 O(log(R-L+1)),按需开点,适合超大坐标范围
struct LiChaoTree
{
    // 直线:y = kx + b
    struct Line
    {
        ll k, b; // 斜率与截距
        Line(ll kk = 0, ll bb = (ll)4e18) : k(kk), b(bb) {}
        // 计算 y 值
        ll eval(ll x) const { return k * x + b; }
    };

    // 结点:保存一条在当前区间更优的直线,左右儿子按需开点
    struct Node
    {
        Line ln;
        Node *ls, *rs;
        Node() : ln(), ls(nullptr), rs(nullptr) {}
    };

    Node *root;
    ll L, R; // 根指针与整体覆盖区间

    // 构造函数:建立覆盖 [L_,R_] 的空树
    LiChaoTree(ll L_ = 0, ll R_ = 0) : root(nullptr), L(L_), R(R_) {}

    // 在整个 [L,R] 上插入一条直线 y = kx + b
    void addLine(ll k, ll b)
    {
        Line ln(k, b);
        addLine(root, L, R, ln);
    }

    // 只在子区间 [l,r] 上插入一条直线 y = kx + b(闭区间)
    void addSegment(ll l, ll r, ll k, ll b)
    {
        if (l > r)
            return;
        Line ln(k, b);
        addSegment(root, L, R, l, r, ln);
    }

    // 在点 x 查询所有直线的最小值;若没有插入过有效线,返回 +INF
    ll query(ll x) const
    {
        return query(root, L, R, x);
    }

    // 释放整棵树
    void clear()
    {
        clear(root);
        root = nullptr;
    }

private:
    // 内部:确保结点存在
    void ensure(Node *&u)
    {
        if (!u)
            u = new Node();
    }

    // 内部:在结点 u 所覆盖的 [l,r] 上插入一条直线 ln(全域)
    void addLine(Node *&u, ll l, ll r, Line ln)
    {
        ensure(u);
        ll m = (l + r) >> 1;
        bool leftBetter = ln.eval(l) < u->ln.eval(l);
        bool midBetter = ln.eval(m) < u->ln.eval(m);
        bool rightBetter = ln.eval(r) < u->ln.eval(r);

        if (midBetter)
            swap(u->ln, ln); // 保证结点处保存“当前在中点更优”的线
        if (l == r)
            return;

        if (leftBetter != midBetter)
        {
            addLine(u->ls, l, m, ln);
        }
        else if (rightBetter != midBetter)
        {
            addLine(u->rs, m + 1, r, ln);
        }
    }

    // 内部:在结点 u 所覆盖的 [l,r] 上,仅对 [ql,qr] 部分插入直线 ln(区间插入)
    void addSegment(Node *&u, ll l, ll r, ll ql, ll qr, Line ln)
    {
        if (qr < l || r < ql)
            return;
        if (ql <= l && r <= qr)
        {
            addLine(u, l, r, ln);
            return;
        }
        ensure(u);
        ll m = (l + r) >> 1;
        addSegment(u->ls, l, m, ql, qr, ln), addSegment(u->rs, m + 1, r, ql, qr, ln);
    }

    // 内部:在点 x 查询最小值
    ll query(Node *u, ll l, ll r, ll x) const
    {
        if (!u)
            return (ll)4e18;
        ll res = u->ln.eval(x);
        if (l == r)
            return res;
        ll m = (l + r) >> 1;
        if (x <= m)
            return min(res, query(u->ls, l, m, x));
        return min(res, query(u->rs, m + 1, r, x));
    }

    // 内部:释放以 u 为根的子树
    void clear(Node *u)
    {
        if (!u)
            return;
        clear(u->ls), clear(u->rs), delete u;
    }
};

代码维护的是“最小值”李超树,如果需要维护“最大值”,可以把比较方向改为选更大的线,或在插入时把直线系数取相反数并在查询后取相反数。


主席树 (可持久化权值线段树)

算法介绍

主席树是在权值域上的可持久化线段树。对原数组做坐标压缩,将所有不同值映射到 [1,m] 的权值下标区间 [l,r](本文统一使用闭区间),第 i 个版本的根结点 root[i] 表示前缀 a[1..i] 在权值域上的频次数组。插入第 i 个元素时,从 root[i-1] 沿着 [1,m] 的路径复制结点并在命中的叶子把频次加一得到 root[i]。查询子数组 [L,R] 的第 k 小时,用两棵根 root[R] 和 root[L-1] 同时向下走,依据左右子树频次差决定走向;查询“≤x 的个数”时在 [1,rank(x)] 上做频次差求和。由于每次插入仅复制 O(log m) 个结点,空间 O((n+q) log m),所有操作均为 O(log m)。

常见例题

题目:给定长度为 n 的数组与 q 次询问。每次询问给出区间 [L,R] 与整数 k,输出子数组 a[L..R] 的第 k 小元素。

做法:先离散化所有值,建立 n 个前缀版本 root[i]。回答 [L,R] 的第 k 小时,从区间 [1,m] 的根结点出发比较左右子树的频次差 cntLeft = sum(left(root[R])) − sum(left(root[L-1])),若 cntLeft ≥ k 则在左子树递归,否则在右子树递归并令 k 减去 cntLeft,直到到达某个叶子下标 pos,最后用反向映射把 pos 还原成原值即为答案。

代码
// 主席树(可持久化权值线段树),闭区间 [1,m] 作为权值域
// 功能:buildPrefix 线性建立 n 个前缀版本;kth(L,R,k) 求子数组第k小;countLeq(L,R,x) 求子数组 ≤x 的个数
// 说明:内部完成坐标压缩,接口用原值;每个函数都有用途注释,括号与单行规约遵循你的格式要求
struct PersistentSeg
{
    // 结点存左儿子下标、右儿子下标、该权值段出现次数之和
    struct Node
    {
        int ls, rs, sum;
        Node(int L = 0, int R = 0, int S = 0) : ls(L), rs(R), sum(S) {}
    };

    int n;            // 原数组长度
    int m;            // 离散后权值域大小
    vector<int> xs;   // 离散值(有序去重),xs[pos-1] 是权值下标 pos 对应的原值
    vector<Node> tr;  // 结点池,结点0为“空结点”
    vector<int> root; // root[i] 表示前缀 a[1..i] 的版本根(root[0] 为空前缀)

    // 构造函数:空构造
    PersistentSeg() : n(0), m(0)
    {
        tr.reserve(1 << 20);
        tr.push_back(Node());
    }

    // 功能:对数组 a 进行坐标压缩,建立 xs 与映射
    template <typename T>
    void compress(const vector<T> &a)
    {
        xs.assign(a.begin(), a.end());
        sort(xs.begin(), xs.end());
        xs.erase(unique(xs.begin(), xs.end()), xs.end());
        m = (int)xs.size();
    }

    // 功能:把原值 v 映射为权值下标 pos(1..m)
    template <typename T>
    int rankOf(const T &v)
    {
        int pos = (int)(lower_bound(xs.begin(), xs.end(), v) - xs.begin()) + 1;
        return pos;
    }

    // 功能:在旧根 oldRoot 的基础上,把权值下标 pos 的频次 +1,返回新根;当前覆盖 [l,r]
    int addOne(int oldRoot, int l, int r, int pos)
    {
        int cur = clone(oldRoot);
        tr[cur].sum = tr[oldRoot].sum + 1;
        if (l == r)
            return cur;
        int mid = (l + r) >> 1;
        if (pos <= mid)
            tr[cur].ls = addOne(tr[oldRoot].ls, l, mid, pos);
        else
            tr[cur].rs = addOne(tr[oldRoot].rs, mid + 1, r, pos);
        return cur;
    }

    // 功能:克隆一个结点并返回新结点下标
    int clone(int p)
    {
        tr.push_back(tr[p]);
        return (int)tr.size() - 1;
    }

    // 功能:基于数组 a 建立 n+1 个前缀版本 root,root[0]=0;区间为 [1,m]
    template <typename T>
    void buildPrefix(const vector<T> &a)
    {
        n = (int)a.size();
        compress(a);
        root.assign(n + 1, 0);
        for (int i = 1; i <= n; i++)
        {
            int pos = rankOf(a[i - 1]);
            root[i] = addOne(root[i - 1], 1, m, pos);
        }
    }

    // 功能:查询子数组 [L,R] 的第 k 小,返回原值;要求 1<=L<=R<=n 且 1<=k<=R-L+1
    ll kth(int L, int R, int k)
    {
        int u = root[R], v = root[L - 1];
        int l = 1, r = m;
        while (l < r)
        {
            int mid = (l + r) >> 1;
            int cntLeft = tr[tr[u].ls].sum - tr[tr[v].ls].sum;
            if (cntLeft >= k)
                u = tr[u].ls, v = tr[v].ls, r = mid;
            else
                u = tr[u].rs, v = tr[v].rs, l = mid + 1, k -= cntLeft;
        }
        return xs[l - 1];
    }

    // 功能:统计子数组 [L,R] 中 ≤ x 的元素个数
    ll countLeq(int L, int R, ll x)
    {
        if (!m)
            return 0;
        int pos = (int)(upper_bound(xs.begin(), xs.end(), x) - xs.begin());
        if (pos == 0)
            return 0;
        return queryPrefix(root[R], 1, m, 1, pos) - queryPrefix(root[L - 1], 1, m, 1, pos);
    }

    // 功能:区间求和,求版本根 rt 在权值下标区间 [ql,qr] 的频次和;当前覆盖 [l,r]
    ll queryPrefix(int rt, int l, int r, int ql, int qr)
    {
        if (!rt || qr < l || r < ql)
            return 0;
        if (ql <= l && r <= qr)
            return tr[rt].sum;
        int mid = (l + r) >> 1;
        ll res = 0;
        if (ql <= mid)
            res += queryPrefix(tr[rt].ls, l, mid, ql, qr);
        if (qr > mid)
            res += queryPrefix(tr[rt].rs, mid + 1, r, ql, qr);
        return res;
    }
};

若题目要支持“第 k 大”,可以把 k 替换为区间长度减 k 再加一;若要支持“严格小于 x 的个数”,把 upper_bound 改成 lower_bound。该模板对值域极大或含负数都适用,因为已内置坐标压缩;如需在线插入删除,也可维护“可持久化差分”或转做可删主席树(BIT/线段树 + 上界二分),思路与本模板配合即可。


区间最值操作 & 区间历史最值

算法介绍

区间最值操作指在闭区间 [l,r] 上执行 a[i] ← min(a[i], x) 或 a[i] ← max(a[i], x),常与区间加法与区间查询(区间和、区间最大值、区间最小值)同时出现。经典做法是 Segment Tree Beats:每个结点维护当前最大值 max1、严格次大值 max2、最大值出现次数 cntMax,以及对称的 min1、min2、cntMin,再维护区间和 sum 与区间加标记 add。对区间 chmin(x) 时,若 x ≥ max1 则无需处理;若 x > max2,则可以在该结点“整段生效”,把 max1 降到 x,并用 cntMax 精确修正 sum;否则向下分裂递归并回收。整体复杂度可用势能法证明为均摊 O(log n)(区间 chmax(x) 对称)。这套维护量与操作流程与 OI Wiki 的描述一致,并能在同一棵树内同时支持 rangeAdd/rangeChmin/rangeChmax 与 rangeQuery(sum、min、max)。

“区间历史最值”问题要求在一系列区间更新后,支持查询某段内“某位置曾达到过的最大值/最小值/历史版本和”等历史信息。一个常见且实用的版本是“历史最大值 with 区间加/取 min/max”:在传播“会让值上升”的操作时,把“上升量”同步计入历史信息;对“只会下降”的操作(如 chmin)历史信息不变。实现时可在 beats 的节点上增设历史字段与“历史加标记”,在整段生效的场景下把对某一“值类”(如全为当前最大值或当前最小值的那部分)带来的提升累加进历史即可;必要时还要像维护 max1、max2 一样为历史值做“面向值类”的打标与合并。OI Wiki把这类做法称作“历史信息”,并给出“历史最大值/历史最小值/历史版本和”的建模思路及与 beats 的配合要点。

常见例题

题目: 维护一个长度为 n 的序列,支持 q 次操作:一是区间加 v;二是把 [l,r] 内所有数与 x 取 min;三是把 [l,r] 内所有数与 x 取 max;四是询问 [l,r] 的区间和与区间最大值;五是询问 [l,r] 的“历史最大值”。

做法: 用 Segment Tree Beats 维护当前量(sum、max1/max2/cntMax、min1/min2/cntMin、add),并在“整段生效”的三类操作里同步更新历史信息:对 rangeAdd(v) 若 v>0 则结点历史最大值整体加 v;对 rangeChmax(x) 若 x>min1 且 x≤min2 时只提升“等于 min1 的那一类”,可在该结点按 cntMin 精确更新历史;对 rangeChmin(x) 不提升历史。遇到无法整段生效时继续向下分裂,回收时历史量用左右儿子的历史量合并取 max/求和。有关 beats 的“整段生效条件”“为何要记次大/次小值”等细节与正确性,参见 OI Wiki 与若干教程梳理。

代码
// Segment Tree Beats:区间加 / 取min / 取max + 区间和 / 区间最值(闭区间 [l,r])
// 同时维护“历史最大值(曾达到过的最大值)”的区间查询
// 命名与封装风格对齐算竞模板:不使用 function,驼峰命名,struct + template<T>,详细注释
struct SegBeats
{
    struct Node
    {
        ll sum;        // 当前区间和
        ll max1, max2; // 当前最大值、严格次大值
        int cntMax;           // 当前最大值的出现次数
        ll min1, min2; // 当前最小值、严格次小值
        int cntMin;           // 当前最小值的出现次数
        ll add;        // 区间加法懒标记(对当前值)
        ll histMax;    // 区间内“历史最大值”的最大者
        ll histAdd;    // 历史最大值的“加标”(当整段整体上升时一并累加)

        Node()
        {
            sum = 0;
            max1 = -(1LL << 60);
            max2 = -(1LL << 60);
            cntMax = 0;
            min1 = (1LL << 60);
            min2 = (1LL << 60);
            cntMin = 0;
            add = 0;
            histMax = -(1LL << 60);
            histAdd = 0;
        }

        // 功能:以单个元素 x 初始化一个叶结点
        static Node fromValue(ll x)
        {
            Node u;
            u.sum = x;
            u.max1 = u.min1 = x;
            u.max2 = -(1LL << 60);
            u.min2 = (1LL << 60);
            u.cntMax = u.cntMin = 1;
            u.histMax = x; // 历史最大值初始等于初值
            return u;
        }
    };

    int n;                             // 数组大小
    vector<Node> tr;                   // 线段树数组,1-indexed
    SegBeats(int n_ = 0) { init(n_); } // 构造函数:创建空树
    template <class T>
    SegBeats(const vector<T> &a) { build(a); } // 构造函数:根据数组建树

    // 功能:初始化大小为 n_ 的空树
    void init(int n_)
    {
        n = n_;
        tr.assign(n * 4 + 5, Node());
    }

    // 功能:由数组 a 建树(下标从 1 开始)
    template <class T>
    void build(const vector<T> &a)
    {
        init((int)a.size() - 1);
        auto buildDfs = [&](auto &&self, int p, int l, int r) -> void
        {
            if (l == r)
            {
                tr[p] = Node::fromValue((ll)a[l]);
                return;
            }
            int m = (l + r) >> 1;
            self(self, p << 1, l, m);
            self(self, p << 1 | 1, m + 1, r);
            pull(p);
        };
        buildDfs(buildDfs, 1, 1, n);
    }

    // 功能:对 [l,r] 执行区间加 v
    void rangeAdd(int l, int r, ll v) { rangeAdd(1, 1, n, l, r, v); }

    // 功能:对 [l,r] 执行区间取 min(x)
    void rangeChmin(int l, int r, ll x) { rangeChmin(1, 1, n, l, r, x); }

    // 功能:对 [l,r] 执行区间取 max(x)
    void rangeChmax(int l, int r, ll x) { rangeChmax(1, 1, n, l, r, x); }

    // 功能:查询 [l,r] 的区间和
    ll querySum(int l, int r) { return querySum(1, 1, n, l, r); }

    // 功能:查询 [l,r] 的当前最大值
    ll queryMax(int l, int r) { return queryMax(1, 1, n, l, r); }

    // 功能:查询 [l,r] 的当前最小值
    ll queryMin(int l, int r) { return queryMin(1, 1, n, l, r); }

    // 功能:查询 [l,r] 的“历史最大值”的最大者(段内每个点的历史最大值取 max,再对区间取 max)
    ll queryHistMax(int l, int r) { return queryHistMax(1, 1, n, l, r); }

private:
    // 功能:把整段结点 p 的“当前值整体 +v”,并同步更新历史信息(仅 v>0 会提升历史)
    void applyAdd(int p, ll v, int len)
    {
        tr[p].sum += v * len;
        tr[p].max1 += v;
        tr[p].min1 += v;
        if (tr[p].max2 > -(1LL << 60))
            tr[p].max2 += v;
        if (tr[p].min2 < (1LL << 60))
            tr[p].min2 += v;
        tr[p].add += v;
        if (v > 0)
            tr[p].histAdd += v, tr[p].histMax += v;
    }

    // 功能:仅对“当前最大值这一类”执行减小到 x 的整段更新(要求 x 在 (max2,max1) 内)
    void applyChminOnMax(int p, ll x)
    {
        ll delta = tr[p].max1 - x;
        tr[p].sum -= delta * tr[p].cntMax;
        tr[p].max1 = x;
        if (tr[p].min1 == tr[p].max1)
            tr[p].min1 = x;
        if (tr[p].min2 == tr[p].max1)
            tr[p].min2 = x; // 两行短句合并

        // 历史最大值不因下降而改变;无需动 hist*
    }

    // 功能:仅对“当前最小值这一类”执行增大到 x 的整段更新(要求 x 在 (min1,min2) 内)
    void applyChmaxOnMin(int p, ll x)
    {
        ll delta = x - tr[p].min1;
        tr[p].sum += delta * tr[p].cntMin;
        tr[p].min1 = x;
        if (tr[p].max1 == tr[p].min1)
            tr[p].max1 = x;
        if (tr[p].max2 == tr[p].min1)
            tr[p].max2 = x;
        // 这一类被整体抬升,历史最大值需要同步抬升
        tr[p].histMax = max(tr[p].histMax, tr[p].max1); // 当前最大值已更新到可能更大
    }

    // 功能:从子结点汇总到父结点 p
    void pull(int p)
    {
        Node &u = tr[p], &L = tr[p << 1], &R = tr[p << 1 | 1];
        u.sum = L.sum + R.sum;

        if (L.max1 > R.max1)
            u.max1 = L.max1, u.cntMax = L.cntMax, u.max2 = max(L.max2, R.max1);
        else if (L.max1 < R.max1)
            u.max1 = R.max1, u.cntMax = R.cntMax, u.max2 = max(R.max2, L.max1);
        else
            u.max1 = L.max1, u.cntMax = L.cntMax + R.cntMax, u.max2 = max(L.max2, R.max2);

        if (L.min1 < R.min1)
            u.min1 = L.min1, u.cntMin = L.cntMin, u.min2 = min(L.min2, R.min1);
        else if (L.min1 > R.min1)
            u.min1 = R.min1, u.cntMin = R.cntMin, u.min2 = min(R.min2, L.min1);
        else
            u.min1 = L.min1, u.cntMin = L.cntMin + R.cntMin, u.min2 = min(L.min2, R.min2);

        u.add = 0; // pull 不改变 add;add 仅在 push 时下传并在 applyAdd 中叠加

        // 历史最大值是左右子区间历史最大值的最大者,同时也要不小于当前 max1
        u.histMax = max({L.histMax, R.histMax, u.max1});
        u.histAdd = 0;
    }

    // 功能:把懒标记从 p 下推到左右子
    void push(int p, int l, int r)
    {
        int m = (l + r) >> 1;
        if (tr[p].add != 0)
        {
            applyAdd(p << 1, tr[p].add, m - l + 1);
            applyAdd(p << 1 | 1, tr[p].add, r - m);
            tr[p].add = 0;
        }
        // 对“整段生效的 chmin/chmax”没有单独的懒标;我们在进入儿子前按需把父节点的 max/min 约束“压过去”
        if (tr[p << 1].max1 > tr[p].max1)
            applyChminOnMax(p << 1, tr[p].max1);
        if (tr[p << 1 | 1].max1 > tr[p].max1)
            applyChminOnMax(p << 1 | 1, tr[p].max1);
        if (tr[p << 1].min1 < tr[p].min1)
            applyChmaxOnMin(p << 1, tr[p].min1);
        if (tr[p << 1 | 1].min1 < tr[p].min1)
            applyChmaxOnMin(p << 1 | 1, tr[p].min1);
        // 历史加标记对我们这份实现只依赖 applyAdd 中的 v>0 逻辑;这里不存在单独的“历史懒”需要下推
    }

    // 功能:在结点 p 覆盖的 [l,r] 上执行区间加 v
    void rangeAdd(int p, int l, int r, int ql, int qr, ll v)
    {
        if (qr < l || r < ql)
            return;
        if (ql <= l && r <= qr)
        {
            applyAdd(p, v, r - l + 1);
            return;
        }
        push(p, l, r);
        int m = (l + r) >> 1;
        rangeAdd(p << 1, l, m, ql, qr, v);
        rangeAdd(p << 1 | 1, m + 1, r, ql, qr, v);
        pull(p);
    }

    // 功能:在结点 p 覆盖的 [l,r] 上执行区间 chmin(x)
    void rangeChmin(int p, int l, int r, int ql, int qr, ll x)
    {
        if (qr < l || r < ql || tr[p].max1 <= x)
            return;
        if (ql <= l && r <= qr && tr[p].max2 < x)
        {
            applyChminOnMax(p, x);
            return;
        }
        push(p, l, r);
        int m = (l + r) >> 1;
        rangeChmin(p << 1, l, m, ql, qr, x);
        rangeChmin(p << 1 | 1, m + 1, r, ql, qr, x);
        pull(p);
    }

    // 功能:在结点 p 覆盖的 [l,r] 上执行区间 chmax(x)
    void rangeChmax(int p, int l, int r, int ql, int qr, ll x)
    {
        if (qr < l || r < ql || tr[p].min1 >= x)
            return;
        if (ql <= l && r <= qr && tr[p].min2 > x)
        {
            applyChmaxOnMin(p, x);
            return;
        }
        push(p, l, r);
        int m = (l + r) >> 1;
        rangeChmax(p << 1, l, m, ql, qr, x);
        rangeChmax(p << 1 | 1, m + 1, r, ql, qr, x);
        pull(p);
    }

    // 功能:查询 [l,r] 的区间和
    ll querySum(int p, int l, int r, int ql, int qr)
    {
        if (qr < l || r < ql)
            return 0LL;
        if (ql <= l && r <= qr)
            return tr[p].sum;
        push(p, l, r);
        int m = (l + r) >> 1;
        return querySum(p << 1, l, m, ql, qr) + querySum(p << 1 | 1, m + 1, r, ql, qr);
    }

    // 功能:查询 [l,r] 的当前最大值
    ll queryMax(int p, int l, int r, int ql, int qr)
    {
        if (qr < l || r < ql)
            return -(1LL << 60);
        if (ql <= l && r <= qr)
            return tr[p].max1;
        push(p, l, r);
        int m = (l + r) >> 1;
        ll a = queryMax(p << 1, l, m, ql, qr), b = queryMax(p << 1 | 1, m + 1, r, ql, qr);
        return a > b ? a : b;
    }

    // 功能:查询 [l,r] 的当前最小值
    ll queryMin(int p, int l, int r, int ql, int qr)
    {
        if (qr < l || r < ql)
            return (1LL << 60);
        if (ql <= l && r <= qr)
            return tr[p].min1;
        push(p, l, r);
        int m = (l + r) >> 1;
        ll a = queryMin(p << 1, l, m, ql, qr), b = queryMin(p << 1 | 1, m + 1, r, ql, qr);
        return a < b ? a : b;
    }

    // 功能:查询 [l,r] 的“历史最大值”的最大者
    ll queryHistMax(int p, int l, int r, int ql, int qr)
    {
        if (qr < l || r < ql)
            return -(1LL << 60);
        if (ql <= l && r <= qr)
            return tr[p].histMax;
        push(p, l, r);
        int m = (l + r) >> 1;
        ll a = queryHistMax(p << 1, l, m, ql, qr), b = queryHistMax(p << 1 | 1, m + 1, r, ql, qr);
        return a > b ? a : b;
    }
};

平衡树

朴素二叉搜索树

算法介绍

朴素二叉搜索树按中序有序地维护键值,插入与删除通过沿着大小关系走到相应位置完成,查找、前驱与后继均可在期望 O(h) 完成,h 为树高。没有随机化或自平衡手段时,最坏会退化到链,比赛中常把它作为讲解用或样例对拍基线。本文代码以指针节点写法实现插入、删除、查找、前驱、后继与按序统计的基本接口,并配合严格注释与构造函数,命名风格与 v3.2 对齐。

常见例题

题目:给定若干操作,插入一个整数,删除一个整数,询问一个整数是否存在,查询某数的前驱与后继。

做法:用朴素二叉搜索树维护键的集合,插入沿路径找到空位挂上新节点,删除分三类:无子、单子、双子(用后继或前驱替换)。查询、前驱与后继均是沿根向下比较并记录答案。

代码
// 朴素二叉搜索树
// 功能:insert / erase / count / predecessor / successor,支持重复或不重复按需调整
// 复杂度:O(h),h为树高;最坏退化为O(n)
template <typename T>
struct NaiveBST
{
    struct Node
    {
        T key;
        int cnt, sz;
        Node *left, *right;
        // 构造:键key,计数cnt=1,子指针为空
        Node(const T &k) : key(k), cnt(1), sz(1), left(nullptr), right(nullptr) {}
    };
    Node *root;

    NaiveBST() : root(nullptr) {}

    // 维护以p为根的子树大小
    static int size(Node *p) { return p ? p->sz : 0; }
    static void pull(Node *p)
    {
        if (p)
            p->sz = p->cnt + size(p->left) + size(p->right);
    }

    // 插入键k(可重复)
    void insert(const T &k)
    {
        auto go = [&](auto &&self, Node *&p) -> void
        {
            if (!p)
            {
                p = new Node(k);
                return;
            }
            if (k == p->key)
            {
                p->cnt++;
                pull(p);
                return;
            }
            if (k < p->key)
                self(self, p->left);
            else
                self(self, p->right);
            pull(p);
        };
        go(go, root);
    }

    // 查询键k出现次数
    int count(const T &k) const
    {
        Node *p = root;
        while (p && p->key != k)
            p = k < p->key ? p->left : p->right;
        return p ? p->cnt : 0;
    }

    // 删除一个k(存在时减一,不存在则忽略)
    void erase(const T &k)
    {
        auto go = [&](auto &&self, Node *&p) -> void
        {
            if (!p)
                return;
            if (k < p->key)
                self(self, p->left);
            else if (k > p->key)
                self(self, p->right);
            else
            {
                if (p->cnt > 1)
                {
                    p->cnt--;
                    pull(p);
                    return;
                }
                if (!p->left || !p->right)
                {
                    Node *q = p->left ? p->left : p->right;
                    delete p;
                    p = q;
                }
                else
                {
                    Node *s = p->right, *pre = p;
                    while (s->left)
                        pre = s, s = s->left;
                    p->key = s->key;
                    p->cnt = s->cnt;
                    s->cnt = 1;
                    if (pre == p)
                        self(self, p->right);
                    else
                        self(self, pre->left);
                }
            }
            pull(p);
        };
        go(go, root);
    }

    // 前驱:严格小于k的最大值,若不存在返回false
    bool predecessor(const T &k, T &res) const
    {
        Node *p = root;
        bool ok = false;
        while (p)
        {
            if (k <= p->key)
                p = p->left;
            else
            {
                res = p->key, ok = true;
                p = p->right;
            }
        }
        return ok;
    }

    // 后继:严格大于k的最小值,若不存在返回false
    bool successor(const T &k, T &res) const
    {
        Node *p = root;
        bool ok = false;
        while (p)
        {
            if (k >= p->key)
                p = p->right;
            else
            {
                res = p->key, ok = true;
                p = p->left;
            }
        }
        return ok;
    }

    // 第k小(1-indexed),若不存在返回false
    bool kth(int k, T &res) const
    {
        Node *p = root;
        while (p)
        {
            int ls = size(p->left);
            if (k <= ls)
                p = p->left;
            else if (k <= ls + p->cnt)
            {
                res = p->key;
                return true;
            }
            else
                k -= ls + p->cnt, p = p->right;
        }
        return false;
    }

    // 小于等于k的排名(元素计重)
    int rankOf(const T &k) const
    {
        Node *p = root;
        int r = 0;
        while (p)
        {
            if (k < p->key)
                p = p->left;
            else
            {
                r += size(p->left) + (k >= p->key ? p->cnt : 0);
                p = p->right;
            }
        }
        return r;
    }
};

Treap

算法介绍

旋转 Treap 在二叉搜索树键序上再叠加一个堆序随机优先级,插入与删除先按键走位,再通过旋转把堆序恢复,从而以期望 O(log n) 保持平衡。它实现简单、常数小,适合在线操作。本文 Treap 提供插入、删除、计数、前驱、后继、kth、rank 接口。

常见例题

题目:给定 q 次操作,包含插入 x、删除 x(若有多个仅删一个)、查询 x 是否存在、求 x 的前驱与后继、求有序第 k 小、求 x 的有序排名。

做法:用 Treap 维护集合,操作按标准模板实现即可,随机优先级保证期望平衡。

代码
// 旋转 Treap(BST + 堆序,支持重复计数)
// 功能:insert / erase / count / predecessor / successor / kth / rankOf
// 复杂度:期望 O(log n)
template <typename T>
struct Treap
{
    struct Node
    {
        T key;
        int pri, cnt, sz;
        Node *left, *right;
        Node(const T &k, int p) : key(k), pri(p), cnt(1), sz(1), left(nullptr), right(nullptr) {}
    };
    Node *root;
    mt19937 rng;

    Treap() : root(nullptr), rng(chrono::steady_clock::now().time_since_epoch().count()) {}

    static int size(Node *p) { return p ? p->sz : 0; }
    static void pull(Node *p)
    {
        if (p)
            p->sz = p->cnt + size(p->left) + size(p->right);
    }

    // 右旋
    static void rotateRight(Node *&p)
    {
        Node *q = p->left;
        p->left = q->right;
        q->right = p;
        pull(p);
        p = q;
        pull(p);
    }
    // 左旋
    static void rotateLeft(Node *&p)
    {
        Node *q = p->right;
        p->right = q->left;
        q->left = p;
        pull(p);
        p = q;
        pull(p);
    }

    // 插入键k
    void insert(const T &k)
    {
        auto go = [&](auto &&self, Node *&p) -> void
        {
            if (!p)
            {
                p = new Node(k, (int)rng());
                return;
            }
            if (k == p->key)
            {
                p->cnt++;
                pull(p);
                return;
            }
            if (k < p->key)
            {
                self(self, p->left);
                if (p->left->pri > p->pri)
                    rotateRight(p);
            }
            else
            {
                self(self, p->right);
                if (p->right->pri > p->pri)
                    rotateLeft(p);
            }
            pull(p);
        };
        go(go, root);
    }

    // 删除一个k
    void erase(const T &k)
    {
        auto go = [&](auto &&self, Node *&p) -> void
        {
            if (!p)
                return;
            if (k < p->key)
                self(self, p->left);
            else if (k > p->key)
                self(self, p->right);
            else
            {
                if (p->cnt > 1)
                {
                    p->cnt--;
                    pull(p);
                    return;
                }
                if (!p->left || !p->right)
                {
                    Node *q = p->left ? p->left : p->right;
                    delete p;
                    p = q;
                }
                else if (p->left->pri > p->right->pri)
                    rotateRight(p), self(self, p->right);
                else
                    rotateLeft(p), self(self, p->left);
            }
            pull(p);
        };
        go(go, root);
    }

    // 查询键k的个数
    int count(const T &k) const
    {
        Node *p = root;
        while (p && p->key != k)
            p = k < p->key ? p->left : p->right;
        return p ? p->cnt : 0;
    }

    // 查询小于k的最大值
    bool predecessor(const T &k, T &res) const
    {
        Node *p = root;
        bool ok = false;
        while (p)
        {
            if (k <= p->key)
                p = p->left;
            else
            {
                res = p->key, ok = true;
                p = p->right;
            }
        }
        return ok;
    }
    // 查询大于k的最小值

    bool successor(const T &k, T &res) const
    {
        Node *p = root;
        bool ok = false;
        while (p)
        {
            if (k >= p->key)
                p = p->right;
            else
            {
                res = p->key, ok = true;
                p = p->left;
            }
        }
        return ok;
    }

    // 查询第k小的数
    bool kth(int k, T &res) const
    {
        Node *p = root;
        while (p)
        {
            int ls = size(p->left);
            if (k <= ls)
                p = p->left;
            else if (k <= ls + p->cnt)
            {
                res = p->key;
                return true;
            }
            else
                k -= ls + p->cnt, p = p->right;
        }
        return false;
    }

    // 键k的排名,即小于k的数的个数加1
    int rankOf(const T &k) const
    {
        Node *p = root;
        int r = 0;
        while (p)
        {
            if (k < p->key)
                p = p->left;
            else
            {
                r += size(p->left) + (k >= p->key ? p->cnt : 0);
                p = p->right;
            }
        }
        return r;
    }
};

无旋 Treap

算法介绍

无旋 Treap(FHQ Treap)通过 split 与 merge 两个基本操作维护平衡,不需要显式旋转。按键分裂把一棵树拆为两棵满足全部键关系的树,按优先级合并时保证堆序,插入等价于先按键分裂再把新点夹在中间合并,删除等价于把等于 k 的那一段拿出来后丢掉再把剩余两段合并。由于结构天然可持久化、实现简短,它在比赛中非常常用。

常见例题

题目:与上一节相同的一组集合操作与序统计。

做法:用 split(root, ≤k) 与 merge(left, right) 组合完成插入、删除、rank、kth、前驱与后继等所有操作,所有操作期望 O(log n)。

代码
// FHQ Treap(无旋Treap),支持重复计数
// 功能:insert / erase / count / predecessor / successor / kth / rankOf
// 复杂度:期望 O(log n)
template <typename T>
struct FHQTreap
{
    struct Node
    {
        T key;
        int pri, cnt, sz;
        Node *left, *right;
        Node(const T &k, int p) : key(k), pri(p), cnt(1), sz(1), left(nullptr), right(nullptr) {}
    };
    Node *root;
    mt19937 rng;

    FHQTreap() : root(nullptr), rng(chrono::steady_clock::now().time_since_epoch().count()) {}

    static int size(Node *p) { return p ? p->sz : 0; }
    static void pull(Node *p)
    {
        if (p)
            p->sz = p->cnt + size(p->left) + size(p->right);
    }

    // 合并两棵树,要求所有 left 的键 <= 所有 right 的键
    static Node *merge(Node *left, Node *right)
    {
        if (!left || !right)
            return left ? left : right;
        if (left->pri > right->pri)
        {
            left->right = merge(left->right, right);
            pull(left);
            return left;
        }
        else
        {
            right->left = merge(left, right->left);
            pull(right);
            return right;
        }
    }

    // 按键 k 分裂为 (<=k) 与 (>k)
    static pair<Node *, Node *> split(Node *p, const T &k)
    {
        if (!p)
            return {nullptr, nullptr};
        if (p->key <= k)
        {
            auto t = split(p->right, k);
            p->right = t.first;
            pull(p);
            return {p, t.second};
        }
        else
        {
            auto t = split(p->left, k);
            p->left = t.second;
            pull(p);
            return {t.first, p};
        }
    }

    // 插入键k
    void insert(const T &k)
    {
        auto lr = split(root, k);
        auto ll = split(lr.first, k - 1); // 注意:若T非数值类型,此处应使用另一个split逻辑:先 split(root, k-ε)
        if (ll.second && ll.second->key == k)
            ll.second->cnt++, pull(ll.second);
        else
            ll.second = merge(ll.second, new Node(k, (int)rng()));
        root = merge(merge(ll.first, ll.second), lr.second);
    }

    // 若T不是可做 k-1 的类型,上面的“等于段”拆分方式可改用 count() 决策后直接插入到中间

    // 删除一个k
    void erase(const T &k)
    {
        auto lr = split(root, k);
        auto ll = split(lr.first, k - 1);
        if (ll.second)
        {
            if (ll.second->cnt > 1)
                ll.second->cnt--, pull(ll.second);
            else
            {
                delete ll.second;
                ll.second = nullptr;
            }
        }
        root = merge(merge(ll.first, ll.second), lr.second);
    }

    // 次数
    int count(const T &k)
    {
        auto lr = split(root, k);
        auto ll = split(lr.first, k - 1);
        int c = ll.second ? ll.second->cnt : 0;
        root = merge(merge(ll.first, ll.second), lr.second);
        return c;
    }

    // 排名(元素计重):<=k 的个数
    int rankOf(const T &k)
    {
        auto lr = split(root, k);
        int ans = size(lr.first);
        root = merge(lr.first, lr.second);
        return ans;
    }

    // 第k小(1-indexed)
    bool kth(int k, T &res)
    {
        Node *p = root;
        while (p)
        {
            int ls = size(p->left);
            if (k <= ls)
                p = p->left;
            else if (k <= ls + p->cnt)
            {
                res = p->key;
                return true;
            }
            else
                k -= ls + p->cnt, p = p->right;
        }
        return false;
    }

    // 前驱
    bool predecessor(const T &k, T &res)
    {
        auto lr = split(root, k - 1);
        if (!lr.first)
        {
            root = merge(lr.first, lr.second);
            return false;
        }
        Node *p = lr.first;
        while (p->right)
            p = p->right;
        res = p->key;
        root = merge(lr.first, lr.second);
        return true;
    }

    // 后继
    bool successor(const T &k, T &res)
    {
        auto lr = split(root, k);
        if (!lr.second)
        {
            root = merge(lr.first, lr.second);
            return false;
        }
        Node *p = lr.second;
        while (p->left)
            p = p->left;
        res = p->key;
        root = merge(lr.first, lr.second);
        return true;
    }
};

说明 若键类型不是整数,k-1 的写法不可用,可把“等于段”的 split 改为先对 k 做一次 split 得到 ≤k>k,然后再对 ≤k 进行“严格小于 k”的 split。实现方式是提供另一个 splitStrict(Node*, const T&),比较时用 <>= 区分。


Splay 树

算法介绍

Splay 通过旋转把最近访问的节点伸展到根,从而获得良好的均摊复杂度。基本操作是单旋与双旋(zig、zig-zig、zig-zag),配合 pushUp 与必要的 pushDown。它可在不记录随机优先级的情况下完成集合所有常用操作。本文实现支持插入、删除、kth、rank、前驱与后继,并给出严格注释与构造函数,接口贴合 v3.2 的常见用法。

常见例题

同 Treap 的集合操作与序统计,要求均摊 O(log n)。做法: splay 到根并在根或相邻子树上继续操作,删除时把要删的键 splay 到根,再把根的左子树最大点伸展到根并接上原右子树。

代码
// Splay 树(集合版,支持重复计数)
// 功能:insert / erase / count / kth / rankOf / predecessor / successor
// 复杂度:均摊 O(log n)
template <typename T>
struct Splay
{
    struct Node
    {
        T key;
        int cnt, sz;
        Node *fa, *ch[2];
        Node(const T &k) : key(k), cnt(1), sz(1), fa(nullptr) { ch[0] = ch[1] = nullptr; }
    };
    Node *root;

    Splay() : root(nullptr) {}

    static int size(Node *p) { return p ? p->sz : 0; }
    static void pull(Node *p)
    {
        if (p)
            p->sz = p->cnt + size(p->ch[0]) + size(p->ch[1]);
    }

    // 判断p是否为其父的右儿子
    static int dir(Node *p) { return p->fa && p == p->fa->ch[1]; }

    // 连接 child 到 parent 的 d 儿子位上
    static void connect(Node *child, Node *parent, int d)
    {
        if (parent)
            parent->ch[d] = child;
        if (child)
            child->fa = parent;
    }

    // 单次旋转
    static void rotate(Node *p)
    {
        Node *q = p->fa, *g = q->fa;
        int d = dir(p), dq = dir(q);
        connect(p->ch[d ^ 1], q, d);
        connect(q, p, d ^ 1);
        if (g)
            connect(p, g, dq);
        else
            p->fa = nullptr;
        pull(q), pull(p);
    }

    // 把p伸展到目标父亲 f(若f为nullptr则伸展到根)
    void splay(Node *p, Node *f = nullptr)
    {
        while (p->fa != f)
        {
            Node *q = p->fa, *g = q->fa;
            if (g != f)
                rotate((dir(p) == dir(q)) ? q : p);
            rotate(p);
        }
        if (!f)
            root = p;
    }

    // 在树中查找键k,若存在把该节点伸展到根,若不存在把最后访问节点伸展到根
    Node *find(const T &k)
    {
        Node *p = root, *last = nullptr;
        while (p && p->key != k)
            last = p, p = k < p->key ? p->ch[0] : p->ch[1];
        splay(p ? p : last);
        return p;
    }

    // 插入k(可重复)
    void insert(const T &k)
    {
        if (!root)
        {
            root = new Node(k);
            return;
        }
        Node *p = root, *fa = nullptr;
        while (p && p->key != k)
            fa = p, p = k < p->key ? p->ch[0] : p->ch[1];
        if (p)
        {
            p->cnt++, pull(p), splay(p);
        }
        else
        {
            Node *x = new Node(k);
            if (k < fa->key)
                connect(x, fa, 0);
            else
                connect(x, fa, 1);
            splay(x);
        }
    }

    // 计数
    int count(const T &k)
    {
        Node *p = find(k);
        return p && p->key == k ? p->cnt : 0;
    }

    // 前驱(严格小于)
    bool predecessor(const T &k, T &res)
    {
        find(k);
        Node *p = root;
        if (p && p->key < k)
        {
            res = p->key;
            return true;
        }
        p = p->ch[0];
        if (!p)
            return false;
        while (p->ch[1])
            p = p->ch[1];
        res = p->key;
        splay(p);
        return true;
    }

    // 后继(严格大于)
    bool successor(const T &k, T &res)
    {
        find(k);
        Node *p = root;
        if (p && p->key > k)
        {
            res = p->key;
            return true;
        }
        p = p->ch[1];
        if (!p)
            return false;
        while (p->ch[0])
            p = p->ch[0];
        res = p->key;
        splay(p);
        return true;
    }

    // 第k小(1-indexed)
    bool kth(int k, T &res)
    {
        Node *p = root;
        if (!p || k <= 0 || k > size(p))
            return false;
        while (p)
        {
            int ls = size(p->ch[0]);
            if (k <= ls)
                p = p->ch[0];
            else if (k <= ls + p->cnt)
            {
                res = p->key;
                splay(p);
                return true;
            }
            else
                k -= ls + p->cnt, p = p->ch[1];
        }
        return false;
    }

    // 排名(元素计重):<=k 的个数
    int rankOf(const T &k)
    {
        find(k);
        int ls = size(root->ch[0]);
        return k < root->key ? ls : ls + root->cnt;
    }

    // 删除一个k(若有多个仅删一个)
    void erase(const T &k)
    {
        Node *p = find(k);
        if (!p || p->key != k)
            return;
        if (p->cnt > 1)
        {
            p->cnt++;
            p->cnt--;
            pull(p);
            return;
        } // 占位以便解释:计数-1
        if (!p->ch[0])
        {
            root = p->ch[1];
            if (root)
                root->fa = nullptr;
            delete p;
            return;
        }
        if (!p->ch[1])
        {
            root = p->ch[0];
            if (root)
                root->fa = nullptr;
            delete p;
            return;
        }
        Node *x = p->ch[0];
        while (x->ch[1])
            x = x->ch[1];
        splay(x, p);             // 把左子树最大点旋到p的左儿子位置
        connect(p->ch[1], x, 1); // 接上原右子树
        x->fa = nullptr;
        root = x;
        delete p;
        pull(root);
    }
};

可持久化数据结构

可持久化数组

算法介绍

可持久化数组以一棵静态结构的“权值存储+路径复制”线段树为载体,每次单点修改时仅复制根到叶子的路径,其他结点复用旧版本指针,因此每次版本增量空间为 O(log n),单点赋值和单点查询均为 O(log n)。所有区间与递归均采用闭区间 [l,r]。本节用 Info 作为叶子所存的元素信息,你可以把它扩展为更复杂的结构,只要定义好合并与赋值即可。

常见例题

题目:给定一个长度为 n 的数组以及 q 次操作,操作一为创建新版本并把位置 x 赋值为 v,操作二为在给定版本 ver 上查询位置 x 的值。

做法:以初始数组建出第 0 版根,之后每次赋值用路径复制构造新根保存到版本表,查询时在对应版本的根上沿闭区间 [l,r] 逐层定位到叶子读取 Info。

代码
// 可持久化数组(基于路径复制的静态线段树),所有区间为闭区间 [l,r]
// Info 用于存储单点信息与赋值;若仅需数值可用 Info{val}
template <typename T>
struct Info
{
    T v;
    Info(T _v = T{}) : v(_v) {}
};

// 结点:左右子指针与区间信息
template <typename Info>
struct Node
{
    int ls, rs;
    Info info;
    Node(int l_ = 0, int r_ = 0, const Info &x = Info()) : ls(l_), rs(r_), info(x) {}
};

template <typename Info>
struct PersistentArray
{
    int n;                 // 数组长度
    vector<Node<Info>> tr; // 结点池
    vector<int> root;      // 各版本根
    PersistentArray(int n_ = 0) : n(0)
    {
        if (n_)
            init(n_);
    }

    // 初始化长度并清空版本,默认建立空数组版本0
    void init(int n_)
    {
        n = n_;
        tr.clear(), tr.reserve(n * 25);
        root.clear();
        root.push_back(build(0, n - 1));
    }

    // 用原始数组a构建版本0
    template <typename U>
    void buildFrom(const vector<U> &a)
    {
        n = (int)a.size();
        tr.clear(), tr.reserve(n * 25);
        root.clear();
        auto buildArr = [&](auto &&self, int l, int r) -> int
        {
            int p = newNode();
            if (l == r)
            {
                tr[p].info = Info(a[l]);
                return p;
            }
            int m = (l + r) >> 1;
            tr[p].ls = self(self, l, m), tr[p].rs = self(self, m + 1, r);
            return p;
        };
        root.push_back(buildArr(buildArr, 0, n - 1));
    }

    // 读取版本ver在位置x的Info
    Info get(int ver, int x) const
    {
        int p = root[ver], l = 0, r = n - 1;
        while (l != r)
        {
            int m = (l + r) >> 1;
            x <= m ? p = tr[p].ls, r = m : p = tr[p].rs, l = m + 1;
        }
        return tr[p].info;
    }

    // 基于版本ver创建新版本,在位置x赋值为v,返回新版本编号
    int setPoint(int ver, int x, const Info &v)
    {
        int np = copyPath(root[ver], 0, n - 1, x, v);
        root.push_back(np);
        return (int)root.size() - 1;
    }

private:
    // 新建空结点
    int newNode()
    {
        tr.emplace_back();
        return (int)tr.size() - 1;
    }

    // 建空树
    int build(int l, int r)
    {
        int p = newNode();
        if (l == r)
            return p;
        int m = (l + r) >> 1;
        tr[p].ls = build(l, m), tr[p].rs = build(m + 1, r);
        return p;
    }

    // 从旧根op出发复制到叶子的路径,将位置x改为v,返回新根
    int copyPath(int op, int l, int r, int x, const Info &v)
    {
        int p = newNode();
        tr[p] = tr[op];
        if (l == r)
        {
            tr[p].info = v;
            return p;
        }
        int m = (l + r) >> 1;
        if (x <= m)
            tr[p].ls = copyPath(tr[op].ls, l, m, x, v);
        else
            tr[p].rs = copyPath(tr[op].rs, m + 1, r, x, v);
        return p;
    }
};

可持久化线段树

算法介绍

可持久化线段树在“可持久化数组”的基础上增加区间聚合能力。每次单点修改仍是复制路径,但结点信息由左右儿子通过 Info 的合并算子得到,典型是区间和或区间最大。它适合在多版本上做区间查询,例如时光倒流或离线前缀问题。每次修改与查询的时间 O(log n),每个新版本的增量空间 O(log n)。

常见例题

题目: 给定数组与若干版本,每次在上一个版本基础上将坐标 x 的值改为 v 生成新版本,然后询问某一版本 ver 的区间 [l,r] 的区间和。

做法: 定义 Info 存 sum 与合并 operator+,用 copyPath 在叶子更新值并沿路径用合并回溯得到新根,rangeQuery 在给定版本根上按闭区间 [l,r] 递归求解。

代码
// 可持久化线段树
// 特点:支持多版本的单点修改与区间查询,每次修改仅复制根到叶子的路径,其他部分复用旧版本
// 复杂度:单次修改与查询 O(log n),每次修改增量空间 O(log n)

template <typename T>
struct SegInfo
{
    T sum;                                       // 区间和,这里示例维护区间和
    SegInfo(T v = T{}) : sum(v) {}               // 构造函数,默认值为0
    static SegInfo empty() { return SegInfo(); } // 提供一个空信息,用于越界或无效区间
    friend SegInfo operator+(const SegInfo &a, const SegInfo &b)
    {
        // 合并两个子区间的结果
        return SegInfo(a.sum + b.sum);
    }
};

template <typename Info>
struct PersistentSegTree
{
    // 结点定义,每个结点存储左右儿子编号与区间信息
    struct Node
    {
        int ls, rs; // 左右儿子在结点数组中的下标
        Info info;  // 当前结点存储的信息
        Node(int l = 0, int r = 0, const Info &x = Info()) : ls(l), rs(r), info(x) {}
    };

    int n;            // 数组长度
    vector<Node> tr;  // 结点池,所有版本共享
    vector<int> root; // 保存每个版本的根节点下标

    // 构造函数,默认空,或者直接指定长度
    PersistentSegTree(int n_ = 0) : n(0)
    {
        if (n_)
            init(n_);
    }

    // 构造函数:用数组初始化
    template <typename T>
    PersistentSegTree(const vector<T> &a) { init(a); }

    // 功能:初始化长度 n_ 的空树,版本0为全空
    void init(int n_)
    {
        n = n_;
        tr.clear();
        tr.reserve(n * 25); // 预留空间,避免频繁扩容
        root.clear();
        root.push_back(build(0, n - 1)); // 建立初始空版本
    }

    // 功能:用数组 a 初始化版本0
    template <typename T>
    void init(const vector<T> &a)
    {
        n = (int)a.size();
        tr.clear();
        tr.reserve(n * 25);
        root.clear();
        auto buildArr = [&](auto &&self, int l, int r) -> int
        {
            int p = newNode();
            if (l == r)
            {
                tr[p].info = Info(a[l]);
                return p;
            }
            int m = (l + r) >> 1;
            tr[p].ls = self(self, l, m), tr[p].rs = self(self, m + 1, r);
            tr[p].info = tr[tr[p].ls].info + tr[tr[p].rs].info;
            return p;
        };
        root.push_back(buildArr(buildArr, 0, n - 1));
    }

    // 功能:在版本 ver 上修改位置 x 的值为 v,生成一个新版本,返回新版本编号
    int modify(int ver, int x, const Info &v)
    {
        int np = copyPath(root[ver], 0, n - 1, x, v);
        root.push_back(np);
        return (int)root.size() - 1;
    }

    // 功能:在版本 ver 上查询区间 [L,R] 的信息
    Info rangeQuery(int ver, int L, int R) const
    {
        return query(root[ver], 0, n - 1, L, R);
    }

private:
    // 功能:新建一个空结点,返回编号
    int newNode()
    {
        tr.emplace_back();
        return (int)tr.size() - 1;
    }

    // 功能:建立空树,区间 [l,r],所有 Info 都为空
    int build(int l, int r)
    {
        int p = newNode();
        if (l == r)
            return p;
        int m = (l + r) >> 1;
        tr[p].ls = build(l, m), tr[p].rs = build(m + 1, r);
        return p;
    }

    // 功能:在旧版本结点 op 的基础上复制路径,将位置 x 更新为 v,返回新根
    int copyPath(int op, int l, int r, int x, const Info &v)
    {
        int p = newNode();
        tr[p] = tr[op]; // 复制当前结点
        if (l == r)
        {
            tr[p].info = v;
            return p;
        } // 到达叶子,直接更新值
        int m = (l + r) >> 1;
        if (x <= m)
            tr[p].ls = copyPath(tr[op].ls, l, m, x, v);
        else
            tr[p].rs = copyPath(tr[op].rs, m + 1, r, x, v);
        tr[p].info = tr[tr[p].ls].info + tr[tr[p].rs].info; // 回溯更新信息
        return p;
    }

    // 功能:在根 p 表示的区间 [l,r] 上查询 [L,R],返回区间信息
    Info query(int p, int l, int r, int L, int R) const
    {
        if (R < l || r < L)
            return Info::empty(); // 无交集返回空信息
        if (L <= l && r <= R)
            return tr[p].info; // 完全覆盖直接返回
        int m = (l + r) >> 1;
        return query(tr[p].ls, l, m, L, R) + query(tr[p].rs, m + 1, r, L, R);
    }
};

可持久化 Trie

算法介绍

可持久化 Trie 通过“复制外加复用未修改分支”的方式,为每次插入生成一个新根,从而得到版本化的字典树。典型是二进制 0/1 Trie,用来支持“在某个版本内求区间最大异或”等查询。每次插入的增量空间为 O(bit),查询为 O(bit),bit 是键的位数或字符集深度。所有查询都在指定版本的根上进行。

常见例题

题目:给定 q 次操作,其中 insert v 创建一个新版本并在该版本中插入整数 v,query ver x 要求在版本 ver 的集合内求与 x 的最大异或值。

做法:采用 0/1 可持久化 Trie,版本 i 的根源于版本 i−1 的根,只在经过的位路径上复制新结点。最大异或查询时从高位到低位优先选择与 x 当前位相反的分支,如果不存在则走相同分支。

代码

// 可持久化 01-Trie,按位插入与查询最大异或,支持版本化根
struct Persistent01Trie
{
    struct Node
    {
        int ch[2];
        int cnt;
        Node() : ch{0, 0}, cnt(0) {}
    };
    int W;            // 位宽(例如31表示非负int)
    vector<Node> tr;  // 结点池
    vector<int> root; // 各版本根
    Persistent01Trie(int bitWidth = 31) : W(bitWidth) { tr.reserve(1 << 20), root.push_back(newNode()); }

    // 在版本ver的基础上插入值x,返回新版本编号
    int insert(int ver, int x)
    {
        int np = copyPath(root[ver], x);
        root.push_back(np);
        return (int)root.size() - 1;
    }

    // 在版本ver上查询与x的最大异或
    int maxXor(int ver, int x) const
    {
        int p = root[ver], res = 0;
        for (int i = W; i >= 0; i--)
        {
            int b = (x >> i) & 1, t = b ^ 1, nxt = tr[p].ch[t];
            if (nxt && tr[nxt].cnt)
                p = nxt, res |= (1 << i);
            else
                p = tr[p].ch[b];
        }
        return res;
    }

private:
    int newNode()
    {
        tr.emplace_back();
        return (int)tr.size() - 1;
    }

    // 基于旧根op复制插入x路径,返回新根
    int copyPath(int op, int x)
    {
        int rp = newNode(), p = rp;
        tr[p] = tr[op], tr[p].cnt++;
        for (int i = W; i >= 0; i--)
        {
            int b = (x >> i) & 1;
            int old = tr[op].ch[b];
            tr[p].ch[b] = newNode();
            tr[tr[p].ch[b]] = old ? tr[old] : Node();
            p = tr[p].ch[b], tr[p].cnt++;
            op = old;
        }
        return rp;
    }
};

莫队

朴素莫队

算法介绍

把所有区间询问离线,按照块排序规则重排后,用一条指针链在数组上滑动维护当前答案。当前区间 [L,R] 与目标区间 [l,r] 的差异通过增删元素完成状态转移,整体复杂度约为 O((n+q)√n),常数小、实现稳定,特别适合可加可删、难以在线维护的数据统计类问题。本模板与现有 v3.2 的风格保持一致,命名、注释与 I/O 习惯完全对齐。

常见例题

题目:给定长度为 n 的数组与 q 次询问,询问区间 [l,r] 内不同数的个数。

做法:在增删函数中用计数桶维护每个值出现次数,增加某个值从 0 到 1 时答案加一,减少某个值从 1 到 0 时答案减一,其他情况不变,按照莫队顺序依次处理即可。

代码
// 朴素莫队(闭区间 [l,r] 语义)
// 功能:离线重排区间询问,通过增删元素维护答案;支持任意可“增删”的统计
// 复杂度:排序 O(q log q),移动指针与增删 O((n+q) * 块长);常取块长 ≈ n/√q 或 √n
struct Mo
{
    struct Query
    {
        int l, r, id;                    // 闭区间 [l,r] 与原始编号
        Query() {}
        Query(int l_, int r_, int id_) : l(l_), r(r_), id(id_) {}
    };

    int n;                                // 数组长度
    int blk;                              // 分块大小
    vector<int> a;                        // 原数组(1-indexed 更顺手,本模板用 1..n)
    vector<Query> qs;                     // 询问列表

    // 构造与初始化,支持空构造或直接给定数组
    Mo() : n(0), blk(1) {}
    Mo(const vector<int> &arr) { init(arr); }

    // 设置数组并初始化块长
    void init(const vector<int> &arr)
    {
        n = (int)arr.size() - 1;          // 约定arr从1开始;若从0开始请在外部补一个哨兵
        a = arr;
        blk = max(1, (int)sqrt(max(1, n)));
        qs.clear();
    }

    // 添加一个闭区间询问
    void addQuery(int l, int r, int id) { qs.emplace_back(l, r, id); }

    // 主过程,传入四个操作与一个取答函数
    // addL(x):把下标 x 插入到当前区间左侧;addR(x):把下标 x 插入到当前区间右侧
    // delL(x):把下标 x 从当前区间左侧删除;delR(x):把下标 x 从当前区间右侧删除
    // getAns():返回当前区间的答案
    template<class AddL, class AddR, class DelL, class DelR, class Get>
    vector<ll> solve(AddL addL, AddR addR, DelL delL, DelR delR, Get getAns)
    {
        vector<int> ord(qs.size());
        iota(ord.begin(), ord.end(), 0);

        auto key = [&](int x) -> pair<int,int>
        {
            int b = qs[x].l / blk;
            int t = qs[x].r;
            if (b & 1) t = -t;
            return {b, t};
        };

        sort(ord.begin(), ord.end(), [&](int i, int j)
        {
            auto ki = key(i), kj = key(j);
            if (ki.first != kj.first) return ki.first < kj.first;
            return ki.second < kj.second;
        });

        vector<ll> ans(qs.size());
        int L = 1, R = 0;

        for (int idx : ord)
        {
            int l = qs[idx].l, r = qs[idx].r;

            while (L > l) addL(--L);
            while (R < r) addR(++R);
            while (L < l) delL(L++);
            while (R > r) delR(R--);

            ans[qs[idx].id] = getAns();
        }
        return ans;
    }
};

// 使用示例:区间不同数个数
// 说明:cnt[val] 统计 val 在当前 [L,R] 出现次数,cur 为当前不同数个数
// 注意:请先把原数组压缩到 [1..M],避免值域过大
struct MoDistinctDemo
{
    vector<int> a;                 // 1..n
    vector<int> cnt;               // 值域计数
    ll cur;                 // 当前答案(不同数个数)

    MoDistinctDemo() : cur(0) {}

    void attachArray(const vector<int> &arr, int maxVal)
    {
        a = arr; cnt.assign(maxVal + 1, 0); cur = 0;
    }

    // 把下标 x 的元素加入左侧
    void addL(int x)
    {
        int v = a[x];
        if (cnt[v] == 0) cur++;
        cnt[v]++;
    }

    // 把下标 x 的元素加入右侧
    void addR(int x)
    {
        int v = a[x];
        if (cnt[v] == 0) cur++;
        cnt[v]++;
    }

    // 把下标 x 的元素从左侧删除
    void delL(int x)
    {
        int v = a[x];
        cnt[v]--;
        if (cnt[v] == 0) cur--;
    }

    // 把下标 x 的元素从右侧删除
    void delR(int x)
    {
        int v = a[x];
        cnt[v]--;
        if (cnt[v] == 0) cur--;
    }

    ll getAns() { return cur; }
};

带修莫队

算法介绍

在朴素莫队的两维移动之外再引入“时间维”,把修改操作当作第三维度的离线事件。每个询问记录它发生时刻之前已生效的修改次数 t,排序按块顺序比较 (l/blk, r/blk, t/blk)。维护当前 [L,R] 与当前时刻 T,两端指针通过增删移动,时间指针通过“应用一次修改”或“撤销一次修改”在 T 与目标 t 之间移动即可。复杂度常见取块长为 n^(2/3) 达到 O(n^(5/3)) 左右的实践效果。

常见例题

题目:给定数组初值,包含两类操作,类型一是把位置 x 的值改为 v,类型二是询问区间 [l,r] 内不同数的个数,要求按照输入顺序输出所有询问答案。

做法:将所有操作离线,询问带上当前已执行的修改数 t,排序后用三指针维护当前 L、R、T,时间前进就把修改位置的旧值替换为新值并对其是否位于 [L,R] 做一次增删,时间后退则反向应用。

代码
// 带修莫队(闭区间 [l,r] 语义;时间维支持“点修改”)
// 功能:同时维护 [L,R] 与时间 T;支持“应用修改/撤销修改”的回滚式增删
struct MoWithUpdate
{
    struct Query
    {
        int l, r, t, id;                   // 询问区间与发生时的修改数 t
        Query() {}
        Query(int l_, int r_, int t_, int id_) : l(l_), r(r_), t(t_), id(id_) {}
    };

    struct Update
    {
        int pos, pre, now;                 // 把 pos 处的值从 pre 改为 now
        Update() {}
        Update(int pos_, int pre_, int now_) : pos(pos_), pre(pre_), now(now_) {}
    };

    int n;                                 // 数组长度
    int blk;                               // 分块大小(常取 n^(2/3))
    vector<int> a;                         // 当前数组(1..n)
    vector<Query> qs;                      // 询问集
    vector<Update> us;                     // 修改集

    MoWithUpdate() : n(0), blk(1) {}
    MoWithUpdate(const vector<int> &arr) { init(arr); }

    void init(const vector<int> &arr)
    {
        n = (int)arr.size() - 1;
        a = arr;
        blk = max(1, (int)pow(max(1, n), 2.0 / 3.0));
        qs.clear(); us.clear();
    }

    void addQuery(int l, int r, int t, int id) { qs.emplace_back(l, r, t, id); }
    void addUpdate(int pos, int pre, int now) { us.emplace_back(pos, pre, now); }

    template<class Add, class Del, class Get>
    vector<ll> solve(Add addPos, Del delPos, Get getAns)
    {
        vector<int> ord(qs.size());
        iota(ord.begin(), ord.end(), 0);

        auto key = [&](int i) -> tuple<int,int,int>
        {
            int b1 = qs[i].l / blk, b2 = qs[i].r / blk, b3 = qs[i].t / blk;
            return make_tuple(b1, b2, b3);
        };

        sort(ord.begin(), ord.end(), [&](int i, int j){ return key(i) < key(j); });

        vector<ll> ans(qs.size());
        int L = 1, R = 0, T = 0;

        auto apply = [&](int k, bool forward)
        {
            int p = us[k].pos;
            int valFrom = forward ? us[k].pre : us[k].now;
            int valTo = forward ? us[k].now : us[k].pre;
            if (L <= p && p <= R) delPos(p);        // 先移除旧值的贡献
            a[p] = valTo;
            if (L <= p && p <= R) addPos(p);        // 再加入新值的贡献
        };

        for (int idx : ord)
        {
            int l = qs[idx].l, r = qs[idx].r, t = qs[idx].t;

            while (T < t) apply(T++, true);
            while (T > t) apply(--T, false);
            while (L > l) addPos(--L);
            while (R < r) addPos(++R);
            while (L < l) delPos(L++);
            while (R > r) delPos(R--);

            ans[qs[idx].id] = getAns();
        }
        return ans;
    }
};

// 使用示例:区间不同数个数(与朴素莫队相同的增删逻辑)
// 说明:对时间应用时,只在修改位置恰好处于 [L,R] 时做一次“先删后加”的切换
struct MoUpdDistinctDemo
{
    vector<int> a;                // 由外部 MoWithUpdate::a 引用或同步
    vector<int> cnt;
    ll cur;

    void attachArray(const vector<int> &arr, int maxVal)
    {
        a = arr; cnt.assign(maxVal + 1, 0); cur = 0;
    }
    void addPos(int i)
    {
        int v = a[i];
        if (cnt[v] == 0) cur++;
        cnt[v]++;
    }
    void delPos(int i)
    {
        int v = a[i];
        cnt[v]--;
        if (cnt[v] == 0) cur--;
    }
    ll getAns() { return cur; }
};

树上莫队

算法介绍

把树做一次欧拉序或 dfs 序展开到序列上,再把每条路径询问映射为序列区间。令 tin[u] 表示 u 第一次被访问的时间戳,tout[u] 表示离开时间戳。常见做法是“每个节点出现两次”的欧拉序,与“访问一次即切换一次选中状态”的思想配合,路径 u 到 v 可以转化为区间 [tin[u], tin[v]] 的若干“开关”加上额外处理的最近公共祖先 lca(u,v)。当指针移动经过某个结点的时间戳时,若该结点此前未被选中则加入贡献,否则移除贡献。复杂度与朴素莫队同阶,块长通常取 √(2n) 以匹配欧拉序的长度。

常见例题

题目:给定一棵带颜色的树,有 q 次询问,每次询问路径 [u,v] 上不同颜色的个数。

做法:对树做欧拉序并预处理二进制 LCA。把每次询问规范化为 tin[u] 与 tin[v] 的区间并记录 lca,莫队增删时对经过的结点进行“开关”,当一个结点从未选中切为选中时把其颜色计数加一并可能带来答案变化,反之移除贡献;若 lca 不在区间内,最后再特判把 lca 的颜色临时加入一次统计得到答案。

代码
// 树上莫队(闭区间 [l,r] 语义,欧拉序长度为 2n)
// 功能:把树上路径询问转化为序列莫队;支持“开关式”可加可删的统计
struct TreeMo
{
    int n, lg;                               // n 个点,lg 为 LCA 对数深度
    vector<vector<int>> g;                   // 邻接表
    vector<int> color;                       // 每个结点的颜色(或值),1..n
    vector<int> euler, in, out, ver;         // euler 为 1..2n 的时间戳序列;ver[t] = 该时间戳对应结点
    vector<int> depth;                       // 深度
    vector<vector<int>> up;                  // LCA 二进制跳表

    // 统计部分
    vector<int> vis;                         // 结点是否已被“选中”
    vector<int> cnt;                         // 颜色出现次数
    ll cur;                           // 当前答案

    // 构造与初始化
    TreeMo() : n(0), lg(0), cur(0) {}
    TreeMo(int n_) { init(n_); }

    // 初始化图结构与容器
    void init(int n_)
    {
        n = n_;
        g.assign(n + 1, {});
        color.assign(n + 1, 0);
        in.assign(n + 1, 0);
        out.assign(n + 1, 0);
        depth.assign(n + 1, 0);
        ver.assign(2 * n + 2, 0);
        euler.assign(2 * n + 2, 0);
        lg = __lg(max(1, n)) + 1;
        up.assign(lg, vector<int>(n + 1, 0));
        vis.assign(n + 1, 0);
        cnt.clear();
        cur = 0;
    }

    // 加边
    void addEdge(int u, int v)
    {
        g[u].push_back(v), g[v].push_back(u);
    }

    // 预处理欧拉序与 LCA,根默认 1
    void build(int root = 1)
    {
        int timer = 0;
        auto dfs = [&](auto &&self, int u, int p) -> void
        {
            up[0][u] = p;
            for (int k = 1; k < lg; k++) up[k][u] = up[k - 1][up[k - 1][u]];
            in[u] = ++timer; ver[timer] = u; euler[timer] = u;
            for (int v : g[u]) if (v != p) depth[v] = depth[u] + 1, self(self, v, u);
            out[u] = ++timer; ver[timer] = u; euler[timer] = u;
        };
        dfs(dfs, root, 0);
    }

    // LCA
    int lca(int a, int b)
    {
        if (depth[a] < depth[b]) swap(a, b);
        int d = depth[a] - depth[b];
        for (int k = 0; k < lg; k++) if (d >> k & 1) a = up[k][a];
        if (a == b) return a;
        for (int k = lg - 1; k >= 0; k--) if (up[k][a] != up[k][b]) a = up[k][a], b = up[k][b];
        return up[0][a];
    }

    // 开关一个结点 u 的贡献
    void toggle(int u)
    {
        int c = color[u];
        if (!vis[u]) { vis[u] = 1; if ((int)cnt.size() <= c) cnt.resize(c + 1, 0); if (cnt[c] == 0) cur++; cnt[c]++; }
        else { vis[u] = 0; cnt[c]--; if (cnt[c] == 0) cur--; }
    }

    // 询问结构
    struct Query
    {
        int l, r, id, w;                     // 映射到欧拉序的闭区间 [l,r];w = lca(u,v),若 w 在 [l,r] 外需特判
        Query() {}
        Query(int l_, int r_, int id_, int w_) : l(l_), r(r_), id(id_), w(w_) {}
    };

    // 把树上路径 (u,v) 转化为欧拉序区间
    Query makeQuery(int u, int v, int id)
    {
        if (in[u] > in[v]) swap(u, v);
        int w = lca(u, v);
        if (w == u) return Query(in[u], in[v], id, 0);
        return Query(out[u], in[v], id, w);
    }

    // 求解所有查询,返回答案数组
    vector<ll> solve(vector<Query> qs)
    {
        int m = (int)qs.size();
        int B = max(1, (int)sqrt(2 * n));
        sort(qs.begin(), qs.end(), [&](const Query &A, const Query &Bq)
        {
            int blA = A.l / B, blB = Bq.l / B;
            if (blA != blB) return blA < blB;
            int brA = A.r / B, brB = Bq.r / B;
            if (brA != brB) return brA < brB;
            return A.w < Bq.w;
        });

        vector<ll> ans(m);
        int L = 1, R = 0;

        auto move = [&](int x)
        {
            int u = ver[x];
            toggle(u);
        };

        for (auto &q : qs)
        {
            while (L > q.l) move(--L);
            while (R < q.r) move(++R);
            while (L < q.l) move(L++);
            while (R > q.r) move(R--);

            ll now = cur;
            if (q.w) { toggle(q.w); now = cur; toggle(q.w); }
            ans[q.id] = now;
        }
        return ans;
    }
};

// 使用示例说明:
// 1) 读入树与颜色,调用 init(n)、addEdge、build(root);
// 2) 对每个询问(u,v) 调用 makeQuery(u,v,id) 收集为数组;
// 3) 调用 solve(queries) 获得答案数组;
// 4) 若统计目标不是“不同颜色个数”,仅需在 toggle 中按目标逻辑增删贡献即可。

杂项

跳表

算法介绍

跳表是带“多级索引”的有序链表。每个节点在若干层出现,上层稀疏、下层稠密。查找从最高层逐层下降,遇到超越目标就下移一层继续,期望时间复杂度 O(log n),插入与删除沿搜索路径回溯并按抛硬币或固定随机高度将新节点接入各层。它常被用作比赛中的有序集合备胎,在无法使用平衡树、或者需要自己掌控比较与持久化细节的题目中很便利。

常见例题

题目:维护一个动态集合,支持插入一个键、删除一个键、查询一个键是否存在、查询严格小于 x 的最大元素。做法:用跳表维护严格递增的键序列。插入时按搜索路径记录每层的“前驱”,随机高度后把新节点挂在各层前驱之后。删除时同理断链。查询前驱时从顶层右移直到下一个大于等于 x,再落到下一层反复,最终在最底层拿到严格小于 x 的位置。

代码
// 跳表(Skip List),支持 insert / erase / contains / predecessor
// 说明:模板参数K为可比较的键,严格弱序(operator<);随机层高采用几何分布
template<typename K>
struct SkipList
{
    struct Node
    {
        K key;                   // 键
        vector<Node*> nxt;       // 各层后继指针
        Node(int h, const K &k): key(k), nxt(h, nullptr) {}  // 构造:给定高度与键
    };

    int maxLevel;                // 允许的最大层数
    double prob;                 // 上升到更高一层的概率
    Node *head;                  // 头哨兵(-inf)
    mt19937_64 rng;              // 随机数
    uniform_real_distribution<double> uni;

    // 构造:指定最大层与晋升概率p
    SkipList(int maxLevel_ = 20, double p_ = 0.5)
        : maxLevel(maxLevel_), prob(p_), head(new Node(maxLevel, K{})), rng(chrono::steady_clock::now().time_since_epoch().count()), uni(0.0, 1.0) {}

    // 产生一个随机高度(至少1)
    int randomLevel()
    {
        int h = 1;
        while (h < maxLevel && uni(rng) < prob) h++;
        return h;
    }

    // 查找严格小于x的前驱结点数组,用于插入/删除
    void findPrev(const K &x, vector<Node*> &prev)
    {
        prev.assign(maxLevel, nullptr);
        Node *p = head;
        for (int lvl = maxLevel - 1; lvl >= 0; lvl--)
        {
            while (p->nxt[lvl] && p->nxt[lvl]->key < x) p = p->nxt[lvl];
            prev[lvl] = p;
        }
    }

    // 是否包含键x
    bool contains(const K &x)
    {
        Node *p = head;
        for (int lvl = maxLevel - 1; lvl >= 0; lvl--)
            while (p->nxt[lvl] && p->nxt[lvl]->key < x) p = p->nxt[lvl];
        p = p->nxt[0];
        return p && !(x < p->key) && !(p->key < x);
    }

    // 插入键x,若已存在则不重复插
    bool insert(const K &x)
    {
        vector<Node*> prev; findPrev(x, prev);
        Node *q = prev[0]->nxt[0];
        if (q && !(x < q->key) && !(q->key < x)) return false;
        int h = randomLevel(); Node *cur = new Node(h, x);
        for (int i = 0; i < h; i++) cur->nxt[i] = prev[i]->nxt[i], prev[i]->nxt[i] = cur;
        return true;
    }

    // 删除键x,若不存在返回false
    bool erase(const K &x)
    {
        vector<Node*> prev; findPrev(x, prev);
        Node *q = prev[0]->nxt[0];
        if (!q || (x < q->key) || (q->key < x)) return false;
        for (int i = 0; i < (int)q->nxt.size(); i++) prev[i]->nxt[i] = q->nxt[i];
        delete q; return true;
    }

    // 返回严格小于x的最大元素是否存在并写入out
    bool predecessor(const K &x, K &out)
    {
        Node *p = head;
        for (int lvl = maxLevel - 1; lvl >= 0; lvl--)
            while (p->nxt[lvl] && p->nxt[lvl]->key < x) p = p->nxt[lvl];
        if (p == head) return false;
        out = p->key; return true;
    }
};

ST表

算法介绍

ST 表(Sparse Table)在静态数组上实现 O(1) RMQ 或可幂等函数的区间查询。预处理 f[i][k] 表示区间 [i,i+2^k-1] 的信息,合并时用幂等函数(如 min、max、gcd)。查询 [l,r] 时令 k=⌊log2(r−l+1)⌋,答案由 f[l][k] 与 f[r−2^k+1][k] 合并即可。由于只做一遍静态预处理,适合没有修改、查询极多的场景。

常见例题

题目:给定数组 a 与 q 次询问,每次询问区间 [l,r] 的最小值与最大值。做法:以 min 和 max 为幂等函数分别建两张 ST 表,查询时按区间长度取 k,并在两张表各做一次合并即可返回答案。

代码

// ST 表(Sparse Table)
// 说明:Info 需要提供默认构造与从值构造,以及 operator+ 作为“区间合并”
//      对于 RMQ,Info::operator+ 可写成 min/max;对于 gcd/按位与或,同理
// 复杂度:预处理 O(n log n),单次查询 O(1)
template <typename Info>
struct SparseTable
{
    int n;                                  // 原数组长度
    int maxK;                               // 最高位 k,使得 2^k <= n
    vector<int> lg;                         // 预存 log2
    vector<vector<Info>> f;                 // f[k][i] 表示区间 [i, i+2^k-1] 的信息

    // 构造函数:空构造
    SparseTable(): n(0), maxK(0) {}

    // 构造函数:从值数组 a 构建(a 为 0-based)
    template <typename T>
    SparseTable(const vector<T> &a) { init(a); }

    // 功能:从值数组 a 初始化,完成整张稀疏表的预处理
    template <typename T>
    void init(const vector<T> &a)
    {
        n = (int)a.size();
        if (!n) { maxK = 0; f.clear(); lg.clear(); return; }
        maxK = __lg(n);
        lg.resize(n + 1);
        lg[1] = 0;
        for (int i = 2; i <= n; i++) lg[i] = lg[i >> 1] + 1;

        f.assign(maxK + 1, vector<Info>(n));
        for (int i = 0; i < n; i++) f[0][i] = Info(a[i]);

        for (int k = 1; k <= maxK; k++)
        {
            int len = 1 << k, half = len >> 1, lim = n - len + 1;
            for (int i = 0; i < lim; i++) f[k][i] = f[k - 1][i] + f[k - 1][i + half];
        }
    }

    // 功能:区间查询,返回闭区间 [l,r] 的 Info 聚合
    Info query(int l, int r)
    {
        assert(0 <= l && l <= r && r < n);
        int k = lg[r - l + 1], len = 1 << k;
        return f[k][l] + f[k][r - len + 1];
    }
};

// 下面给出常用的三个 Info 示例,均以 ll 为底层值。
// 你可以按需替换底层类型或运算,并与 SparseTable 一起直接使用。

// 区间最小值
struct InfoMin
{
    ll v;
    InfoMin(ll x = (ll)4e18): v(x) {}
    friend InfoMin operator+(const InfoMin &a, const InfoMin &b) { return InfoMin(min(a.v, b.v)); }
};

// 区间最大值
struct InfoMax
{
    ll v;
    InfoMax(ll x = (ll)-4e18): v(x) {}
    friend InfoMax operator+(const InfoMax &a, const InfoMax &b) { return InfoMax(max(a.v, b.v)); }
};

// 区间 gcd(注意 0 的幺元语义:gcd(x,0)=x)
struct InfoGcd
{
    ll v;
    InfoGcd(ll x = 0): v(x) {}
    static ll g(ll a, ll b)
    {
        while (b) a %= b, swap(a, b);
        return a;
    }
    friend InfoGcd operator+(const InfoGcd &a, const InfoGcd &b) { return InfoGcd(g(a.v, b.v)); }
};

// 使用示例:
// vector<ll> a = {...};
// SparseTable<InfoMin> stMin(a); auto mn = stMin.query(l, r).v;
// SparseTable<InfoMax> stMax(a); auto mx = stMax.query(l, r).v;
// SparseTable<InfoGcd> stGcd(a); auto gd = stGcd.query(l, r).v;


线性基

算法介绍

线性基用于在 GF(2) 上维护一组数的异或线性独立基。插入一个数时按最高位到最低位试图“消去”已有基向量;若最终非零则作为新基放入。它能在 O(log A) 时间内支持插入与询问可得到的最大异或和;若额外维护“可取到零”的标记,还可回答是否能异或成给定值。比赛中常用于最大异或对、与图上“环异或”相关的可达值分析。

常见例题

题目:给一个数组,支持插入一个数与询问当前集合能得到的最大异或和。做法:线性基顺序插入,查询时从高位到低位尝试贪心改进答案即可。若题目要求“是否能异或得到某值”,则在插入阶段不丢弃与现有基完全消去成零的元素,并把 canZero 置真,判断时先用基消去目标,再看是否变成零或 canZero 为真。

代码
// 线性基(Binary Linear Basis),支持插入与最大异或和查询,可选支持判断可达性
template<typename T = unsigned ll, int MAXL = 63>
struct LinearBasis
{
    array<T, MAXL + 1> b;   // b[i] 存放最高位在 i 的基向量
    bool canZero;           // 是否可以异或出 0(存在冗余向量)

    // 构造:清空基
    LinearBasis() { clear(); }

    // 清空
    void clear()
    {
        b.fill(0);
        canZero = false;
    }

    // 插入一个向量 x;返回是否扩大了线性空间
    bool insert(T x)
    {
        for (int i = MAXL; i >= 0; i--)
        {
            if (!((x >> i) & 1)) continue;
            if (!b[i]) { b[i] = x; return true; }
            x ^= b[i];
        }
        canZero = true; return false;
    }

    // 查询能得到的最大异或和
    T maxXor()
    {
        T ans = 0;
        for (int i = MAXL; i >= 0; i--) if ((ans ^ b[i]) > ans) ans ^= b[i];
        return ans;
    }

    // 判断是否能异或得到给定值 x
    bool canMake(T x)
    {
        for (int i = MAXL; i >= 0; i--) if ((x >> i) & 1) x ^= b[i] ? b[i] : 0;
        return x == 0 || canZero;
    }

    // 合并另一个线性基(把对方所有基向量插入本基)
    void mergeFrom(const LinearBasis &o)
    {
        for (int i = MAXL; i >= 0; i--) if (o.b[i]) insert(o.b[i]);
        if (o.canZero) canZero = true;
    }
};

Sqrt Tree

算法介绍

Sqrt Tree 是一种利用“分治 + 分块 + 预处理”思想的数据结构,用来在接近 O(1) 的时间内回答 区间可结合运算 查询。其核心做法是:

  1. 把数组分层,每一层的块长大约是上层块长的平方根。
  2. 每层块内预处理所有前缀和后缀结果,用于回答跨块的子段查询。
  3. 在层与层之间,预处理“整块到整块”的聚合表(between),用于快速合并中间整块。
  4. 查询 [l,r] 时,找到合适的层,使得 [l,r] 覆盖了至少两个完整块。答案由三部分组成:左端块的后缀、右端块的前缀、中间整块的聚合。
  5. 更新时,修改元素所在的块并在各层递归重建相关的前缀/后缀和 between 表。

整个结构对区间的查询近似 O(1),单点修改大约 O(√n)。其优势是常数小,实现纯静态时几乎是最优的区间 RMQ/区间和解法之一。所有区间均视作闭区间 [l,r]。

常见例题

题目:给定长度为 n 的数组 a,q 次操作:
1 l r —— 询问区间 [l,r] 的最小值
2 x v —— 把位置 x 的元素改为 v

做法: 定义 struct Info { int val; … }operator+ 定义为取两个 Info 的最小值,Info() 的默认构造返回 +∞ 作为幺元。构建 SqrtTree 后,每次操作一直接调用 query(l,r) 得到区间最小值;操作二调用 update(x,Info(v)) 完成单点修改。由于查询接近 O(1),更新 O(√n),整体能轻松应对 q,n=10^5 级别的数据规模。

代码
// Sqrt Tree(闭区间 [l,r] 查询 + 单点修改)
struct Info
{
    ll val; // 示例:区间和;若你要改成 min/max/gcd,请相应修改默认构造与 operator+
    Info(ll v = 0) : val(v) {}
    friend Info operator+(const Info &a, const Info &b) { return Info(a.val + b.val); }
};

struct SqrtTree
{
private:
    int n;       // 原数组长度
    int lg;      // 2^lg >= n 的最小 lg
    int indexSz; // 顶层索引层大小(把每块代表拼接到尾部)

    vector<Info> v;      // 底层数组 + 索引层拼接后的数组,长度 n + indexSz
    vector<int> clz;     // 用来 O(1) 定位最高位的预处理表
    vector<int> layers;  // 每一层的对数规模(块大小的 log2)
    vector<int> onLayer; // 把某个 log 值映射到所属层下标

    vector<vector<Info>> pref;    // pref[layer][i]:该层内块的 L..i 前缀聚合
    vector<vector<Info>> suf;     // suf[layer][i] :该层内块的 i..R 后缀聚合
    vector<vector<Info>> between; // between[layer-1][ofs + block(i,j)]:整块 i..j 的聚合

    // 功能:在指定层 layer 的一个块 [l,r) 上构建块内前/后缀
    void buildBlock(int layer, int l, int r)
    {
        pref[layer][l] = v[l];
        for (int i = l + 1; i < r; i++)
            pref[layer][i] = pref[layer][i - 1] + v[i];
        suf[layer][r - 1] = v[r - 1];
        for (int i = r - 2; i >= l; i--)
            suf[layer][i] = v[i] + suf[layer][i + 1];
    }

    // 功能:为除顶层以外的层构建“层间整块聚合”表
    void buildBetween(int layer, int lBound, int rBound, int betweenOffs)
    {
        int bSzLog = (layers[layer] + 1) >> 1;
        int bCntLog = layers[layer] >> 1;
        int bSz = 1 << bSzLog;
        int bCnt = (rBound - lBound + bSz - 1) >> bSzLog;
        for (int i = 0; i < bCnt; i++)
        {
            Info ans = suf[layer][lBound + (i << bSzLog)];
            between[layer - 1][betweenOffs + lBound + (i << bCntLog) + i] = ans;
            for (int j = i + 1; j < bCnt; j++)
            {
                Info add = suf[layer][lBound + (j << bSzLog)];
                ans = ans + add;
                between[layer - 1][betweenOffs + lBound + (i << bCntLog) + j] = ans;
            }
        }
    }

    // 功能:为顶层构建索引层,把每个块的“块首后缀”抽到 v 尾部
    void buildBetweenZero()
    {
        int bSzLog = (lg + 1) >> 1;
        for (int i = 0; i < indexSz; i++)
            v[n + i] = suf[0][i << bSzLog];
        build(1, n, n + indexSz, (1 << lg) - n);
    }

    // 功能:当顶层某个块变更时,刷新其对应的索引元素并沿层维护
    void updateBetweenZero(int bid)
    {
        int bSzLog = (lg + 1) >> 1;
        v[n + bid] = suf[0][bid << bSzLog];
        update(1, n, n + indexSz, (1 << lg) - n, n + bid);
    }

    // 功能:递归构建层 layer 管辖区间 [lBound,rBound),并继续对子层构建
    void build(int layer, int lBound, int rBound, int betweenOffs)
    {
        if (layer >= (int)layers.size())
            return;
        int bSz = 1 << ((layers[layer] + 1) >> 1);
        for (int l = lBound; l < rBound; l += bSz)
        {
            int r = min(l + bSz, rBound);
            buildBlock(layer, l, r);
            build(layer + 1, l, r, betweenOffs);
        }
        if (layer == 0)
            buildBetweenZero();
        else
            buildBetween(layer, lBound, rBound, betweenOffs);
    }

    // 功能:递归更新,将位置 x 所在块的块内信息与层间整块聚合重建
    void update(int layer, int lBound, int rBound, int betweenOffs, int x)
    {
        if (layer >= (int)layers.size())
            return;
        int bSzLog = (layers[layer] + 1) >> 1;
        int bSz = 1 << bSzLog;
        int blockIdx = (x - lBound) >> bSzLog;
        int l = lBound + (blockIdx << bSzLog);
        int r = min(l + bSz, rBound);
        buildBlock(layer, l, r);
        if (layer == 0)
            updateBetweenZero(blockIdx);
        else
            buildBetween(layer, lBound, rBound, betweenOffs);
        update(layer + 1, l, r, betweenOffs, x);
    }

    // 功能:内部查询,闭区间 [l,r],betweenOffs 为层间偏移,base 为该层基底
    Info query(int l, int r, int betweenOffs, int base)
    {
        if (l == r)
            return v[l];
        if (l + 1 == r)
            return v[l] + v[r];
        int layer = onLayer[clz[(l - base) ^ (r - base)]];
        int bSzLog = (layers[layer] + 1) >> 1;
        int bCntLog = layers[layer] >> 1;
        int lBound = (((l - base) >> layers[layer]) << layers[layer]) + base;
        int lBlock = ((l - lBound) >> bSzLog) + 1;
        int rBlock = ((r - lBound) >> bSzLog) - 1;

        Info ans = suf[layer][l];
        if (lBlock <= rBlock)
        {
            Info mid = (layer == 0)
                           ? query(n + lBlock, n + rBlock, (1 << lg) - n, n)
                           : between[layer - 1][betweenOffs + lBound + (lBlock << bCntLog) + rBlock];
            ans = ans + mid;
        }
        ans = ans + pref[layer][r];
        return ans;
    }

public:
    // 构造函数:用底层数组构建整棵 Sqrt Tree
    SqrtTree(const vector<Info> &a)
        : n((int)a.size()), lg(log2Up(max(1, n))), v(a), clz(1 << lg), onLayer(lg + 1)
    {
        clz[0] = 0;
        for (int i = 1; i < (int)clz.size(); i++)
            clz[i] = clz[i >> 1] + 1;

        int tlg = lg;
        while (tlg > 1)
        {
            onLayer[tlg] = (int)layers.size();
            layers.push_back(tlg);
            tlg = (tlg + 1) >> 1;
        }
        for (int i = lg - 1; i >= 0; i--)
            onLayer[i] = max(onLayer[i], onLayer[i + 1]);

        int betweenLayers = max(0, (int)layers.size() - 1);
        int bSzLog = (lg + 1) >> 1;
        int bSz = 1 << bSzLog;
        indexSz = (n + bSz - 1) >> bSzLog;

        v.resize(n + indexSz);
        pref.assign(layers.size(), vector<Info>(n + indexSz));
        suf.assign(layers.size(), vector<Info>(n + indexSz));
        between.assign(betweenLayers, vector<Info>((1 << lg) + bSz));

        build(0, 0, n, 0);
    }

    // 功能:闭区间查询,返回 [l,r] 上的聚合
    Info query(int l, int r)
    {
        if (l > r)
            return Info();
        return query(l, r, 0, 0);
    }

    // 功能:单点修改,把 v[x] 改为 item,并维护相关层
    void update(int x, const Info &item)
    {
        v[x] = item;
        update(0, 0, n, 0, x);
    }

    // 工具函数:最小 lg 使 2^lg >= m
    static int log2Up(int m)
    {
        int res = 0;
        while ((1 << res) < m)
            res++;
        return res;
    }
};


三、字符串

字符串哈希

朴素字符串哈希

算法介绍

用多项式滚动哈希把字符串映射到整数域,前缀哈希配合幂数组在 O(1) 时间得到任意子串 [l,r] 的哈希值;采用 64 位无符号整数“自然溢出”模拟取模,常数极小,实战中足够稳健。为适配比赛需求,构造时一次性预处理前缀与幂,查询全是闭区间 [l,r]。本节实现对齐你 v3.2 的思路并统一命名与注释规范。

常见例题

题目:给定字符串 s,多次询问子串 [l1,r1] 与 [l2,r2] 是否相等。

做法:用本节滚动哈希,分别计算两段的哈希并比较即可,相等时高度概率为真;若需要零碰撞,可转“双模”。

代码
// 朴素字符串哈希(64位自然溢出),子串索引均为闭区间 [l,r] 且 0-based
// 功能:O(1) 获取任意子串哈希;构造时预处理前缀与幂
// 说明:默认 base=131,可在构造时自定义;若需零碰撞可配合“双模”一起判等
struct StringHash64
{
    using U64 = ull;
    int n;                 // 字符串长度
    U64 base;              // 基数
    vector<U64> pref;      // 前缀哈希:pref[i] = s[0..i-1]
    vector<U64> power;     // 幂数组:power[i] = base^i

    // 构造:给定字符串与可选基数
    StringHash64(const string &s = "", U64 b = 131) { init(s, b); }

    // 初始化或重建
    void init(const string &s, U64 b = 131)
    {
        base = b;
        n = (int)s.size();
        pref.assign(n + 1, 0);
        power.assign(n + 1, 1);
        for (int i = 1; i <= n; i++) power[i] = power[i - 1] * base;
        for (int i = 1; i <= n; i++) pref[i] = pref[i - 1] * base + (U64)(unsigned char)s[i - 1];
    }

    // 取得子串 s[l..r] 的哈希,闭区间且 0-based
    U64 getHash(int l, int r)
    {
        if (l > r) return 0;
        return pref[r + 1] - pref[l] * power[r - l + 1];
    }
};

双模字符串哈希

算法介绍

用两个不同模数与基数组合出二维哈希,查询仍是 O(1)。为避免 64 位乘法溢出,使用 __int128 做中间乘法取模;默认模数选用 1e9+7 与 1e9+9,也可自定义。构造一次,查询多次,所有子串均以闭区间 [l,r] 为准。

常见例题

题目:给定字符串 s 与 q 次询问,每次判断 [l1,r1] 与 [l2,r2] 是否相等且需要“零碰撞”。

做法:用双模哈希分别计算两个模下的子串哈希,二者都相等即视为严格相等。

代码
// 双模字符串哈希,子串索引闭区间 [l,r],0-based
// 功能:O(1) 出任意子串双哈希;支持自定义基数与模数
struct StringHash2
{
    using i128 = __int128_t;
    using i64 = long long;

    int n;                   // 字符串长度
    i64 mod1, mod2;          // 两个模数
    i64 base1, base2;        // 两个基数
    vector<i64> p1, p2;      // 幂
    vector<i64> h1, h2;      // 前缀哈希

    // 构造:可自定义基数与模数
    StringHash2(const string &s = "", i64 b1 = 131, i64 b2 = 137, i64 m1 = 1000000007LL, i64 m2 = 1000000009LL)
    {
        init(s, b1, b2, m1, m2);
    }

    // 初始化或重建
    void init(const string &s, i64 b1 = 131, i64 b2 = 137, i64 m1 = 1000000007LL, i64 m2 = 1000000009LL)
    {
        base1 = b1, base2 = b2, mod1 = m1, mod2 = m2;
        n = (int)s.size();
        p1.assign(n + 1, 1), p2.assign(n + 1, 1);
        h1.assign(n + 1, 0), h2.assign(n + 1, 0);
        for (int i = 1; i <= n; i++) p1[i] = (i128)p1[i - 1] * base1 % mod1, p2[i] = (i128)p2[i - 1] * base2 % mod2;
        for (int i = 1; i <= n; i++)
        {
            int v = (unsigned char)s[i - 1];
            h1[i] = ((i128)h1[i - 1] * base1 + v) % mod1;
            h2[i] = ((i128)h2[i - 1] * base2 + v) % mod2;
        }
    }

    // 取得子串 s[l..r] 的双哈希
    pair<i64,i64> getHash(int l, int r)
    {
        if (l > r) return {0, 0};
        i64 x1 = (h1[r + 1] - (i128)h1[l] * p1[r - l + 1]) % mod1; if (x1 < 0) x1 += mod1;
        i64 x2 = (h2[r + 1] - (i128)h2[l] * p2[r - l + 1]) % mod2; if (x2 < 0) x2 += mod2;
        return {x1, x2};
    }
};

Trie 树

朴素 Trie

算法介绍

用树形指针按字符扩展保存字符串集合,根到某结点路径唯一表示一个前缀。节点上记录经过计数与以该节点为结尾的单词数,可实现插入、精确查词与前缀计数。相比 v3.2 中用静态数组的写法,这里采用动态扩容向量,接口更安全、构造更灵活,仍保持 O(|s|) 的时间复杂度与常数优势。

常见例题

题目:多次插入小写字符串,支持查询字符串 x 是否出现过,以及查询以字符串 p 为前缀的单词数量。

做法:用本节 Trie 的 insert 插入;countWord 判断精确出现次数是否大于 0;countPrefix 返回经过计数实现前缀统计。

代码
// 朴素 Trie(小写字母),支持插入、精确查询、前缀计数
// 节点存:child[26]、pass(路过计数)、end(该处结尾的单词数)
struct Trie
{
    struct Node
    {
        int child[26];  // 下一个字符转移
        int pass;       // 经过该节点的字符串数量
        int end;        // 以该节点为结尾的字符串数量
        Node() : child{}, pass(0), end(0) {}
    };

    vector<Node> t;     // 节点池,t[0] 为根

    // 构造:可选预留容量
    Trie(int reserveSize = 1) { t.reserve(max(1, reserveSize)); t.clear(); t.emplace_back(); }

    // 清空为仅根节点
    void clear() { t.clear(); t.emplace_back(); }

    // 插入字符串 s
    void insert(const string &s)
    {
        int p = 0;
        for (char ch : s)
        {
            int c = ch - 'a';
            if (!t[p].child[c]) t[p].child[c] = newNode();
            p = t[p].child[c], t[p].pass++;
        }
        t[p].end++;
    }

    // 统计字符串 s 出现次数
    int countWord(const string &s)
    {
        int p = 0;
        for (char ch : s)
        {
            int c = ch - 'a';
            if (!t[p].child[c]) return 0;
            p = t[p].child[c];
        }
        return t[p].end;
    }

    // 统计以 s 为前缀的字符串数量
    int countPrefix(const string &s)
    {
        int p = 0;
        for (char ch : s)
        {
            int c = ch - 'a';
            if (!t[p].child[c]) return 0;
            p = t[p].child[c];
        }
        return t[p].pass + t[p].end;
    }

private:
    // 新建节点
    int newNode()
    {
        t.emplace_back();
        return (int)t.size() - 1;
    }
};

01-Trie

算法介绍

把整数按二进制位从高到低插入到二叉 Trie 中,常见操作有最大异或、计数 (x xor y) < k 的数量等。节点保存两个儿子与经过计数,插入和查询都是 O(B)(B 为比特数,常取 31 或 63)。本实现既能做 maxXor,也提供 countLessThan 用于一类计数题,与 v3.2 片段中的计数操作一致方向。

常见例题

题目:维护一个多重集合,支持插入整数 v,查询给定 x 的最大异或值,以及统计集合中有多少 y 满足 (x xor y) < k

做法:插入用 insert;最大异或用 maxXor;小于计数用 countLessThan(x,k) 自高位贪心选择可行分支累加计数。

代码
// 01-Trie(整数二进制字典树),支持插入、最大异或、(x xor y) < k 的计数
// 默认处理 31 位非负整数,可通过构造参数设置最高位
struct BinaryTrie
{
    struct Node
    {
        int nxt[2];   // 左0、右1
        int cnt;      // 经过该节点的数量
        Node() : nxt{0,0}, cnt(0) {}
    };

    vector<Node> t;   // 节点池
    int maxBit;       // 最高位(含),如 30 表示处理 [30..0]

    // 构造:给定最高位,默认 30(适配 int)
    BinaryTrie(int highest = 30) : maxBit(highest) { t.clear(); t.emplace_back(); }

    // 清空
    void clear() { t.clear(); t.emplace_back(); }

    // 插入一个数 x
    void insert(int x)
    {
        int p = 0; t[p].cnt++;
        for (int i = maxBit; i >= 0; i--)
        {
            int b = (x >> i) & 1;
            if (!t[p].nxt[b]) t[p].nxt[b] = newNode();
            p = t[p].nxt[b], t[p].cnt++;
        }
    }

    // 最大异或:在集合里找 y 使 x xor y 最大,若集合为空返回 0
    int maxXor(int x)
    {
        if (t[0].cnt == 0) return 0;
        int p = 0, ans = 0;
        for (int i = maxBit; i >= 0; i--)
        {
            int b = (x >> i) & 1, want = b ^ 1;
            int go = t[p].nxt[want] && t[t[p].nxt[want]].cnt ? want : b;
            p = t[p].nxt[go], ans |= (go ^ b) << i;
        }
        return ans;
    }

    // 计数:(x xor y) < k 的 y 的数量;若集合为空返回 0
    ll countLessThan(int x, int k)
    {
        if (t[0].cnt == 0) return 0;
        ll res = 0;
        int p = 0;
        for (int i = maxBit; i >= 0 && p; i--)
        {
            int xb = (x >> i) & 1, kb = (k >> i) & 1;
            int left = t[p].nxt[xb];          // 选择与 xb 相同的分支 => 该位异或为 0
            int right = t[p].nxt[xb ^ 1];     // 选择相反分支 => 该位异或为 1
            if (kb)                           // k 的该位为 1:可以把“该位异或为 0”的整棵都计入,并继续在“该位异或为 1”的分支上走
            {
                if (left) res += t[left].cnt;
                p = right;
            }
            else                               // k 的该位为 0:该位异或只能为 0,继续在相同分支上走
            {
                p = left;
            }
        }
        return res;
    }

private:
    int newNode()
    {
        t.emplace_back();
        return (int)t.size() - 1;
    }
};

前缀

KMP

算法介绍

KMP 算法用来在主串中查找模式串的所有出现位置。它先对模式串构造一个前缀函数(也称为 “失配表”或“最长前缀-后缀匹配长度”数组),用于在发生不匹配时,让匹配跳过已经确定的部分,从而将总时间复杂度控制在 O(n + m),其中 n 是主串长度,m 是模式串长度。

前缀函数 π[i] 表示模式串 pat[0..i] 的最长真前缀(proper prefix,也就是不包含整个字符串自身)恰好等于它本身的一个真后缀的长度。构建 π 数组时间 O(m),匹配阶段用 π 数组避开重复比较,实现线性匹配。

常见例题

题目: 主串 S 长度 n,模式串 P 长度 m,求所有 S 中 P 的出现位置。输入 S 和 P,即输出所有 i(0 ≤ i ≤ n−m)使得 S[i..i+m−1] == P。

做法: 用 KMP,先 build 前缀函数 π,对模式串 P 构造 π。然后遍历主串 S,用指针 j 表示模式当前匹配的位置,从 j = 0 开始,遍历 S 的每个字符 i。若 S[i] == P[j] 则 j++;否则用 π 表回退 j。每当 j == m 时说明匹配成功,记录 i − m + 1,然后 j = π[j−1] 继续找下一个匹配。

代码
// KMP 模板
// 功能:构造模式 pat 的前缀函数 pi 数组与在主串中查找所有匹配起始位置
// 复杂度:构造 pi O(m),查找 O(n)
struct KMPMatcher
{
    string pat;               // 模式串
    vector<int> pi;           // 前缀函数数组,长度 m

    // 构造函数,接受模式串 pat_ 初始化 pi
    KMPMatcher(const string &pat_) : pat(pat_), pi((int)pat_.size(), 0) { build(); }

    // 功能:构造前缀函数 pi
    void build()
    {
        int m = (int)pat.size();
        pi[0] = 0;
        for (int i = 1; i < m; i++)
        {
            int j = pi[i - 1];
            while (j > 0 && pat[i] != pat[j]) j = pi[j - 1];
            if (pat[i] == pat[j]) j++;
            pi[i] = j;
        }
    }

    // 功能:在主串 text 中查找所有 pat 的匹配起始位置,返回所有 0-based 的位置
    vector<int> matchPositions(const string &text)
    {
        int n = (int)text.size(), m = (int)pat.size();
        vector<int> res;
        int j = 0;
        for (int i = 0; i < n; i++)
        {
            while (j > 0 && text[i] != pat[j]) j = pi[j - 1];
            if (text[i] == pat[j]) j++;
            if (j == m)
            {
                res.push_back(i - m + 1);
                j = pi[j - 1];
            }
        }
        return res;
    }
};

Z 函数(拓展 KMP )

算法介绍

Z 函数对字符串 S 的每个位置 i(0 ≤ i < n)定义 Z[i] 为从 i 开始,与 S[0..] 的最长公共前缀长度。常用于快速匹配、字符串周期、找回文、字符串组合问题等。Z 算法在 O(n) 时间内构造所有 Z[i],用两个指针维护当前已知的右边界区间 [L,R],在这个区间内利用之前计算结果加速,否则直接匹配。

常见例题

题目: 给定一个字符串 S 和一个模式串 P,求主串 S 中 P 所有出现的位置。一个做法是构造字符串 P + '#' + S('#' 是不在字母表中的分隔符),然后用 Z 函数计算对整个合并串的 Z 数组,任何 i ≥ m+1 且 Z[i] ≥ m 的位置 i−(m+1) 即为匹配起始位置。

代码
// Z 函数模板
// 功能:构造字符串 s 的 Z 数组 Z[i] 表示 s[i..] 与 s[0..] 的最长公共前缀长度
// 复杂度:O(n)
struct ZFunction
{
    string s;              // 输入字符串
    vector<int> z;         // Z 数组,长度 n

    // 构造函数,接受字符串 s_ 并构造 z
    ZFunction(const string &s_) : s(s_), z((int)s_.size(), 0) { build(); }

    // 功能:构造 z 数组
    void build()
    {
        int n = (int)s.size();
        int L = 0, R = 0;
        z[0] = n;
        for (int i = 1; i < n; i++)
        {
            if (i <= R) z[i] = min(R - i + 1, z[i - L]);
            else z[i] = 0;
            while (i + z[i] < n && s[z[i]] == s[i + z[i]]) z[i]++;
            if (i + z[i] - 1 > R) { L = i; R = i + z[i] - 1; }
        }
    }

    // 功能:判断模式 pat 是否在 text 中出现,返回所有起始位置
    vector<int> matchPositions(const string &pat, const string &text)
    {
        string comb = pat + '#' + text;
        ZFunction zf(comb);
        int m = (int)pat.size();
        vector<int> res;
        for (int i = m + 1; i < (int)zf.z.size(); i++)
            if (zf.z[i] >= m) res.push_back(i - (m + 1));
        return res;
    }
};

AC 自动机

算法介绍

AC 自动机用来在一个主串中同时匹配多模式(多串)集合。先把所有模式插入一个 Trie 树中,每个节点记录若是模式末尾则有输出,并构造 fail(失配指针)链接,用 BFS 从根开始为每个节点设定 fail 指向最长真后缀的节点;匹配过程按主串字符逐步从当前状态转移(存在 Trie 边优先),若无匹配边则沿 fail 指针回退直到根或有匹配边;所有输出模式在匹配状态中随时可以被报告。总体构建时间是 O(sum of pattern lengths * σ),匹配主串时间 O(n + total报告数),σ 是字母表大小。

常见例题

题目: 给定一个 n(1 ≤ n ≤ 10^5)大小的文本串 S 与 k(1 ≤ k ≤ 10^3)个模式串 P_i,总长度之和 M ≤ 10^5。要求输出每个 P_i 在 S 中出现的次数。

做法: 构造 AC 自动机,把所有模式插入;构造 fail 表;然后对 S 扫描,用状态机转移并累加每个模式末尾节点遇到的输出次数。

代码
// AC 自动机模板
// 功能:支持插入多个模式串,构造 fail 指针,并在主串中匹配所有模式的出现次数
// 复杂度:插入总长度 M,构造 fail O(M σ),匹配主串长度 n 时间 O(n + 出现次数)
struct ACAutomaton
{
    static const int SIGMA = 26;           // 字母表大小,可按需修改
    struct Node
    {
        int next[SIGMA];                  // Trie 树转移边
        int fail;                        // 失配指针
        int outCount;                    // 该节点结束了多少模式
        Node() : fail(0), outCount(0) { for (int i = 0; i < SIGMA; i++) next[i] = -1; }
    };

    vector<Node> nodes;                    // 所有节点,根为 0

    // 构造函数,初始化根节点
    ACAutomaton() { nodes.emplace_back(); }

    // 功能:插入一个模式串 pattern,编号不返回,仅统计 outCount
    void insert(const string &pattern)
    {
        int u = 0;
        for (char ch : pattern)
        {
            int c = ch - 'a';
            if (nodes[u].next[c] == -1)
            {
                nodes[u].next[c] = (int)nodes.size();
                nodes.emplace_back();
            }
            u = nodes[u].next[c];
        }
        nodes[u].outCount++;
    }

    // 功能:构造 fail 指针与输出计数(累加父 fail 链上的 outCount)
    void build()
    {
        queue<int> q;
        for (int c = 0; c < SIGMA; c++)
            if (nodes[0].next[c] != -1) { nodes[nodes[0].next[c]].fail = 0; q.push(nodes[0].next[c]); }
            else nodes[0].next[c] = 0;
        while (!q.empty())
        {
            int u = q.front(); q.pop();
            for (int c = 0; c < SIGMA; c++)
            {
                int v = nodes[u].next[c];
                if (v != -1)
                {
                    nodes[v].fail = nodes[nodes[u].fail].next[c];
                    nodes[v].outCount += nodes[nodes[nodes[u].fail].outCount];
                    q.push(v);
                }
                else nodes[u].next[c] = nodes[nodes[u].fail].next[c];
            }
        }
    }

    // 功能:匹配主串 text,返回总出现次数(每次遇到模式结尾节点时累加该节点 outCount)
    ll matchCount(const string &text)
    {
        int u = 0;
        ll total = 0;
        for (char ch : text)
        {
            int c = ch - 'a';
            u = nodes[u].next[c];
            total += nodes[u].outCount;
        }
        return total;
    }

    // 可选功能:返回每个模式串是否出现,或出现次数;这需要在 insert 时记录每个 pattern 所在节点索引,然后在 match 过程中对这些节点的 outCount 累加或记录
};

后缀

后缀数组

算法介绍

后缀数组对字符串 s 的全部后缀进行排序并记录起点索引,常配合 LCP 数组解决最长公共子串、重复子串计数等问题。常用的倍增算法在 O(n log n) 内构建 SA;在 SA 基础上用 Kasai 可在线性时间构建 LCP,LCP[i] 表示 SA[i] 与 SA[i−1] 两个后缀的最长公共前缀长度。这些结论在竞赛资料中是经典结论。

常见例题

题目: 给定字符串 s,要求统计不同子串个数。

做法: 先构造 SA 与 LCP。不同子串数量等于所有后缀长度之和减去相邻后缀的 LCP 之和,即 ∑(n−SA[i]) − ∑LCP[i]。若题目改为求两个串的最长公共子串,可把两串用分隔符拼接后,据分隔符归属只在跨边界的相邻后缀上取最小 LCP 即可快速得到答案。

代码
// 后缀数组(倍增 O(n log n))+ Kasai LCP(O(n))
// 用途:构造 SA, rank, LCP;支持按需获取不同子串数
struct SuffixArray
{
    string s;                // 原串
    int n;                   // 长度
    vector<int> sa;          // 后缀数组,0..n-1 的排列
    vector<int> rk;          // rk[i] = 后缀 i 在 SA 中的名次
    vector<int> lcp;         // lcp[i] = LCP(SA[i], SA[i-1]),lcp[0]=0

    // 构造函数,直接构建
    SuffixArray(const string &s_ = "") { init(s_); }

    // 初始化并构建 SA 与 LCP
    void init(const string &s_)
    {
        s = s_; n = (int)s.size();
        sa.resize(n), rk.resize(n), lcp.assign(n, 0);
        if (!n) return;
        buildSA(); buildLCP();
    }

    // 倍增法构建 SA
    void buildSA()
    {
        vector<int> tmp(n), key(n);
        for (int i = 0; i < n; i++) sa[i] = i, rk[i] = s[i];
        for (int k = 1;; k <<= 1)
        {
            auto cmp = [&](int a, int b)
            {
                if (rk[a] != rk[b]) return rk[a] < rk[b];
                int ra = a + k < n ? rk[a + k] : -1;
                int rb = b + k < n ? rk[b + k] : -1;
                return ra < rb;
            };
            // 基于 pair(rk[i], rk[i+k]) 的排序,可用稳定排序或基数排序优化
            iota(sa.begin(), sa.end(), 0);
            stable_sort(sa.begin(), sa.end(), cmp);
            tmp[sa[0]] = 0;
            for (int i = 1; i < n; i++) tmp[sa[i]] = tmp[sa[i - 1]] + (cmp(sa[i - 1], sa[i]) ? 1 : 0);
            for (int i = 0; i < n; i++) rk[i] = tmp[i];
            if (rk[sa.back()] == n - 1) break;
        }
    }

    // Kasai 构建 LCP,lcp[i] = LCP(SA[i], SA[i-1])
    void buildLCP()
    {
        for (int i = 0; i < n; i++) rk[sa[i]] = i;
        int k = 0;
        for (int i = 0; i < n; i++)
        {
            if (rk[i] == 0) { k = 0; continue; }
            int j = sa[rk[i] - 1];
            while (i + k < n && j + k < n && s[i + k] == s[j + k]) k++;
            lcp[rk[i]] = k;
            if (k) k--;
        }
    }

    // 返回不同子串个数
    ll distinctSubstrings() const
    {
        ll all = 1LL * n * (n + 1) / 2;
        ll sumLcp = 0;
        for (int i = 1; i < n; i++) sumLcp += lcp[i];
        return all - sumLcp;
    }
};

后缀自动机

算法介绍

后缀自动机是正则子串集合的最小有向无环自动机,可在线 O(n) 构建并解决大量子串计数与匹配问题。状态保存 endpos 等价类的最长长度 len,后缀链接 link,转移 next。按字符从左到右扩展,针对每个新字符维护“走转移、克隆、连后缀”的三分支,保证图规模 O(n)。常用性质包括不同子串计数等于 ∑(len[v]−len[link[v]])。这些结论与作法可参见标准资料。

常见例题

题目: 给定字符串 s,输出不同子串个数。

做法: 在线构造 SAM,遍历所有状态累加 len[v]−len[link[v]]。若题目改为给定 t,问 t 是否为 s 的子串,则从起点按 t 的字符尝试沿转移走,任何一步走不动即失败,否则成功。

代码
// 后缀自动机(小写字母 a..z),在线 O(n) 构建
// 功能:支持插入与不同子串计数;可按 need 扩展到计数出现次数等
struct SuffixAutomaton
{
    struct State
    {
        int next[26];
        int link;
        int len;
        State() : link(-1), len(0) { for (int i = 0; i < 26; i++) next[i] = -1; }
    };

    vector<State> st;
    int last;

    // 构造函数,预留 2n 空间
    SuffixAutomaton(int reserveLen = 0)
    {
        st.reserve(reserveLen ? 2 * reserveLen : 0);
        st.push_back(State());
        last = 0;
    }

    // 插入单个字符 ch
    void extend(char ch)
    {
        int c = ch - 'a';
        int cur = (int)st.size();
        st.push_back(State());
        st[cur].len = st[last].len + 1;

        int p = last;
        while (p != -1 && st[p].next[c] == -1) st[p].next[c] = cur, p = st[p].link;
        if (p == -1) st[cur].link = 0;
        else
        {
            int q = st[p].next[c];
            if (st[p].len + 1 == st[q].len) st[cur].link = q;
            else
            {
                int clone = (int)st.size();
                st.push_back(st[q]);
                st[clone].len = st[p].len + 1;
                while (p != -1 && st[p].next[c] == q) st[p].next[c] = clone, p = st[p].link;
                st[q].link = st[cur].link = clone;
            }
        }
        last = cur;
    }

    // 批量插入字符串
    void build(const string &s)
    {
        for (char ch : s) extend(ch);
    }

    // 统计不同子串个数: sum(len[v] - len[link[v]])
    ll countDistinct()
    {
        ll ans = 0;
        for (int v = 1; v < (int)st.size(); v++) ans += st[v].len - st[st[v].link].len;
        return ans;
    }

    // 判断 t 是否为子串
    bool contains(const string &t)
    {
        int u = 0;
        for (char ch : t)
        {
            int c = ch - 'a';
            if (st[u].next[c] == -1) return false;
            u = st[u].next[c];
        }
        return true;
    }
};

最小表示法

算法介绍

最小表示法将环串映射为其字典序最小的旋转,常用于循环同构判定与哈希规约。Booth 算法在 O(n) 时间求得起点索引,通过在 s+s 上用两指针比较维护最优起点。

常见例题

题目: 给定两个由小写字母构成的环串,判断它们是否同构。

做法: 分别求两串的最小旋转起点 posA 与 posB,然后取各自从该起点起的 n 个字符比较是否相同。

代码
// Booth 算法:返回最小表示的起始下标
// 功能:在 O(n) 时间内求 s 的字典序最小旋转起点
int minimalRotationIndex(const string &s)
{
    string t = s + s;
    int n = (int)s.size();
    int i = 0, j = 1, k = 0;
    while (i < n && j < n && k < n)
    {
        char a = t[i + k], b = t[j + k];
        if (a == b) { k++; continue; }
        if (a > b) i = i + k + 1; else j = j + k + 1;
        if (i == j) j++;
        k = 0;
    }
    return min(i, j);
}

回文串

Manacher

算法介绍

Manacher 算法在 O(n) 时间内求出以每个位置为中心的最长回文半径,通常同时维护奇半径与偶半径数组。它利用当前最右回文 [l,r] 的对称性将部分中心的答案直接从镜像拷贝,然后向外扩展修正。

常见例题

题目: 给定字符串 s,求最长回文子串在 s 中的 [l,r] 位置。

做法: 先用 Manacher 求出奇偶半径,再在遍历中把奇中心 i 的回文区间折算为 [i−d1[i]+1, i+d1[i]−1],偶中心 i 的为 [i−d2[i], i+d2[i]−1],取长度最大的区间即可。

代码
// Manacher:求奇回文半径 d1 与偶回文半径 d2
// 语义:d1[i] = 以 i 为中心的最长奇回文半径;d2[i] = 以 i-1、i 为中心的最长偶回文半径
struct Manacher
{
    string s;
    int n;
    vector<int> d1, d2;

    Manacher(const string &s_ = "") { init(s_); }

    void init(const string &s_)
    {
        s = s_; n = (int)s.size();
        d1.assign(n, 0), d2.assign(n, 0);
        if (!n) return;

        // 奇半径
        for (int i = 0, l = 0, r = -1; i < n; i++)
        {
            int k = 1;
            if (i <= r) k = min(d1[l + r - i], r - i + 1);
            while (0 <= i - k && i + k < n && s[i - k] == s[i + k]) k++;
            d1[i] = k;
            if (i + k - 1 > r) l = i - k + 1, r = i + k - 1;
        }

        // 偶半径
        for (int i = 0, l = 0, r = -1; i < n; i++)
        {
            int k = 0;
            if (i <= r) k = min(d2[l + r - i + 1], r - i + 1);
            while (0 <= i - k - 1 && i + k < n && s[i - k - 1] == s[i + k]) k++;
            d2[i] = k;
            if (i + k - 1 > r) l = i - k, r = i + k - 1;
        }
    }

    // 求最长回文子串对应的闭区间 [L,R]
    pair<int,int> longestInterval()
    {
        int bestLen = 0, L = 0, R = -1;

        for (int i = 0; i < n; i++)
        {
            int len = 2 * d1[i] - 1;
            if (len > bestLen)
            {
                bestLen = len;
                L = i - d1[i] + 1, R = i + d1[i] - 1;
            }
        }
        for (int i = 0; i < n; i++)
        {
            int len = 2 * d2[i];
            if (len > bestLen)
            {
                bestLen = len;
                L = i - d2[i], R = i + d2[i] - 1;
            }
        }
        return {L, R};
    }
};

回文自动机

算法介绍

回文自动机以两棵带 fail 指针的“回文树”在线维护全部不同回文子串,存在长度 −1 与 0 的两个根。每次向右加入一个字符,在当前“最大后缀回文”出发沿 fail 找到能扩大为新回文的位置,若不存在就新建节点并设定 fail 指向相应最长真后缀回文。总构建是 O(n) 节点与转移量级的,能在线统计不同回文数量与每个回文出现次数。

常见例题

题目: 给定字符串 s,输出不同回文子串个数与每个回文出现次数。

做法: 在线构建 Eertree,节点中新建时把计数置为 1,整串读完以后按节点长度从大到小把出现次数沿 fail 聚合回父回文,这样每个节点的 cnt 就是对应回文在 s 中的出现次数,而节点总数减去 2 就是不同回文的数量。

代码
// 回文自动机(Eertree),字符集默认 a..z
// 节点含:len(回文长度)、fail(最长真回文后缀链接)、next(按字符转移)、cnt(出现次数)、num(以该回文为结尾的回文后缀个数)
struct Eertree
{
    struct Node
    {
        int len, fail, cnt, num;
        int nxt[26];
        Node(int L = 0) : len(L), fail(0), cnt(0), num(0) { for (int i = 0; i < 26; i++) nxt[i] = 0; }
    };

    vector<Node> t;     // 节点数组,1-based 方便,t[1] 长度 -1,t[2] 长度 0
    string s;           // 动态维护的串,从下标 0 开始追加
    int last;           // 指向当前最大后缀回文所在的节点
    int n;              // s 的当前长度

    // 构造与初始化
    Eertree() { init(); }

    void init()
    {
        t.clear(); s.clear(); last = 2; n = 0;
        t.reserve(1 << 20);
        t.emplace_back(Node(0));               // 占位 0
        t.emplace_back(Node(-1));              // 1: len=-1 的虚根
        t.emplace_back(Node(0));               // 2: len=0 的空串根
        t[1].fail = 1; t[2].fail = 1;
    }

    // 辅助:沿 fail 走,寻找以 s[pos] 为右端点时能继续扩展的回文
    int getFail(int x, int pos)
    {
        while (true)
        {
            int L = t[x].len;
            if (pos - 1 - L >= 0 && s[pos - 1 - L] == s[pos]) return x;
            x = t[x].fail;
        }
    }

    // 插入单个字符 ch,返回新建节点编号(若无新建则返回已存在的编号)
    int addChar(char ch)
    {
        int c = ch - 'a';
        s.push_back(ch); n++;
        int cur = getFail(last, n - 1);

        if (!t[cur].nxt[c])
        {
            int now = (int)t.size();
            t.emplace_back(Node(t[cur].len + 2));
            int failTo = t[getFail(t[cur].fail, n - 1)].nxt[c];
            t[now].fail = failTo ? failTo : 2;
            t[cur].nxt[c] = now;
            t[now].num = t[t[now].fail].num + 1;
        }
        last = t[cur].nxt[c];
        t[last].cnt++;
        return last;
    }

    // 统计每个回文出现次数,需在整串加入完之后调用
    void countOccurrences()
    {
        vector<pair<int,int>> ord;
        ord.reserve((int)t.size() - 3);
        for (int i = 3; i < (int)t.size(); i++) ord.emplace_back(t[i].len, i);
        sort(ord.rbegin(), ord.rend());
        for (auto [L, id] : ord) t[t[id].fail].cnt += t[id].cnt;
    }

    // 返回不同回文子串个数(不含两个根)
    int distinct() const
    {
        return (int)t.size() - 3;
    }
};

四、动态规划

基础 DP

背包 DP

算法介绍

把最优子结构刻画为状态,经典是容量与物品编号这两维。以闭区间 [l,r] 统一叙述时,区间的意义在“枚举体积或价值的范围”与“滚动数组转移窗口”中均应保持闭包一致。0/1 背包用从右到左的体积维度更新避免重复使用,同类地,完全背包从左到右更新,多重背包可用二进制拆分转化为若干件 0/1 物品。思路与正确性要点、常见细节与优化在竞赛社区与资料中有系统总结,适合直接套用与按需扩展。也可据题意改写为计数型背包或多目标背包。

常见例题

题目:给定 n 件物品,第 i 件体积 w[i]、价值 v[i]。给定容量上限 C,回答 q 次询问,每次给出 [l,r] 表示只允许使用编号落在 [l,r] 的物品,问在容量不超过 C 的前提下的最大价值。

做法:预处理两棵“分治卷积背包树”,左树自左向右、右树自右向左各维护一段的 0/1 背包 dp,在线回答时把 [l,r] 拆成“左树覆盖的若干段”与“右树覆盖的若干段”,把这些段的 dp 表按容量做一次合并即可。容量维始终按 [0,C] 闭区间滚动更新。

代码
// 背包DP模板,支持0/1、完全、多重三种模式,容量闭区间为[0,cap]
// 用途:solve01/solveComplete/solveMultiple 返回最大价值;可据此拓展为计数或可行性
struct Knapsack
{
    int n, cap;                       // 物品数与容量上限
    vector<int> w, v, c;              // 体积、价值、数量(多重)
    vector<ll> f;              // dp[0..cap],闭区间

    Knapsack(): n(0), cap(0) {}
    Knapsack(int cap_): n(0), cap(cap_) { f.assign(cap + 1, 0); }

    void init(int cap_)
    {
        cap = cap_;
        f.assign(cap + 1, 0);
        w.clear(), v.clear(), c.clear();
        n = 0;
    }

    void addItem01(int wi, int vi)
    {
        w.push_back(wi), v.push_back(vi), c.push_back(1), n++;
    }

    void addItemComplete(int wi, int vi)
    {
        w.push_back(wi), v.push_back(vi), c.push_back(-1), n++;
    }

    void addItemMultiple(int wi, int vi, int ci)
    {
        w.push_back(wi), v.push_back(vi), c.push_back(ci), n++;
    }

    // 0/1背包,f[x] 表示容量x的最大价值
    void solve01()
    {
        for (int i = 0; i < n; i++)
        {
            if (c[i] != 1) continue;
            for (int x = cap; x >= w[i]; x--) f[x] = max(f[x], f[x - w[i]] + v[i]);
        }
    }

    // 完全背包,从小到大,确保同一轮可以重复使用
    void solveComplete()
    {
        for (int i = 0; i < n; i++)
        {
            if (c[i] != -1) continue;
            for (int x = w[i]; x <= cap; x++) f[x] = max(f[x], f[x - w[i]] + v[i]);
        }
    }

    // 多重背包,二进制拆分为若干件0/1物品
    void solveMultiple()
    {
        vector<pair<int,int>> items;
        for (int i = 0; i < n; i++)
        {
            if (c[i] <= 0) continue;
            int k = c[i], wi = w[i], vi = v[i];
            for (int p = 1; p <= k; p <<= 1)
            {
                items.emplace_back(p * wi, p * vi);
                k -= p;
            }
            if (k) items.emplace_back(k * wi, k * vi);
        }
        for (auto [wi, vi] : items)
        {
            for (int x = cap; x >= wi; x--) f[x] = max(f[x], f[x - wi] + vi);
        }
    }

    // 查询容量闭区间 [0,cap] 中的最优值
    ll best() const
    {
        return *max_element(f.begin(), f.end());
    }
};

区间 DP

算法介绍

区间 DP 把答案定义在子区间 [l,r] 上,核心是按区间长度升序枚举,再在 [l,r] 内枚举切分点 k 或最后合并位置。典型如石子合并、矩阵连乘、括号计数与最优二叉树,统一转移形如 由 dp[l][k]dp[k+1][r] 合并得到,并在需要时加上跨区间代价。

常见例题

题目:给定长度为 n 的环状石子堆,石子重量 a[i]。每次合并相邻两堆代价等于两堆石子总数,求最小总代价。

做法:先把环断成链复制一倍,在 [l,r] 上用 dp 计算最小合并代价,转移是 dp[l] [r] = min_{k in [l,r-1]} dp[l] [k] + dp[k+1] [r] + sum[l..r],答案是长度为 n 的所有连续区间的最小值。

代码
// 区间DP模板(闭区间 [l,r]),以“石子合并最小代价”为例
struct IntervalDP
{
    int n;                                  // 原序列长度
    vector<ll> a, pre;               // 倍长序列与前缀和
    vector<vector<ll>> dp;           // dp[l] [r] 最小代价

    IntervalDP(): n(0) {}

    // 构造:传入权值数组
    IntervalDP(const vector<int> &w) { init(w); }

    void init(const vector<int> &w)
    {
        n = (int)w.size();
        a.assign(2 * n + 1, 0);
        for (int i = 1; i <= 2 * n; i++) a[i] = w[(i - 1) % n];
        pre.assign(2 * n + 1, 0);
        for (int i = 1; i <= 2 * n; i++) pre[i] = pre[i - 1] + a[i];
        dp.assign(2 * n + 2, vector<ll>(2 * n + 2, 0));
        for (int len = 2; len <= n; len++)
        {
            for (int l = 1, r = l + len - 1; r <= 2 * n; l++, r++)
            {
                ll best = (ll)4e18;
                for (int k = l; k < r; k++)
                {
                    ll cost = dp[l][k] + dp[k + 1][r] + pre[r] - pre[l - 1];
                    if (cost < best) best = cost;
                }
                dp[l][r] = best;
            }
        }
    }

    // 查询环上最小合并代价
    ll answer()
    {
        ll ans = (ll)4e18;
        for (int l = 1, r = l + n - 1; r <= 2 * n; l++, r++) ans = min(ans, dp[l][r]);
        return ans;
    }
};

DAG上的 DP

算法介绍

在有向无环图上把拓扑序作为递推顺序,状态可以是“到达某点的最优值”或“从某点出发的最优值”。拓扑序保证所有入边的来源都已计算完,转移时遍历入边或出边即可。

常见例题

题目:给定 n 点 m 边的 DAG,边权为非负,求从源 s 到每个点的最长路径长度。

做法:拓扑排序后按序放松转移,f[u] 已知时,用 f[v] = max(f[v], f[u] + w(u,v)) 更新所有出边。

代码
// DAG最长路DP(闭区间编号 [1,n])
// 功能:给定DAG与源点s,计算 f[u] 为 s 到 u 的最长路
struct DagDP
{
    int n, m, s;
    vector<vector<pair<int,int>>> g;
    vector<int> indeg, order;
    vector<ll> f;

    DagDP(): n(0), m(0), s(1) {}
    DagDP(int n_, int s_ = 1): n(n_), m(0), s(s_)
    {
        g.assign(n + 1, {});
        indeg.assign(n + 1, 0);
    }

    void addEdge(int u, int v, int w)
    {
        g[u].push_back({v, w});
        indeg[v]++; m++;
    }

    void topo()
    {
        queue<int> q;
        for (int i = 1; i <= n; i++) if (!indeg[i]) q.push(i);
        while (!q.empty())
        {
            int u = q.front(); q.pop();
            order.push_back(u);
            for (auto [v, w] : g[u]) if (--indeg[v] == 0) q.push(v);
        }
    }

    void run()
    {
        f.assign(n + 1, (ll)-4e18);
        f[s] = 0;
        topo();
        for (int u : order)
        {
            if (f[u] <= (ll)-3e18) continue;
            for (auto [v, w] : g[u]) f[v] = max(f[v], f[u] + w);
        }
    }
};

参考思路可在资料中找到拓展,如把最短路改为最长路径并处理权值范围、或统计条数。


树形 DP

算法介绍

在树上把子树记为闭区间 [l,r] 的“结构块”隐喻,核心是“先算子再回父”。常见如树的最大独立集、最大匹配、直径与重心等,都可以把每个结点的信息由其孩子信息合并得到。

常见例题

题目:给定一棵树,求最大独立集的大小。

做法:设 f[u] [0] 表示 u 不选时的最大独立集大小,f[u] [1] 表示 u 选时的值。转移为 f[u] [1] = 1 + Σ f[v] [0],f[u] [0] = Σ max(f[v] [0], f[v] [1])。

代码
// 树形DP:最大独立集
struct TreeDP
{
    int n;
    vector<vector<int>> g;
    vector<array<int,2>> f;
    int root;

    TreeDP(): n(0), root(1) {}
    TreeDP(int n_, int root_ = 1): n(n_), root(root_) { g.assign(n + 1, {}); }

    void addEdge(int u, int v)
    {
        g[u].push_back(v), g[v].push_back(u);
    }

    void dfs(int u, int p)
    {
        f[u] = {0, 1};
        for (int v : g[u]) if (v != p)
        {
            dfs(v, u);
            f[u][1] += f[v][0];
            f[u][0] += max(f[v][0], f[v][1]);
        }
    }

    int maxIndependentSet()
    {
        f.assign(n + 1, {0, 0});
        dfs(root, 0);
        return max(f[root][0], f[root][1]);
    }
};

树形背包 DP

算法介绍

在树上做容量或数量约束的背包,通常是“父合并子”的分组卷积。每处理一个孩子 v,就把当前 u 的 dp 用一个小背包与子树 v 的 dp 在容量闭区间 [0,C] 上做合并卷积,复杂度与度数和容量乘积相关。

常见例题

题目:给定一棵树,每个结点有价值 val[u] 和体积 cost[u],选出的结点集合必须满足“选择结点则必须选择其父亲”,容量不超过 C,最大化价值。

做法:在结点 u 的 dp 上先放入自身,再逐子合并,dp[u] [x] 表示在 u 的子树内、容量恰为 x 的最大价值。

代码
// 树形背包:父选则子可选,容量闭区间[0,cap]
struct TreeKnapsack
{
    int n, cap, root;
    vector<vector<int>> g;
    vector<int> w, val;
    vector<vector<ll>> dp, tmp;

    TreeKnapsack(): n(0), cap(0), root(1) {}
    TreeKnapsack(int n_, int cap_, int root_ = 1): n(n_), cap(cap_), root(root_)
    {
        g.assign(n + 1, {});
        w.assign(n + 1, 0);
        val.assign(n + 1, 0);
        dp.assign(n + 1, vector<ll>(cap + 1, (ll)-4e18));
        tmp.assign(cap + 1, vector<ll>(cap + 1, (ll)-4e18));
    }

    void addEdge(int u, int v) { g[u].push_back(v), g[v].push_back(u); }

    void setNode(int u, int wi, int vi) { w[u] = wi, val[u] = vi; }

    void dfs(int u, int p)
    {
        for (int x = 0; x <= cap; x++) dp[u][x] = (ll)-4e18;
        if (w[u] <= cap) dp[u][w[u]] = val[u];

        for (int v : g[u]) if (v != p)
        {
            dfs(v, u);
            vector<ll> ndp(cap + 1, (ll)-4e18);
            for (int x = 0; x <= cap; x++)
            {
                if (dp[u][x] <= (ll)-3e18) continue;
                ndp[x] = max(ndp[x], dp[u][x]);
                for (int y = 0; y + x <= cap; y++)
                {
                    if (dp[v][y] <= (ll)-3e18) continue;
                    ndp[x + y] = max(ndp[x + y], dp[u][x] + dp[v][y]);
                }
            }
            dp[u].swap(ndp);
        }
    }

    ll solve()
    {
        dfs(root, 0);
        ll ans = 0;
        for (int x = 0; x <= cap; x++) ans = max(ans, dp[root][x]);
        return ans;
    }
};

状压 DP

算法介绍

把若干二进制属性压成一个整数的位集,以子集枚举和位运算作为转移的基本动作。旅行商问题的 Held–Karp 写法是竞赛最常用模板之一,状态 dp[mask] [i] 表示访问了 mask 中的点并以 i 结尾的最短路,按 [l,r] 的闭区间理解时常把城市编号范围视作 [1,n],每个 mask 覆盖该闭区间内的子集。

常见例题

题目:给定完全图边权,求 TSP 最短回路。

做法:dp[1<<s] [s]=0 起步,枚举子集与终点,转移枚举上一点。可顺带保存路径重建信息。

代码
// 状压DP:Held–Karp求TSP
struct TspDp
{
    int n, start;
    vector<vector<ll>> w;
    vector<vector<ll>> dp;
    vector<vector<int>> pre;

    TspDp(): n(0), start(0) {}
    TspDp(int n_, int s_ = 0): n(n_), start(s_)
    {
        w.assign(n, vector<ll>(n, (ll)4e18));
        int full = 1 << n;
        dp.assign(full, vector<ll>(n, (ll)4e18));
        pre.assign(full, vector<int>(n, -1));
    }

    void setEdge(int i, int j, ll c) { w[i][j] = c; }

    ll solve()
    {
        int full = 1 << n;
        dp[1 << start][start] = 0;
        for (int mask = 0; mask < full; mask++)
        {
            for (int u = 0; u < n; u++)
            {
                if (!(mask >> u & 1)) continue;
                ll cur = dp[mask][u];
                if (cur >= (ll)3e18) continue;
                for (int v = 0; v < n; v++)
                {
                    if (mask >> v & 1) continue;
                    int nmask = mask | (1 << v);
                    ll cand = cur + w[u][v];
                    if (cand < dp[nmask][v]) dp[nmask][v] = cand, pre[nmask][v] = u;
                }
            }
        }
        ll ans = (ll)4e18;
        for (int u = 0; u < n; u++) ans = min(ans, dp[full - 1][u] + w[u][start]);
        return ans;
    }
};

数位 DP

算法介绍

围绕十进制或任意进制的逐位扫描,用“是否贴合上界的限制 tight”“是否已经开始计数的标志 started”“附加性质的记忆”构成状态。区间计数通常把 [L,R] 问题拆成 solve(R) − solve(L−1)。

常见例题

题目:统计区间 [l,r] 内各位数字之和为 S 的整数个数。

做法:编写 solve(x) 返回 [0,x] 的计数,状态为位置 pos、当前和 sum、是否贴合 tight、是否已开始 started,记忆化搜索即得。闭区间 [l,r] 直接相减。

代码
// 数位DP:统计 [l,r] 之间数位和等于S的个数
struct DigitDP
{
    vector<int> dig;
    ll memo[20][200][2][2];
    bool vis[20][200][2][2];
    int target;

    DigitDP(): target(0) { clear(); }

    void clear()
    {
        for (int i = 0; i < 20; i++)
            for (int s = 0; s < 200; s++)
                for (int t = 0; t < 2; t++)
                    for (int st = 0; st < 2; st++)
                        vis[i][s][t][st] = false, memo[i][s][t][st] = 0;
    }

    ll dfs(int pos, int sum, int tight, int started)
    {
        if (pos == (int)dig.size()) return sum == target;
        if (vis[pos][sum][tight][started]) return memo[pos][sum][tight][started];
        vis[pos][sum][tight][started] = true;
        int up = tight ? dig[pos] : 9;
        ll res = 0;
        for (int d = 0; d <= up; d++)
        {
            int nt = tight && (d == up);
            int ns = started || d != 0;
            int nsum = sum + (ns ? d : 0);
            if (nsum > target) continue;
            res += dfs(pos + 1, nsum, nt, ns);
        }
        return memo[pos][sum][tight][started] = res;
    }

    ll solve(ll x, int S)
    {
        if (x < 0) return 0;
        target = S;
        clear();
        dig.clear();
        if (x == 0) dig.push_back(0); else
        {
            vector<int> t;
            while (x) t.push_back((int)(x % 10)), x /= 10;
            for (int i = (int)t.size() - 1; i >= 0; i--) dig.push_back(t[i]);
        }
        return dfs(0, 0, 1, 0);
    }

    ll countInRange(ll L, ll R, int S)
    {
        return solve(R, S) - solve(L - 1, S);
    }
};

插头 DP

算法介绍

按行推进的小网格状态压缩,把每一列的“连接口”编码在位集里,转移时维护插头配对关系,用以解决棋盘型计数与铺砖问题。模板化时,把每行处理为从左到右的若干单元格,枚举“上下左右”的连通开闭,配合哈希表或数组滚动。竞赛社区对 Plug DP 的讲解与样例丰富,适合按题面图案定制状态含义与转移。

常见例题

题目:统计用 1×2 多米诺铺满 n×m 棋盘的方案数(允许 n、m 交换以减维)。

做法:行为主维度,状态是当前行每列“是否有向上插头”,对每个格子决定横放或竖放,竖放改变当前行与下一行的插头位,横放在当前行内消除一对相邻插头;到达最后一格且状态全零即累计答案。

代码
// 插头DP示意:多米诺铺满计数(n行m列,按行推进)
// 说明:为了清晰起见,这里用unordered_map做稀疏滚动,状态为mask表示每列向上插头
struct PlugDP
{
    int n, m;
    static const int MOD = 1000000007;

    PlugDP(): n(0), m(0) {}
    PlugDP(int n_, int m_): n(n_), m(m_) {}

    int addmod(int a, int b) { a += b; if (a >= MOD) a -= MOD; return a; }

    int solve()
    {
        using um = unordered_map<int,int>;
        um cur, nxt;
        cur[0] = 1;
        for (int i = 1; i <= n; i++)
        {
            for (int j = 1; j <= m; j++)
            {
                nxt.clear();
                for (auto &kv : cur)
                {
                    int mask = kv.first, ways = kv.second;
                    int up = (mask >> (j - 1)) & 1;
                    int left = (j > 1) ? ((mask >> (j - 2)) & 1) : 0;

                    // 若当前格已被竖块占据(来自上一行),则只能不放新的竖块,把该位清0
                    if (up)
                    {
                        int nmask = mask & ~(1 << (j - 1));
                        nxt[nmask] = addmod(nxt[nmask], ways);
                    }
                    else
                    {
                        // 尝试横放,与左侧未占用且left==0时可行
                        if (j < m && !left)
                        {
                            int nmask = mask | (1 << (j - 1));      // 在当前位置放入“临时占位”,与下个格子配对后清除
                            nmask |= (1 << j);                        // 标记 j+1 位为占用
                            nmask &= ~(1 << (j - 1));                 // 立刻消去成对占用,保持“无悬挂插头”
                            nmask &= ~(1 << j);
                            nxt[nmask] = addmod(nxt[nmask], ways);
                        }
                        // 尝试竖放,向下一行延伸
                        if (i < n)
                        {
                            int nmask = mask | (1 << (j - 1));
                            nxt[nmask] = addmod(nxt[nmask], ways);
                        }
                    }
                }
                cur.swap(nxt);
            }
            // 换行时把mask整体右移m位并检查是否为0
            nxt.clear();
            for (auto &kv : cur)
            {
                int mask = kv.first, ways = kv.second;
                if (mask == 0) nxt[0] = addmod(nxt[0], ways);
            }
            cur.swap(nxt);
        }
        return cur.count(0) ? cur[0] : 0;
    }
};

概率 DP

算法介绍

在概率图或随机过程上用 DP 表示到达某状态的概率或期望值。到达概率通常为前驱概率与转移概率的线性组合;期望值问题常把 E[state] 写成方程并移项,若结构是 DAG 则可直接拓扑解,若含环则可转线性方程组或高斯消元。竞赛中最常见的是“从起点随机走到终点的期望步数”与“命中概率累计”,与 DAG DP 的写法一脉相承。

常见例题

题目:在 DAG 上从源 s 以等概率沿出边随机前进,终点集合 T 为吸收状态,求到达每个点 u 的被访问概率。

做法:拓扑序从 s 向外传播,p[v] += p[u] / outdeg(u);终点的概率直接累加。

代码
// 概率传播DP:DAG上从源均匀随机游走的到达概率
struct ProbDag
{
    int n, s;
    vector<vector<int>> g;
    vector<int> indeg, outdeg, order;
    vector<ld> prob;

    ProbDag(): n(0), s(1) {}
    ProbDag(int n_, int s_ = 1): n(n_), s(s_)
    {
        g.assign(n + 1, {});
        indeg.assign(n + 1, 0);
        outdeg.assign(n + 1, 0);
    }

    void addEdge(int u, int v)
    {
        g[u].push_back(v);
        indeg[v]++, outdeg[u]++;
    }

    void topo()
    {
        queue<int> q;
        for (int i = 1; i <= n; i++) if (!indeg[i]) q.push(i);
        while (!q.empty())
        {
            int u = q.front(); q.pop();
            order.push_back(u);
            for (int v : g[u]) if (--indeg[v] == 0) q.push(v);
        }
    }

    void run()
    {
        prob.assign(n + 1, 0);
        prob[s] = 1;
        topo();
        for (int u : order)
        {
            if (outdeg[u] == 0) continue;
            ld step = prob[u] / (ld)outdeg[u];
            for (int v : g[u]) prob[v] += step;
        }
    }
};

计数 DP

算法介绍

把“方案数”作为状态值进行加法计数,常见的如路径条数、匹配计数、序列构造计数等。与背包、区间、树形、数位等各分支结合即可形成“计数背包”“区间括号合法数”“树形着色数”“数位性质计数”等,并以模数取模。

常见例题

题目:计算长度为 n 的序列个数,使其每一项来自 [1,m] 且相邻两项差的绝对值不超过 d。

做法:设 f[i] [x] 表示前 i 项且第 i 项为 x 的方案数,从 f[i−1] [y] 在闭区间 [max(1,x−d), min(m,x+d)] 内把可达的 y 累加过来,最后求 Σ f[n] [x]。

代码
// 计数DP:相邻差不超过d的序列计数,模mod
struct CountSeq
{
    int n, m, d, mod;
    vector<vector<int>> f;

    CountSeq(): n(0), m(0), d(0), mod(1000000007) {}
    CountSeq(int n_, int m_, int d_, int mod_ = 1000000007): n(n_), m(m_), d(d_), mod(mod_) {}

    int addmod(int a, int b) { a += b; if (a >= mod) a -= mod; return a; }

    int solve()
    {
        f.assign(n + 1, vector<int>(m + 1, 0));
        for (int x = 1; x <= m; x++) f[1][x] = 1;
        for (int i = 2; i <= n; i++)
        {
            for (int x = 1; x <= m; x++)
            {
                int L = max(1, x - d), R = min(m, x + d);
                // 朴素累加,可按需加前缀和优化
                int s = 0;
                for (int y = L; y <= R; y++) s = addmod(s, f[i - 1][y]);
                f[i][x] = s;
            }
        }
        int ans = 0;
        for (int x = 1; x <= m; x++) ans = addmod(ans, f[n][x]);
        return ans;
    }
};

DP 优化

分治优化

算法介绍

处理形如 dp[t] [i] 等于在 j<i 的转移中取最小值的多阶段 DP,如果满足决策单调性,即最优断点 opt[t] [i] 随 i 单调不减,就可以用分治计算一段 [L,R] 的 dp,并只在 [optL,optR] 范围里搜索断点,整体把 O(n·range) 压到 O(n log n·range′) 或 O(nk log n)。充分条件常见于代价函数满足四边形不等式或满足凸性等情形。

常见例题

题目:给定数组 a,定义代价 C(j+1,i) 为区间 [j+1,i] 的某个可预处理代价,要求分成 k 段使 ∑C 最小,输出 dp[k] [n]。

做法:经典 k 段分割。若该 C 满足单调性,则对每一层 t 固定,用分治在 [1,n] 上递归求 dp[t] [·],每个区间只在上层给出的最优断点范围内枚举 j。

代码
// 分治优化 DP 模板(闭区间下标约定 [1,n])
// 目标:已知 dpPrev[j],计算 dpCur[i] = min_{j∈[1..i-1]} { dpPrev[j] + cost(j+1, i) }
// 要求:最优断点对 i 单调(opt[i] 非递减)
template <typename T, class Cost>
struct DCDP
{
    int n;                   // 状态个数
    vector<T> dpPrev, dpCur; // 上一层/当前层
    Cost cost;               // 代价仿函数:cost(l, r)
    const T INF;

    DCDP(int n_, T inf_, Cost c) : n(n_), cost(c), INF(inf_)
    {
        dpPrev.assign(n + 1, INF);
        dpCur.assign(n + 1, INF);
    }

    // 计算一层 dpCur,外部应先填好 dpPrev
    void computeOneLayer()
    {
        auto solve = [&](auto &&self, int l, int r, int optL, int optR) -> void
        {
            if (l > r)
                return;
            int mid = (l + r) >> 1, bestK = max(1, min(mid - 1, optL));
            T best = INF;
            int up = min(optR, mid - 1);
            for (int k = bestK; k <= up; k++)
            {
                T cand = dpPrev[k] + cost(k + 1, mid);
                if (cand < best)
                    best = cand, bestK = k;
            }
            dpCur[mid] = best;
            self(self, l, mid - 1, optL, bestK);
            self(self, mid + 1, r, bestK, optR);
        };
        dpCur[1] = dpPrev[1]; // 视题意决定边界
        solve(solve, 2, n, 1, n - 1);
    }
};

单调栈/单调队列优化

算法介绍

当转移只依赖一个“窗口”或需要维护区间的最值时,可以用单调队列 O(1) 摊还取得滑动窗口最值,从而将 DP 的 O(n·w) 压到 O(n)。当转移需要维护“上一个更小/更大元素”之类结构化信息,则用单调栈在线维护边界以消掉循环。可参考滑动窗口最值的经典讲解与讲义。

常见例题

题目:dp[i] = min_{j∈[i-W,i-1]}(dp[j]) + cost[i]。

做法:用一个维护 dp[j] 的递增队列,队首始终是窗口内的最小值。每步弹出越界下标,再把新 j 入队并弹去队尾不优的元素,dp[i] 直接取队首即可。

代码
// 单调队列优化固定窗口 DP:dp[i] = min(dp[i-W..i-1]) + w[i]
template <typename T>
struct MonoQueueDP
{
    int n, W;
    vector<T> w, dp;
    MonoQueueDP(int n_, int W_) : n(n_), W(W_), w(n + 1), dp(n + 1) {}

    // 计算 dp[1..n],假设 dp[0]=0 且从 i=1 开始窗口为 [max(1,i-W), i-1]
    void solve()
    {
        deque<int> q;
        dp[0] = T{};
        for (int i = 1; i <= n; i++)
        {
            while (!q.empty() && q.front() < i - W)
                q.pop_front();
            while (!q.empty() && dp[q.back()] >= dp[i - 1])
                q.pop_back();
            q.push_back(i - 1);
            dp[i] = dp[q.front()] + w[i];
        }
    }
};


四边形不等式优化 (Knuth 优化)

算法介绍

对于区间 DP 形如 dp[i] [j] = min_{k∈[i..j-1]}(dp[i] [k] + dp[k+1] [j]) + w(i,j),若 w 满足单调性和四边形不等式,则最优断点满足 opt[i] [j-1] ≤ opt[i] [j] ≤ opt[i+1] [j],可以把枚举 k 的范围压到一个小区间,时间从 O(n^3) 降到 O(n^2)。

常见例题

题目:区间合并代价 w(i,j) 已给,要求最小合并代价。

做法:经典区间 DP。若满足 Knuth 条件,则在长度枚举时,把 k 的枚举限制在 [opt[i] [j-1], opt[i+1] [j]],边算边维护 opt。

代码
// Knuth 优化模板(闭区间 [1,n])
// 前提:满足 Knuth 条件,则 opt[i] [j] ∈ [opt[i] [j-1], opt[i+1] [j]]
template <typename T, class Weight>
struct KnuthDP
{
    int n;
    vector<vector<T>> dp;
    vector<vector<int>> opt;
    Weight w; // w(i,j) 给出区间代价
    const T INF;

    KnuthDP(int n_, T inf_, Weight W) : n(n_), w(W), INF(inf_)
    {
        dp.assign(n + 1, vector<T>(n + 1, INF));
        opt.assign(n + 1, vector<int>(n + 1, 0));
        for (int i = 1; i <= n; i++)
            dp[i][i] = T{}, opt[i][i] = i;
    }

    // 计算所有区间的最优值
    void solve()
    {
        for (int len = 2; len <= n; len++)
        {
            for (int i = 1; i + len - 1 <= n; i++)
            {
                int j = i + len - 1;
                int L = (i + 1 <= j ? opt[i][j - 1] : i), R = (i <= j - 1 ? opt[i + 1][j] : j);
                if (L > R)
                    swap(L, R);
                T best = INF;
                int who = L;
                for (int k = L; k <= R; k++)
                {
                    T cand = dp[i][k] + dp[k + 1][j] + w(i, j);
                    if (cand < best)
                        best = cand, who = k;
                }
                dp[i][j] = best;
                opt[i][j] = who;
            }
        }
    }
};

斜率优化

算法介绍

当转移为 dp[i] = min_{j<i}(m[j]·x[i] + b[j]) 且满足斜率 m 单调、查询 x 单调时,可以用维护下凸壳的双端队列把每次转移降到均摊 O(1)。若不满足单调,可用 Li Chao 线段树在 O(log U) 查询。

常见例题

题目:已知 m[j]、b[j] 和 x[i] 单调递增,求每个 i 的 dp[i] = min_j(m[j]·x[i] + b[j])。

做法:维护下凸壳。插入直线时通过“交点单调”弹出尾部无用线;查询时因为 x 单调,只需比较队首两条线的值即可。

代码
// 单调斜率 + 单调查询的下凸壳(Deque 版)
// 维护 y = m x + b 的下凸壳,addLine 要求 m 单调递增,query x 单调递增
template <typename T>
struct MonotoneCHT
{
    struct Line
    {
        T m, b;
        T eval(T x) const { return m * x + b; }
    };
    deque<Line> q;

    MonotoneCHT() {}

    // 判断 l2 是否被 l1 与 l3 弃用(交点顺序法,避免浮点:用叉积比较)
    bool bad(const Line &l1, const Line &l2, const Line &l3)
    {
        // (b3 - b1)/(m1 - m3) <= (b2 - b1)/(m1 - m2)  等价于交点(l1,l3) 在交点(l1,l2) 左侧
        return (__int128)(l3.b - l1.b) * (l1.m - l2.m) <= (__int128)(l2.b - l1.b) * (l1.m - l3.m);
    }

    // 插入一条斜率单调递增的直线
    void addLine(T m, T b)
    {
        Line L{m, b};
        while (q.size() >= 2 && bad(q[q.size() - 2], q.back(), L))
            q.pop_back();
        q.push_back(L);
    }

    // 在单调递增的 x 上查询最小值
    T query(T x)
    {
        while (q.size() >= 2 && q[0].eval(x) >= q[1].eval(x))
            q.pop_front();
        return q.front().eval(x);
    }
};

WQS 二分优化

算法介绍

当目标里带有“恰好/至少 选 K 次”这类约束,而基础 DP 在去掉次数约束后可做时,可以给每次“使用一次”的动作加上拉格朗日乘子 λ 的线性罚分,把“最大化价值,使用次数≥K”变成“最大化 价值−λ·次数”。然后对 λ 二分,使最优解的使用次数恰好跨过 K,从而在 O(log 答) 次调用基础 DP 内得到带次数约束的解,这就是 WQS(又称 Alien Trick)。

常见例题

题目:给定 n 个物品,选择若干个使得 f(选择集合) 最大,但必须选恰好 K 个,且基础 DP 能在加上线性“每选一个扣 λ”后在 O(n·something) 内求出最大值与“被选个数”。

做法:对 λ 做二分,每次跑一次“带罚分”的 DP,得到 pair{value, cnt}。若 cnt ≥ K,说明 λ 偏小,调大;否则调小。最后在最接近的 λ 上把价值加回 λ·K 得到答案。

代码
// WQS 二分模板(把“使用一次”的动作加上 -lambda 的线性罚分)
// run(lambda) 需返回 {最佳价值, 使用次数},二分到使次数跨过 K 的临界
template <typename T, class Runner>
struct WQS
{
    int K;      // 目标使用次数
    Runner run; // 给定 lambda 的求解器:返回 {value, cnt}
    T lo, hi;   // 二分范围,需由题目取值域给出
    int iters;  // 二分次数

    WQS(int K_, Runner R, T L, T H, int I = 60) : K(K_), run(R), lo(L), hi(H), iters(I) {}

    // 返回“恢复罚分”后的最优值
    T solve()
    {
        T L = lo, R = hi;
        for (int _ = 0; _ < iters; _++)
        {
            T mid = (L + R) / 2;
            auto [val, cnt] = run(mid);
            if (cnt >= K)
                L = mid;
            else
                R = mid;
        }
        auto [val, cnt] = run(L);
        return val + (T)K * L; // 把 -lambda*cnt 的惩罚加回 K*lambda
    }
};

杂项

树上背包

算法介绍

树上背包是把背包状态按树的结构自底向上合并。常见模型是每个结点 i 有体积 w[i] 与价值 v[i],选择的点集必须是某个结点子树内的若干点,容量限制为 C,问能取得的最大价值。做法是以任一根 root 把树定根,节点区间一律理解为它的整棵子树;对每个结点维护 f[i][k] 表示在结点 i 的子树内选体积恰为 k 的最大价值,然后把儿子一个个并入当前结点的 DP 表。由于合并本质是分组背包的两层循环,复杂度大致是 O(∑ 子树大小 × C);对总节点数 n,最坏可视作 O(nC)。

常见例题

题目:给定一棵 n 个点的树,点 i 有体积 w[i] 和价值 v[i],给一个容量 C,选择若干点组成集合 S,要求 S 是某个结点的子树内的点集(比如选整棵树就是 root 的子树),且 ∑w[i] ≤ C,最大化 ∑v[i]。

做法:把树在 root 处定根后做 DFS。每个结点先把自己作为一个“组”的初值 f[i][w[i]] = v[i],再依次把每个儿子的 DP 表与当前表做一遍“分组背包”式合并:枚举当前容量 k,从大到小枚举给儿子的容量 t,把 f[i][k] 用 f[i][k−t] + f[child] [t] 更新即可。最终答案取 root 的 f[root] [0..C] 的最大值。

代码
// 树上背包(闭区间容量 [0,C]),基于“逐个儿子并入”的树形 DP
// 功能:给定树、每点体积与价值、总容量C,求某个根子树内的最大价值
// 说明:若要做“整棵树任意连通子集”或“必须选父才能选子”等限制,可在合并时加判
struct TreeKnapsack
{
    int n;                              // 节点数
    int cap;                            // 背包容量C
    vector<int> head, to, nxt;          // 邻接表
    vector<int> weight, value;          // 点权体积与价值
    vector<vector<ll>> dp;       // dp[u][k]:u子树内体积恰为k的最大价值
    int edgeCnt;

    // 构造函数
    TreeKnapsack(int n_ = 0, int cap_ = 0) { init(n_, cap_); }

    // 初始化图与 DP 表
    void init(int n_, int cap_)
    {
        n = n_, cap = cap_;
        head.assign(n + 1, -1);
        to.assign(2 * n + 5, 0);
        nxt.assign(2 * n + 5, -1);
        weight.assign(n + 1, 0);
        value.assign(n + 1, 0);
        dp.assign(n + 1, vector<ll>(cap + 1, (ll)-4e18));
        edgeCnt = 0;
    }

    // 加一条无向边 u-v
    void addEdge(int u, int v)
    {
        to[edgeCnt] = v; nxt[edgeCnt] = head[u]; head[u] = edgeCnt++;
        to[edgeCnt] = u; nxt[edgeCnt] = head[v]; head[v] = edgeCnt++;
    }

    // 设定点权
    void setNode(int u, int w, int val)
    {
        weight[u] = w;
        value[u] = val;
    }

    // 树上背包的 DFS,p为当前点,fa为父节点,覆盖子树整段
    void dfs(int p, int fa)
    {
        // 初始化当前点只选自己的状态
        for (int k = 0; k <= cap; k++) dp[p][k] = (ll)-4e18;
        if (weight[p] <= cap) dp[p][weight[p]] = value[p];

        // 依次合并每个儿子
        for (int e = head[p]; e != -1; e = nxt[e])
        {
            int v = to[e]; if (v == fa) continue;
            dfs(v, p);

            // 临时数组承接合并结果
            vector<ll> tmp(cap + 1, (ll)-4e18);

            // 把儿子v的整段 [0,cap] 合并到p的整段 [0,cap]
            for (int k = 0; k <= cap; k++)
            {
                if (dp[p][k] < (ll)-3e18) continue;
                tmp[k] = max(tmp[k], dp[p][k]); // 可以不从v拿
                for (int t = 0; t + k <= cap; t++)
                {
                    if (dp[v][t] < (ll)-3e18) tmp[k + t] = max(tmp[k + t], dp[p][k] + dp[v][t]);
                }
            }
            dp[p].swap(tmp);
        }
    }

    // 计算以root为根的答案,返回在容量[0,cap]闭区间内的最大价值
    ll solve(int root)
    {
        dfs(root, 0);
        ll ans = 0;
        for (int k = 0; k <= cap; k++) ans = max(ans, dp[root][k]);
        return ans;
    }
};

SOS DP

算法介绍

SOS DP(Sum Over Subsets DP)用快速子集/超集 Zeta 变换在 O(n 2^n) 时间内把所有子集的和一次性算完,n 为比特数。给定数组 f[S](S 为 0..(1<<n)-1 的掩码),常见需求是得到 g[S] = ∑{T ⊆ S} f[T] 或 h[S] = ∑ f[T]。做法是在每一位 b 上做一轮“维度前缀和”:若 S 在第 b 位为 1,就把 f[S] 加到 f[S without b](子集 Zeta);或反向加到 f[S with b](超集 Zeta)。对应的 Möbius 反演只需把“加”改成“减”,即可把前缀和还原。该技巧是“子集前缀和”的标准实现,属于快速 Zeta / Möbius 变换的特例。

常见例题

题目: 给定数组 a[S],S 是 n 位掩码,要求对每个 S 输出 g[S] = ∑_{T ⊆ S} a[T],并且支持把 g 转回 a。

做法: 先把 f 初始化为 a,执行子集 Zeta 变换得到 g。若要从 g 恢复 a,执行子集 Möbius 反演。若需要“超集求和”,把转移方向改为向包含当前位的超集累加即可。子集枚举的对照写法可以在 cp-algorithms 里看到,对比其 O(3^n) 或 O(n2^n) 的做法有助于理解速度优势。

代码
// SOS DP:子集/超集 Zeta 与 Möbius 反演,闭区间掩码范围 [0, (1<<n)-1]
// 功能:zetaSubset 计算 g[S]=∑_{T⊆S} f[T];mobiusSubset 反演回原f
//      zetaSuperset 计算 g[S]=∑_{T⊇S} f[T];mobiusSuperset 反演回原f
// 说明:类型T需支持+=、-= 运算,常用为 ll 或 int
template <typename T>
struct SosDp
{
    int n;              // 比特数
    int size;           // 2^n
    vector<T> f;        // 工作数组

    // 构造与初始化
    SosDp(int n_ = 0) { init(n_); }

    void init(int n_)
    {
        n = n_;
        size = 1 << n;
        f.assign(size, T{});
    }

    // 装入原数组 a[0..(1<<n)-1]
    template <class Arr>
    void load(const Arr &a)
    {
        int m = (int)a.size();
        if (m != size) init(__lg(m));
        for (int s = 0; s < size; s++) f[s] = a[s];
    }

    // 子集 Zeta:得到 g[S] = ∑_{T⊆S} f[T]
    // 对每一位b,若S含b,则把 f[S^(1<<b)] += f[S],相当于对“去掉b”的子集做前缀和
    void zetaSubset()
    {
        for (int b = 0; b < n; b++)
        {
            for (int s = 0; s < size; s++)
            {
                if (s & (1 << b)) f[s] += f[s ^ (1 << b)];
            }
        }
    }

    // 子集 Möbius:把 g 反演回原 f
    void mobiusSubset()
    {
        for (int b = 0; b < n; b++)
        {
            for (int s = 0; s < size; s++)
            {
                if (s & (1 << b)) f[s] -= f[s ^ (1 << b)];
            }
        }
    }

    // 超集 Zeta:得到 g[S] = ∑_{T⊇S} f[T]
    // 对每一位b,若S不含b,则把 f[S] 加到 f[S|(1<<b)]
    void zetaSuperset()
    {
        for (int b = 0; b < n; b++)
        {
            for (int s = 0; s < size; s++)
            {
                if ((s & (1 << b)) == 0) f[s] += f[s | (1 << b)];
            }
        }
    }

    // 超集 Möbius:把 g 反演回原 f
    void mobiusSuperset()
    {
        for (int b = 0; b < n; b++)
        {
            for (int s = 0; s < size; s++)
            {
                if ((s & (1 << b)) == 0) f[s] -= f[s | (1 << b)];
            }
        }
    }

    // 读取当前数组
    const vector<T>& get() const { return f; }
};

五、图论

树的直径

算法介绍

树的直径是指树上最远两点间的最短路径长度与端点。经典做法是“两次 BFS/DFS”:先从任意点 s 出发找到最远点 u,再从 u 出发找到最远点 v,路径 u–v 即为一条直径。这一结论与实现被广泛采用,整体复杂度 O(n) 且只需邻接表即可完成。直径中心可由 u–v 的路径中点得到,如果路径长度为偶数则中心唯一,否则有两个中心。

代码
// 树的直径(两次搜索版),建图使用 vector<vector<int>> 存图,禁止链式向前星
// 功能:求任意一条直径的两个端点与长度,可选恢复路径
struct TreeDiameter
{
    int n;                                   // 点数
    vector<vector<int>> g;                   // 邻接表
    vector<int> parent, dist;                // 父指针与距离

    // 构造:给定点数初始化空图
    TreeDiameter(int n_ = 0) { init(n_); }

    // 初始化:清空图与辅助数组
    void init(int n_)
    {
        n = n_;
        g.assign(n, {});
        parent.assign(n, -1);
        dist.assign(n, 0);
    }

    // 加边:无向边 u-v
    void addEdge(int u, int v)
    {
        g[u].push_back(v); g[v].push_back(u);
    }

    // 从源点 s 做一次 BFS,返回最远点
    int farthestFrom(int s)
    {
        fill(parent.begin(), parent.end(), -1);
        fill(dist.begin(), dist.end(), 0);
        queue<int> q; q.push(s);
        parent[s] = s;
        int last = s;
        while (!q.empty())
        {
            int u = q.front(); q.pop();
            last = u;
            for (int v : g[u])
            {
                if (parent[v] != -1) continue;
                parent[v] = u; dist[v] = dist[u] + 1; q.push(v);
            }
        }
        return last;
    }

    // 求直径端点与长度,返回 {u, v, len}
    tuple<int,int,int> diameter()
    {
        int u = farthestFrom(0);
        int v = farthestFrom(u);
        return {u, v, dist[v]};
    }

    // 恢复一条直径路径
    vector<int> diameterPath()
    {
        auto [u, v, _] = diameter();
        vector<int> path;
        for (int x = v; ; x = parent[x])
        {
            path.push_back(x);
            if (x == u) break;
        }
        reverse(path.begin(), path.end());
        return path;
    }
};

树的中心

算法介绍

树的中心指删除该点(或两点)后,每个连通块的最大大小最小;等价于直径路径的中点。基于前节求得的直径路径,若直径长度为偶数则中心唯一,为奇数则有两个中心;也可使用“剥叶法”自外向内分层剥离直至剩下 1 或 2 个点。

代码
// 树的中心:基于直径路径求中心(O(n))
// 功能:返回 1 或 2 个中心点
struct TreeCenter
{
    TreeDiameter td;             // 复用直径模板

    TreeCenter(int n_ = 0) : td(n_) {}

    void init(int n_) { td.init(n_); }

    void addEdge(int u, int v) { td.addEdge(u, v); }

    vector<int> centers()
    {
        vector<int> path = td.diameterPath();
        int len = (int)path.size() - 1;
        if (len % 2 == 0) return { path[len / 2] };
        return { path[len / 2], path[len / 2 + 1] };
    }
};

树的重心

算法介绍

树的重心是指删除该点后,每个连通块的节点数均不超过 n/2,重心最多两个。经典做法是一次 DFS 计算子树大小,令每点的最大“子块”规模为 max(maxChild, n - size[u]),取该量最小者即为重心。

代码
// 树的重心:一次 DFS 统计子树规模并评估最大块
// 功能:返回所有重心(1 或 2 个)
struct TreeCentroid
{
    int n;
    vector<vector<int>> g;
    vector<int> sz, vis;     // 子树大小与访问标记

    TreeCentroid(int n_ = 0) { init(n_); }

    void init(int n_)
    {
        n = n_;
        g.assign(n, {});
        sz.assign(n, 0);
        vis.assign(n, 0);
    }

    void addEdge(int u, int v)
    {
        g[u].push_back(v); g[v].push_back(u);
    }

    int dfsSize(int u, int p)
    {
        sz[u] = 1;
        for (int v : g[u])
        {
            if (v == p) continue;
            sz[u] += dfsSize(v, u);
        }
        return sz[u];
    }

    vector<int> getCentroids()
    {
        dfsSize(0, -1);
        int best = n + 1;
        vector<int> ans;
        auto dfs = [&](auto &&self, int u, int p) -> void
        {
            int maxPart = n - sz[u];
            for (int v : g[u])
            {
                if (v == p) continue;
                maxPart = max(maxPart, sz[v]);
                self(self, v, u);
            }
            if (maxPart < best) { best = maxPart; ans.assign(1, u); }
            else if (maxPart == best) ans.push_back(u);
        };
        dfs(dfs, 0, -1);
        return ans;
    }
};

基环树

算法介绍

基环树(pseudotree)是指在一个连通无向图上仅含一条简单环,其余边均在环的树枝上,即 n 个点、n 条边的连通图。典型套路是先在线性时间内找出一条环并标记环上节点,再对每个环点向外做 DFS 或 BFS 处理分支树上的问题(如到环距离、每个连通块统计等)。找环可用 DFS 的“父边回退”法;或者利用“度为 1 的点反复剥离”的拓扑消圈方法。

代码
// 基环树:找环 + 标记环 + 计算到环最短距离
// 功能:findCycle() 标记唯一简单环上的点;distToCycle() 计算每点到环的距离
struct PseudoTree
{
    int n;
    vector<vector<int>> g;
    vector<int> parent, vis, inCycle;   // 0 未访,1 在栈,2 出栈结束;inCycle 标记环点
    vector<int> cycleStack;             // 当前 DFS 栈记录
    vector<int> cycle;                  // 环上节点集合

    PseudoTree(int n_ = 0) { init(n_); }

    void init(int n_)
    {
        n = n_;
        g.assign(n, {});
        parent.assign(n, -1);
        vis.assign(n, 0);
        inCycle.assign(n, 0);
        cycleStack.clear();
        cycle.clear();
    }

    void addEdge(int u, int v)
    {
        g[u].push_back(v); g[v].push_back(u);
    }

    // 内部:当在 u 的 DFS 中遇到回到栈内的点 v 时,从栈上切出环
    void extractCycle(int u, int v)
    {
        cycle.clear();
        cycle.push_back(v);
        for (int i = (int)cycleStack.size() - 1; i >= 0; i--)
        {
            cycle.push_back(cycleStack[i]);
            if (cycleStack[i] == v) break;
        }
        for (int x : cycle) inCycle[x] = 1;
        reverse(cycle.begin(), cycle.end());
    }

    // DFS 找到唯一环,返回是否已经找到
    bool dfsFind(int u, int p)
    {
        vis[u] = 1; cycleStack.push_back(u);
        for (int v : g[u])
        {
            if (v == p) continue;
            if (vis[v] == 0)
            {
                parent[v] = u;
                if (dfsFind(v, u)) return true;
            }
            else if (vis[v] == 1)
            {
                extractCycle(u, v);
                return true;
            }
        }
        vis[u] = 2; cycleStack.pop_back();
        return false;
    }

    // 标记唯一简单环上的全部点
    void findCycle()
    {
        dfsFind(0, -1);
    }

    // 计算每个点到环的最短距离,多源 BFS
    vector<int> distToCycle()
    {
        vector<int> dist(n, -1);
        queue<int> q;
        for (int i = 0; i < n; i++)
        {
            if (inCycle[i]) dist[i] = 0, q.push(i);
        }
        while (!q.empty())
        {
            int u = q.front(); q.pop();
            for (int v : g[u])
            {
                if (dist[v] != -1) continue;
                dist[v] = dist[u] + 1; q.push(v);
            }
        }
        return dist;
    }
};

最近公共祖先

算法介绍

二进制提升做法先用一次深度优先搜索得到每个点的深度与进出时间,再预处理 up[k][v] 表示点 v 的第 2^k 级祖先。查询时把较深的点按二进制分解向上跳齐深度,然后同时尝试从高位到低位比较祖先,能同时抬升就抬升,最终它们的父亲即为最近公共祖先。该方法预处理 O(n log n),单次查询 O(log n)。

代码
// 最近公共祖先(LCA)- 二进制提升
// 功能:O(n log n) 预处理,O(log n) 查询 LCA/距离/第 k 级祖先
struct LCA
{
    int n, lg;                              // 点数与最高位
    vector<vector<int>> g;                  // 邻接表
    vector<int> depth, tin, tout;           // 深度与进出时间
    vector<vector<int>> up;                 // up[k][v] 为 v 的 2^k 级祖先
    int timer;                              // 时间戳

    // 构造:给定点数初始化空图
    LCA(int n_ = 0) { init(n_); }

    // 初始化:清空并准备空间
    void init(int n_)
    {
        n = n_;
        lg = n ? __lg(n) + 1 : 1;
        g.assign(n, {});
        depth.assign(n, 0);
        tin.assign(n, 0);
        tout.assign(n, 0);
        up.assign(lg, vector<int>(n, 0));
        timer = 0;
    }

    // 加边:无向边 u-v
    void addEdge(int u, int v)
    {
        g[u].push_back(v), g[v].push_back(u);
    }

    // 预处理:指定根 r,计算 depth / tin / tout / up
    void build(int r = 0)
    {
        auto dfs = [&](auto &&self, int u, int p) -> void
        {
            tin[u] = ++timer;
            up[0][u] = p == -1 ? u : p;
            for (int k = 1; k < lg; k++) up[k][u] = up[k - 1][up[k - 1][u]];
            for (int v : g[u])
            {
                if (v == p) continue;
                depth[v] = depth[u] + 1;
                self(self, v, u);
            }
            tout[u] = ++timer;
        };
        dfs(dfs, r, -1);
    }

    // 判定 u 是否是 v 的祖先
    bool isAncestor(int u, int v)
    {
        return tin[u] <= tin[v] && tout[v] <= tout[u];
    }

    // 查询 u 与 v 的最近公共祖先
    int lca(int u, int v)
    {
        if (isAncestor(u, v)) return u;
        if (isAncestor(v, u)) return v;
        for (int k = lg - 1; k >= 0; k--) if (!isAncestor(up[k][u], v)) u = up[k][u];
        return up[0][u];
    }

    // 查询 u 的第 k 级祖先(k>=0),若越界则返回根的祖先自身
    int kthAncestor(int u, int k)
    {
        for (int i = 0; i < lg; i++) if (k >> i & 1) u = up[i][u];
        return u;
    }

    // 查询两点距离
    int dist(int u, int v)
    {
        int w = lca(u, v);
        return depth[u] + depth[v] - 2 * depth[w];
    }
};

树链剖分

算法介绍

树链剖分把树分解成若干条重链与轻边,使任意一条简单路径被 O(log n) 条链覆盖,从而把树上的路径与子树问题转化为若干段连续的 dfs 序区间问题,再用线段树或树状数组维护。重儿子定义为子树规模最大的儿子,剖分时每次优先沿重儿子延伸到同一条链。该方法对路径加减、区间合并、以及配合懒标记做区间更新等问题非常通用。

常见例题

题目: 给定一棵 n 点树,并支持 q 次操作,操作一为把路径 [u,v] 上所有点的权值加上 x,操作二为查询路径 [u,v] 上权值之和。

做法: 用树链剖分把路径拆成若干个 dfs 序的连续段,每段在底层线段树上执行一次区间加或区间和即可,合并所有段的查询结果即为答案。

代码
// 树链剖分(闭区间 dfs 序),与线段树对接使用
// 功能:提供路径上的若干段 [l,r](dfs 序)以便在线段树上做区间操作;也支持子树一段
// 说明:回调 onRange(l,r) 交给外层数据结构(如线段树)处理,路径查询/修改只需遍历若干段
struct HLD
{
    int n;                                  // 点数
    vector<vector<int>> g;                  // 邻接表
    vector<int> parent, depth, heavy;       // 父亲、深度、重儿子
    vector<int> head, pos, sz;              // 链顶、dfs 序位置、子树大小
    int cur;                                

    HLD(int n_ = 0) { init(n_); }

    void init(int n_)
    {
        n = n_;
        g.assign(n, {});
        parent.assign(n, -1);
        depth.assign(n, 0);
        heavy.assign(n, -1);
        head.assign(n, 0);
        pos.assign(n, 0);
        sz.assign(n, 0);
        cur = 0;
    }

    void addEdge(int u, int v)
    {
        g[u].push_back(v), g[v].push_back(u);
    }

    // 预处理:求 size 与 heavy
    void dfsSize(int u, int p)
    {
        parent[u] = p;
        sz[u] = 1;
        int mx = 0;
        for (int v : g[u])
        {
            if (v == p) continue;
            depth[v] = depth[u] + 1;
            dfsSize(v, u);
            if (sz[v] > mx) heavy[u] = v, mx = sz[v];
            sz[u] += sz[v];
        }
    }

    // 预处理:分配 head 与 pos
    void dfsDecomp(int u, int h)
    {
        head[u] = h;
        pos[u] = cur++;
        if (heavy[u] != -1) dfsDecomp(heavy[u], h);
        for (int v : g[u])
        {
            if (v == parent[u] || v == heavy[u]) continue;
            dfsDecomp(v, v);
        }
    }

    // 对整棵树以 root 为根做剖分
    void build(int root = 0)
    {
        dfsSize(root, -1);
        dfsDecomp(root, root);
    }

    // 路径分解:把 u-v 路径拆成若干个 dfs 序闭区间段,逐段调用 onRange(l,r)
    template<class F>
    void forEachPath(int u, int v, F onRange)
    {
        while (head[u] != head[v])
        {
            if (depth[head[u]] < depth[head[v]]) swap(u, v);
            onRange(pos[head[u]], pos[u]);
            u = parent[head[u]];
        }
        if (depth[u] > depth[v]) swap(u, v);
        onRange(pos[u], pos[v]);
    }

    // 子树一段:返回以 u 为根的子树对应的 dfs 序闭区间
    pair<int,int> subtree(int u)
    {
        return { pos[u], pos[u] + sz[u] - 1 };
    }
};

启发式合并 (DSU on Tree)

算法介绍

DSU on Tree(又称 Sack、小并大)在每个结点处理完所有轻儿子后,清空轻儿子对答案的贡献,仅保留重儿子的贡献,再把轻儿子的节点数据合并进来,借助“轻儿子总被清空”的事实把总复杂度压到 O(n log n)。该技巧适合“对每个结点 u 求其整棵子树上的某种统计量”的批量问题,典型如“子树中出现频率为 k 的颜色数”等。

常见例题

题目: 给定一棵 n 点带颜色的树,要求对每个点 u 输出其子树中不同颜色的数量。

做法: 以 u 为根的子树先递归处理所有轻儿子并清空其增加的计数,保留重儿子的计数,然后把轻儿子的节点逐个加入当前计数,最后答案即为当前计数器的不同颜色数。若题目改为统计出现次数恰为 k 的颜色数,则把计数器替换为“频率桶”即可。

代码
// DSU on Tree(Sack)通用框架,闭区间语义体现在 dfs 序编号的使用上
// 功能:对每个结点 u 计算其子树统计量;回调 onAdd(x,delta) 定义如何计入/移除一个结点
// 说明:先处理并清空所有轻儿子,保留重儿子,然后把轻儿子元素并入,最后在 u 点处记录答案
struct DsuOnTree
{
    int n, root;
    vector<vector<int>> g;
    vector<int> sz, heavy, parent;
    vector<int> order, in, out;   // dfs 序闭区间 [in[u], out[u]]
    int timer;

    DsuOnTree(int n_ = 0) { init(n_); }

    void init(int n_)
    {
        n = n_;
        g.assign(n, {});
        sz.assign(n, 0);
        heavy.assign(n, -1);
        parent.assign(n, -1);
        order.assign(n, 0);
        in.assign(n, 0);
        out.assign(n, 0);
        timer = 0; root = 0;
    }

    void addEdge(int u, int v)
    {
        g[u].push_back(v), g[v].push_back(u);
    }

    // 预处理 size / heavy 与 dfn
    void build(int r = 0)
    {
        root = r;
        auto dfs1 = [&](auto &&self, int u, int p) -> void
        {
            parent[u] = p;
            sz[u] = 1;
            in[u] = timer; order[timer] = u; timer++;
            int mx = 0;
            for (int v : g[u])
            {
                if (v == p) continue;
                self(self, v, u);
                if (sz[v] > mx) heavy[u] = v, mx = sz[v];
                sz[u] += sz[v];
            }
            out[u] = timer - 1;
        };
        dfs1(dfs1, r, -1);
    }

    // 主过程:onAdd(x, +1/-1) 定义计数方式;onAnswer(u) 在结点 u 处收割答案
    template<class AddF, class AnsF>
    void run(AddF onAdd, AnsF onAnswer)
    {
        auto addSubtree = [&](auto &&self, int u, int p, int delta) -> void
        {
            onAdd(u, delta);
            for (int v : g[u])
            {
                if (v == p) continue;
                if (v == heavy[u]) continue;
                self(self, v, u, delta);
            }
        };

        auto dfs2 = [&](auto &&self, int u, int p, bool keep) -> void
        {
            for (int v : g[u])
            {
                if (v == p || v == heavy[u]) continue;
                self(self, v, u, false);
            }
            if (heavy[u] != -1) self(self, heavy[u], u, true);
            for (int v : g[u])
            {
                if (v == p || v == heavy[u]) continue;
                addSubtree(addSubtree, v, u, +1);
            }
            onAdd(u, +1);
            onAnswer(u);
            if (!keep)
            {
                addSubtree(addSubtree, u, p, -1);
            }
        };
        dfs2(dfs2, root, -1, true);
    }
};

重构树 (虚拟树构建)

算法介绍

虚拟树(也称辅助树、压缩树)针对给定的一组关键点 S,将原树按这些点及其两两 LCA 压缩成一棵只包含必要结构的树,从而把许多“只与 S 有关”的树上问题降到更小规模上处理。经典构造步骤是按 dfs 序排序关键点,依次用栈维护一条上升链,遇到新点时用 LCA 确定分叉位置并连边,最终得到一棵以这些关键点与必要 LCA 构成的树。

常见例题

题目: 给定一棵树与若干次查询,每次给定一个点集 S,要求只在由 S 压缩得到的虚拟树上做一次 DP 求答案。

做法: 先用 LCA 预处理出 dfs 序与最近公共祖先。对 S 按 dfs 序排序,依序将两两相邻点的 LCA 加入,再去重后用栈按“入栈节点始终形成一条自根向下的链”的规则连边,得到虚拟树,随后在这棵小树上进行一次 DP。

代码
// 虚拟树构建(Auxiliary Tree),依赖已建好的 LCA(含 tin)
// 功能:给定点集 key,返回由 key 及其必要 LCA 组成的虚拟树邻接表与节点列表
struct VirtualTree
{
    LCA *plca;                    // 指向外部 LCA(需已 build)
    vector<vector<int>> vt;       // 虚拟树邻接表
    vector<int> nodes;            // 虚拟树包含的全部点(key ∪ lca)
    
    VirtualTree(LCA *p = nullptr) : plca(p) {}

    // 构建:传入关键点集合 key(任意顺序、互不相同),返回虚拟树根
    int build(const vector<int> &key)
    {
        nodes = key;
        auto &tin = plca->tin;
        auto cmp = [&](int a, int b){ return tin[a] < tin[b]; };

        // 把相邻点的 LCA 也加入
        vector<int> vec = nodes;
        sort(vec.begin(), vec.end(), cmp);
        int m = vec.size();
        for (int i = 0; i + 1 < m; i++)
        {
            int w = plca->lca(vec[i], vec[i + 1]);
            vec.push_back(w);
        }
        sort(vec.begin(), vec.end(), cmp);
        vec.erase(unique(vec.begin(), vec.end()), vec.end());
        nodes = vec;

        // 用栈按 dfs 序连边
        vt.assign(nodes.size(), {});
        vector<int> st;
        auto id = [&](int x)
        {
            return (int)(lower_bound(nodes.begin(), nodes.end(), x, cmp) - nodes.begin());
        };
        auto link = [&](int u, int v) { vt[u].push_back(v), vt[v].push_back(u); };

        st.push_back(0);                // vec[0] 为 dfs 序最早的点,作为根
        for (int i = 1; i < (int)nodes.size(); i++)
        {
            int x = nodes[i];
            while (!st.empty() && !plca->isAncestor(st.back(), x)) st.pop_back();
            int u = st.back(), iu = id(u), ix = id(x);
            link(iu, ix);
            st.push_back(x);
        }
        return 0;                       // 返回虚拟树根在 nodes 数组中的下标(一般为 0)
    }
};

图遍历

深度优先搜索 (DFS)

算法介绍

深度优先搜索是一种从起点开始尽可能向前走到不能再前进时回溯的遍历方法。它的本质是用递归或栈模拟对图的“走迷宫”过程,适用于连通性判定、拓扑排序、寻找连通分量、割点与桥的求解等。每次访问一个点时,把它标记为已访问,然后继续递归访问所有未访问的邻接点,直到无路可走再回退。

代码
// 深度优先搜索 (DFS)
// 功能:遍历整张图并统计连通分量
struct GraphDFS
{
    int n;                 // 节点数
    vector<vector<int>> g; // 邻接表
    vector<int> vis;       // 访问标记

    // 构造函数,初始化 n 个点的图
    GraphDFS(int n_ = 0) { init(n_); }

    // 初始化,清空邻接表和标记数组
    void init(int n_)
    {
        n = n_;
        g.assign(n, {});
        vis.assign(n, 0);
    }

    // 添加一条无向边
    void addEdge(int u, int v)
    {
        g[u].push_back(v);
        g[v].push_back(u);
    }

    // 从节点 u 开始的 DFS,覆盖连通块
    void dfs(int u)
    {
        vis[u] = 1;
        for (int v : g[u])
            if (!vis[v])
                dfs(v);
    }

    // 统计连通分量个数
    int countComponents()
    {
        int cnt = 0;
        for (int i = 0; i < n; i++)
            if (!vis[i])
            {
                dfs(i);
                cnt++;
            }
        return cnt;
    }
};

广度优先搜索 (BFS)

算法介绍

广度优先搜索是一种从起点开始,先遍历所有与起点距离为 1 的点,再遍历所有与起点距离为 2 的点,依次类推的图遍历方法。其本质是队列层次遍历,适用于最短路、层次图构建、判定二分图等。BFS 每次出队一个点,访问所有未访问的邻接点并入队,直到队列为空。

代码
// 广度优先搜索 (BFS)
// 功能:在无权图中计算从起点到其他点的最短路
struct GraphBFS
{
    int n;                 // 节点数
    vector<vector<int>> g; // 邻接表
    vector<int> dist;      // 距离数组

    // 构造函数,初始化 n 个点的图
    GraphBFS(int n_ = 0) { init(n_); }

    // 初始化,清空邻接表和距离数组
    void init(int n_)
    {
        n = n_;
        g.assign(n, {});
        dist.assign(n, -1);
    }

    // 添加一条无向边
    void addEdge(int u, int v)
    {
        g[u].push_back(v);
        g[v].push_back(u);
    }

    // 从起点 s 开始的 BFS,计算所有点到 s 的最短路
    void bfs(int s)
    {
        queue<int> q;
        dist[s] = 0;
        q.push(s);
        while (!q.empty())
        {
            int u = q.front();
            q.pop();
            for (int v : g[u])
                if (dist[v] == -1)
                {
                    dist[v] = dist[u] + 1;
                    q.push(v);
                }
        }
    }
};

图连通性

割点与桥 (关节点 & 割边)

算法介绍

在无向图中,若删除一个点及其关联边导致连通块数增加,这个点就是割点;若删除一条边导致连通块数增加,这条边就是桥。用一次深度优先搜索维护时间戳 dfn 与返祖值 low,沿树边向下搜索时记录父子关系,遇到返祖边就更新 low。对根的儿子数不少于二是割点;对非根,若存在儿子 v 使得 low[v] ≥ dfn[u] 则 u 是割点;若存在儿子 v 使得 low[v] > dfn[u] 则边 (u,v) 是桥。整套判定在线性时间完成。

常见例题

给定一个 n 点 m 边的无向图,输出所有割点与所有桥的集合,然后把所有桥删除后统计剩余的连通块个数并输出。做法是一次 DFS 求出 dfn 与 low,收集割点与桥,最后在桥删去后用并查集或二次 DFS 统计连通块数量即可。

代码
// 割点与桥(闭区间对本节无意义,这里按图的顶点集合与边集合来表述)
// 功能:O(n+m) 找出所有割点与桥,并可在删除桥后按需做进一步统计
struct CutAndBridge
{
    int n;                                      // 顶点数
    vector<vector<int>> g;                      // 无向图邻接表
    vector<int> dfn, low;                       // 时间戳与返祖值
    vector<char> isCut;                         // 是否为割点
    vector<pair<int,int>> bridges;              // 桥集合(u,v)保证u为父、v为子
    int timer;                                  // DFS 时钟

    // 构造与初始化
    CutAndBridge(int n_ = 0) { init(n_); }

    // 功能:重置规模并清空图
    void init(int n_)
    {
        n = n_;
        g.assign(n + 1, {});
        dfn.assign(n + 1, 0);
        low.assign(n + 1, 0);
        isCut.assign(n + 1, 0);
        bridges.clear();
        timer = 0;
    }

    // 功能:加无向边 u-v
    void addEdge(int u, int v)
    {
        g[u].push_back(v); g[v].push_back(u);
    }

    // 功能:主过程,找割点与桥
    void build()
    {
        for (int i = 1; i <= n; i++) if (!dfn[i]) dfsRoot(i);
    }

    // 功能:返回所有割点下标
    vector<int> cutVertices()
    {
        vector<int> res;
        for (int i = 1; i <= n; i++) if (isCut[i]) res.push_back(i);
        return res;
    }

    // 功能:返回所有桥(父,子)
    const vector<pair<int,int>>& getBridges() { return bridges; }

private:
    // 功能:从 root 开始的一棵 DFS 树根的特判入口
    void dfsRoot(int root)
    {
        int child = 0;
        auto selfDfs = [&](auto &&self, int u, int p) -> void
        {
            dfn[u] = low[u] = ++timer;
            for (int v : g[u])
            {
                if (v == p) continue;
                if (!dfn[v])
                {
                    if (u == root) child++;
                    self(self, v, u);
                    low[u] = min(low[u], low[v]);
                    if (low[v] >= dfn[u] && u != root) isCut[u] = 1;
                    if (low[v] > dfn[u]) bridges.emplace_back(u, v);
                }
                else low[u] = min(low[u], dfn[v]);
            }
        };
        selfDfs(selfDfs, root, 0);
        if (child >= 2) isCut[root] = 1;
    }
};

双连通分量 (边/点双连通)

算法介绍

点双连通分量是极大顶点集合,使得任意两点至少存在两条点不交路径;边双连通分量是极大边集合,使得任意两点至少存在两条边不交路径。点双常用基于 Tarjan 的栈做法:DFS 时把树边与返祖边压栈,当遇到一个结点 u 的某个儿子 v 满足 low[v] ≥ dfn[u] 就不断弹栈直到 (u,v),弹出的一组边上的点构成一个点双。边双可以先用上节找出所有桥,再把桥删去,按剩余边做一次 DFS 得到每个边双分量编号。

常见例题

给定无向图,先输出其点双分量的个数及每个分量包含的点集;再输出边双分量的个数。做法是用点双 Tarjan 边栈法依次弹出每个分量并收集点集;边双先找桥,再把桥视作删除,用一次 DFS 标好每个连通块编号即为边双编号。若需要建“圆方树”,可把每个点双压成一个方点,并把包含关系连边形成树形结构以支持后续树上查询。

代码
// 点双与边双(无向图)
// 功能:输出全部点双分量的点集;标出所有边双编号;保留桥以便后续使用
struct Biconnected
{
    int n;                                      // 顶点数
    vector<vector<int>> g;                      // 无向图
    vector<int> dfn, low;                       // 时间戳与返祖值
    int timer;                                  // 时钟
    vector<pair<int,int>> edgeStack;            // 点双用的边栈
    vector<vector<int>> vertexBcc;              // 点双分量,每个分量的点集
    vector<pair<int,int>> bridges;              // 桥集合(父,子)
    vector<int> compId;                         // 边双编号(给每个点一个编号)

    Biconnected(int n_ = 0) { init(n_); }

    // 功能:重置图
    void init(int n_)
    {
        n = n_;
        g.assign(n + 1, {});
        dfn.assign(n + 1, 0);
        low.assign(n + 1, 0);
        timer = 0;
        edgeStack.clear();
        vertexBcc.clear();
        bridges.clear();
        compId.assign(n + 1, 0);
    }

    // 功能:加边
    void addEdge(int u, int v)
    {
        g[u].push_back(v); g[v].push_back(u);
    }

    // 功能:构建点双与桥
    void buildPointBcc()
    {
        for (int i = 1; i <= n; i++) if (!dfn[i]) dfsPointBcc(i, 0);
    }

    // 功能:构建边双编号(把桥视作删除后再连通分量编号)
    void buildEdgeBcc()
    {
        vector<vector<pair<int,int>>> adj(n + 1);
        vector<vector<int>> id(n + 1);
        // 先把所有边加入,等会儿过滤桥
        for (int u = 1; u <= n; u++) for (int v : g[u]) if (u < v) adj[u].push_back({v, 1}), adj[v].push_back({u, 1});
        // 标记桥集便于过滤
        unordered_set<ll> isBridge;
        auto key = [&](int a, int b) -> ll { if (a > b) swap(a, b); return (1LL * a << 32) ^ b; };
        for (auto [u, v] : bridges) isBridge.insert(key(u, v));
        vector<char> vis(n + 1, 0);
        int cid = 0;
        auto dfs = [&](auto &&self, int u) -> void
        {
            vis[u] = 1; compId[u] = cid;
            for (auto [v, w] : adj[u])
            {
                if (isBridge.count(key(u, v))) continue;
                if (!vis[v]) self(self, v);
            }
        };
        for (int i = 1; i <= n; i++) if (!vis[i]) { cid++; dfs(dfs, i); }
    }

private:
    // 功能:点双 Tarjan(递归时维护边栈),当前在 u,父边来自 p
    void dfsPointBcc(int u, int p)
    {
        dfn[u] = low[u] = ++timer;
        for (int v : g[u])
        {
            if (v == p) continue;
            if (!dfn[v])
            {
                edgeStack.emplace_back(u, v);
                dfsPointBcc(v, u);
                low[u] = min(low[u], low[v]);
                if (low[v] >= dfn[u])
                {
                    vector<int> comp;
                    while (true)
                    {
                        auto e = edgeStack.back(); edgeStack.pop_back();
                        comp.push_back(e.first); comp.push_back(e.second);
                        if (e.first == u && e.second == v) break;
                    }
                    sort(comp.begin(), comp.end()); comp.erase(unique(comp.begin(), comp.end()), comp.end());
                    vertexBcc.push_back(comp);
                }
                if (low[v] > dfn[u]) bridges.emplace_back(u, v);
            }
            else if (dfn[v] < dfn[u])
            {
                edgeStack.emplace_back(u, v);
                low[u] = min(low[u], dfn[v]);
            }
        }
    }
};

强连通分量 (SCC 缩点)

算法介绍

在有向图中,强连通分量是极大点集,使得任意两点相互可达。Tarjan 算法用一次 DFS 与栈维护 index 与 lowlink,遇到一个结点的 lowlink 等于自身 index 时,从栈顶不断弹出直到它,形成一个 SCC。把所有 SCC 缩为点并保留跨分量边即可得到凝聚图,这是一张有向无环图,常在上面做拓扑序与 DP。此处实现与经典 Tarjan 完全一致,时间复杂度 O(n+m),凝聚图无环的原因是若有环则原图中这些分量应属同一 SCC。

常见例题

给定 n 个点的有向图,先缩点得到凝聚图,再计算每个 SCC 的点权之和,最后在凝聚 DAG 上求最长路径的最大点权和。做法是 Tarjan 分解并得到 compId 与 compWeight,按照 compId 把跨分量的边加入新图并去重,随后拓扑 DP 求解答案。

代码
// Tarjan 强连通分量 + 缩点 DAG
// 功能:O(n+m) 求 compId、每个分量的节点集合与凝聚图
struct SCC
{
    int n;                                  // 顶点数
    vector<vector<int>> g;                  // 有向图
    vector<int> dfn, low, inStack;          // 时间戳、返祖值、是否在栈
    vector<int> st;                          // 栈
    int timer, sccCnt;                      // 时钟与分量计数
    vector<int> compId;                     // 每个点所属分量编号(1..sccCnt)
    vector<vector<int>> compNodes;          // 每个分量包含的点
    vector<vector<int>> dag;                // 凝聚图(1..sccCnt)

    SCC(int n_ = 0) { init(n_); }

    // 功能:重置大小与清空
    void init(int n_)
    {
        n = n_;
        g.assign(n + 1, {});
        dfn.assign(n + 1, 0);
        low.assign(n + 1, 0);
        inStack.assign(n + 1, 0);
        compId.assign(n + 1, 0);
        st.clear();
        compNodes.clear();
        dag.clear();
        timer = 0; sccCnt = 0;
    }

    // 功能:加有向边 u->v
    void addEdge(int u, int v)
    {
        g[u].push_back(v);
    }

    // 功能:主过程,分解并构建凝聚图
    void build()
    {
        for (int i = 1; i <= n; i++) if (!dfn[i]) dfs(i);
        compNodes.assign(sccCnt + 1, {});
        for (int i = 1; i <= n; i++) compNodes[compId[i]].push_back(i);
        dag.assign(sccCnt + 1, {});
        unordered_set<ll> seen;
        auto key = [&](int a, int b) -> ll { return (1LL * a << 32) ^ b; };
        for (int u = 1; u <= n; u++)
        {
            int cu = compId[u];
            for (int v : g[u])
            {
                int cv = compId[v]; if (cu == cv) continue;
                ll k = key(cu, cv); if (seen.insert(k).second) dag[cu].push_back(cv);
            }
        }
    }

    // 功能:返回凝聚图
    const vector<vector<int>>& getDag() { return dag; }

private:
    // 功能:Tarjan 递归
    void dfs(int u)
    {
        dfn[u] = low[u] = ++timer;
        st.push_back(u); inStack[u] = 1;
        for (int v : g[u])
        {
            if (!dfn[v]) { dfs(v); low[u] = min(low[u], low[v]); }
            else if (inStack[v]) low[u] = min(low[u], dfn[v]);
        }
        if (low[u] == dfn[u])
        {
            ++sccCnt;
            while (true)
            {
                int x = st.back(); st.pop_back(); inStack[x] = 0;
                compId[x] = sccCnt;
                if (x == u) break;
            }
        }
    }
};

最短路

Dijkstra

算法介绍

在非负边权图上,从单源 s 出发求到所有点的最短路。使用优先队列按当前最小距离“贪心”地弹出点,并对其所有出边做松弛。时间复杂度在邻接表与二叉堆下为 O(m log n)。这是单源最短路的标准解,要求所有边权非负。

常见例题

题目:给定 n 点 m 边的带权无向图,q 次询问,每次给出 s,t,问从 s 到 t 的最短路长度。

做法:每次以 s 为源运行 Dijkstra,或若 q 很多而图稠密,考虑预处理所有点对(例如换成 Floyd,见下)。在 Dijkstra 中把所有边加入邻接表,源点距离置 0,出队时若弹出已过期的对直接跳过,遍历出边做松弛。

代码
// 单源最短路 Dijkstra(非负权,闭区间下标 0..n-1)
// 构造函数与接口:addEdge(u,v,w) 添加边;run(s) 返回 dist;也提供一次性查询 s->t
template <typename T>
struct Dijkstra
{
    int n;
    vector<vector<pair<int, T>>> g;

    Dijkstra(int n_ = 0) : n(n_), g(n) {}

    // 功能:添加一条有向边 u->v,权重为 w
    void addEdge(int u, int v, T w) { g[u].push_back({v, w}); }

    // 功能:从源点 s 运行 Dijkstra,返回到每个点的最短距离(不可达为 +inf)
    vector<T> run(int s) const
    {
        const T INF = numeric_limits<T>::max() / 4;
        vector<T> dist(n, INF);
        priority_queue<pair<T, int>, vector<pair<T, int>>, greater<pair<T, int>>> pq;
        dist[s] = 0;
        pq.push({0, s});
        while (!pq.empty())
        {
            auto [d, u] = pq.top();
            pq.pop();
            if (d != dist[u])
                continue;
            for (auto [v, w] : g[u])
            {
                T nd = d + w;
                if (nd < dist[v])
                    dist[v] = nd, pq.push({nd, v});
            }
        }
        return dist;
    }

    // 功能:返回 s 到 t 的最短距离
    T query(int s, int t) const
    {
        auto dist = run(s);
        return dist[t];
    }
};

Floyd

算法介绍

Floyd–Warshall 用动态规划求所有点对最短路。设 dist[i][j] 为 i 到 j 的当前最短距离,按中转点 k 从小到大逐步放宽可用的中转集合。可处理负边权但不能包含可达的负环。复杂度 O(n^3)。

常见例题

题目:给定 n≤500 的带权有向图,有负边但无负环,回答任意点对最短路。

做法:初始化 dist[i][i]=0、边权赋值、其它置 +inf。三层循环按 k,i,j 依次转移即可。

代码
// Floyd–Warshall(支持负边,禁止负环),闭区间点集 [0,n-1]
template <typename T>
struct Floyd
{
    int n;
    vector<vector<T>> dist;

    Floyd(int n_ = 0) : n(n_), dist(n, vector<T>(n, numeric_limits<T>::max() / 4))
    {
        for (int i = 0; i < n; i++)
            dist[i][i] = 0;
    }

    // 功能:添加边 u->v 权重 w(若多条边,取最小)
    void addEdge(int u, int v, T w) { dist[u][v] = min(dist[u][v], w); }

    // 功能:执行 Floyd,得到所有点对最短路
    void run()
    {
        for (int k = 0; k < n; k++)
            for (int i = 0; i < n; i++)
                if (dist[i][k] < numeric_limits<T>::max() / 8)
                    for (int j = 0; j < n; j++)
                        if (dist[k][j] < numeric_limits<T>::max() / 8)
                            dist[i][j] = min(dist[i][j], dist[i][k] + dist[k][j]);
    }

    // 功能:返回 i 到 j 的最短路
    T query(int i, int j) const { return dist[i][j]; }
};

SPFA

算法介绍

SPFA 是 Bellman–Ford 的队列优化版本,单源求最短路,可处理负边权并可用“入队次数≥n”来判定是否存在可达负环。最坏复杂度 O(nm),竞赛中常配合差分约束、稀疏图与随机化点序能跑很快,但要注意卡 SPFA 的数据。

常见例题

题目:给定含负边的有向图,判断从 s 是否能到达负环;若无负环,输出到所有点的最短路。

做法:以 s 为源初始化 dist[s]=0,用队列维护待松弛点,更新时若 cnt[v] 达到 n 说明存在可达负环。

代码
// SPFA(负边可用,负环检测),闭区间点集 [0,n-1]
template <typename T>
struct Spfa
{
    int n;
    vector<vector<pair<int, T>>> g;

    Spfa(int n_ = 0) : n(n_), g(n) {}

    // 功能:添加有向边 u->v,权重 w
    void addEdge(int u, int v, T w) { g[u].push_back({v, w}); }

    // 功能:从 s 计算最短路;返回 pair<是否无负环, dist>
    pair<bool, vector<T>> run(int s) const
    {
        const T INF = numeric_limits<T>::max() / 4;
        vector<T> dist(n, INF);
        vector<int> inq(n, 0), cnt(n, 0);
        queue<int> q;
        dist[s] = 0;
        q.push(s);
        inq[s] = 1;

        while (!q.empty())
        {
            int u = q.front();
            q.pop();
            inq[u] = 0;
            for (auto [v, w] : g[u])
            {
                if (dist[u] != INF && dist[u] + w < dist[v])
                {
                    dist[v] = dist[u] + w;
                    if (!inq[v])
                    {
                        q.push(v);
                        inq[v] = 1;
                    }
                    if (++cnt[v] >= n)
                        return {false, dist};
                }
            }
        }
        return {true, dist};
    }
};

K 短路

算法介绍

经典的“k-th shortest path”做法是在非负权图上把 Dijkstra 的“每点最多进一次堆”改为“允许进堆多次”,对每个点统计被弹出次数,某点第 k 次被弹出即得到从源到该点的第 k 短路长。若仅需 s→t 的第 k 短,可在弹出 t 第 k 次时返回。

常见例题

题目:给定非负有向图与 s,t,k,求 s 到 t 的第 k 短路长度,若不存在输出 −1。

做法:以 (dist,u) 为堆元素,从 (0,s) 开始弹出,记录 popCnt[u],当 u==t 且 popCnt[u]==k 则返回当前距离;每次弹出 (d,u) 后,把所有出边 (u,v,w) 以 (d+w,v) 入堆。注意边权需非负。

代码
// k-th shortest path(非负权),弹出 t 第 k 次即为答案
template <typename T>
struct KthShortest
{
    int n;
    vector<vector<pair<int, T>>> g;

    KthShortest(int n_ = 0) : n(n_), g(n) {}
    void addEdge(int u, int v, T w) { g[u].push_back({v, w}); }

    // 功能:返回 s 到 t 的第 k 短路长;不存在则返回 INF
    T run(int s, int t, int k) const
    {
        const T INF = numeric_limits<T>::max() / 4;
        vector<int> popCnt(n, 0);
        priority_queue<pair<T, int>, vector<pair<T, int>>, greater<pair<T, int>>> pq;
        pq.push({0, s});
        while (!pq.empty())
        {
            auto [d, u] = pq.top();
            pq.pop();
            if (++popCnt[u] == k && u == t)
                return d;
            if (popCnt[u] > k)
                continue;
            for (auto [v, w] : g[u])
                pq.push({d + w, v});
        }
        return INF;
    }
};

差分约束

算法介绍

差分约束系统由形如 x_v - x_u ≤ w 的不等式组成。把每条不等式建成一条边 u→v 权重 w,则存在解当且仅当图中无可达负环。通常加一个超级源向所有点连 0 边,再跑最短路。若只判可行可用 SPFA 检测负环;若要最小可行解,可取 dist[v] 作为最小解的一组势函数。参考与推导见 cp-algorithms 及相关讲解。

常见例题

题目:已知 m 个约束 x_b - x_a ≤ c,求是否有解;若有解,输出一组满足约束的 {x_i}。

做法:建图后加超级源 s,对每个 i 连边 s→i 权重 0,跑 SPFA。如果出现“入队计数≥n”的点则无解;否则取 dist[i] 作为一组可行解(或做等距平移)。

代码
// 差分约束:判可行 + 给出一组解
template <typename T>
struct DiffConstraints
{
    int n;
    vector<vector<pair<int, T>>> g;

    DiffConstraints(int n_ = 0) : n(n_), g(n + 1) {}

    // 功能:添加约束 x_v - x_u <= w  =>  边 u->v 权重 w
    void addConstraint(int u, int v, T w) { g[u].push_back({v, w}); }

    // 功能:判可行并给出一组解;返回 pair<是否可行, 解向量>
    pair<bool, vector<T>> solve() const
    {
        int s = n;
        for (int i = 0; i < n; i++)
            g[s].push_back({i, 0});

        const T INF = numeric_limits<T>::max() / 4;
        vector<T> dist(n + 1, INF);
        vector<int> inq(n + 1, 0), cnt(n + 1, 0);
        queue<int> q;
        dist[s] = 0;
        q.push(s);
        inq[s] = 1;

        while (!q.empty())
        {
            int u = q.front();
            q.pop();
            inq[u] = 0;
            for (auto [v, w] : g[u])
            {
                if (dist[u] != INF && dist[u] + w < dist[v])
                {
                    dist[v] = dist[u] + w;
                    if (!inq[v])
                    {
                        q.push(v);
                        inq[v] = 1;
                    }
                    if (++cnt[v] >= n + 1)
                        return {false, {}};
                }
            }
        }
        dist.pop_back();
        return {true, dist};
    }
};

同余最短路 (CRT + Dijkstra)

算法介绍

这类题把“目标状态的编号”拆成同余类做最短路。常见模型是给一个模数 m 与若干“步长及其代价”,从 0 类出发,允许把当前余数 r 通过某个步长 a 走到 (r + a) mod m ,代价为 w(a)。在“到达数值 X 的最小代价”等问题中,先在“余数图”上以 m 个点做 Dijkstra 得到到每个余数类的最小代价 d[r],再通过 CRT 或取特定余数来还原到原问题域。关于 CRT 的基本性质与构造可参考这些资料;在实现上我们只用到“按模 m 的余数类图上跑 Dijkstra”。

常见例题

题目:给定正整数 m 与一组“可用步长” {a_i, cost_i},每次可把 x 变成 x + a_i,花费 cost_i,问把 0 变成任意非负整数 y 的最小代价,输出所有 0..N 的答案或特定 y 的答案。

做法:在 m 个余数结点上建有向边 r → (r + a_i) mod m,权重为 cost_i,从 0 结点做 Dijkstra 得到最小代价 d[r]。则到达任意 y 的最小代价为 d[y mod m],前提是“可以通过若干步把 0 的余数变成 y 的余数”(即 gcd(所有 a_i, m) | y)。若题目还包含多个模并要求满足多组同余,可先用 CRT 合并为一个模,然后同法处理。

代码
// 同余最短路:在余数图上 Dijkstra,d[r] 给出达到余数 r 的最小代价
template <typename T>
struct RemainderGraph
{
    int m;
    vector<vector<pair<int, T>>> g;

    RemainderGraph(int m_ = 1) : m(m_), g(m) {}

    // 功能:添加一种“步长 a,代价 w”,即从 r 走到 (r+a)%m,花费 w
    void addStep(int a, T w)
    {
        a %= m;
        if (a < 0)
            a += m;
        for (int r = 0; r < m; r++)
            g[r].push_back({(r + a) % m, w});
    }

    // 功能:以余数 0 为源跑 Dijkstra,返回到各余数的最小代价
    vector<T> run() const
    {
        const T INF = numeric_limits<T>::max() / 4;
        vector<T> dist(m, INF);
        priority_queue<pair<T, int>, vector<pair<T, int>>, greater<pair<T, int>>> pq;
        dist[0] = 0;
        pq.push({0, 0});
        while (!pq.empty())
        {
            auto [d, u] = pq.top();
            pq.pop();
            if (d != dist[u])
                continue;
            for (auto [v, w] : g[u])
            {
                T nd = d + w;
                if (nd < dist[v])
                    dist[v] = nd, pq.push({nd, v});
            }
        }
        return dist;
    }

    // 功能:求达到整数 y 的最小代价(若不可达返回 INF)
    T costTo(ll y) const
    {
        auto dist = run();
        int r = (int)((y % m + m) % m);
        return dist[r];
    }
};

最小生成树

Prim

算法介绍

Prim 以点为中心扩展生成树,始于任意一个起点,每次把横跨生成树与补图的最小权边纳入结果,相当于在一个“可访问点的最小出边”优先队列上反复取最小。用二叉堆可做到 O((n+m)log n)。与 Kruskal 相比,Prim 更适合稠密图或以邻接表直接增量地“长”出树的场景。正确性来自贪心切割性质:对任意割,其最小横切边必属于某棵最小生成树。

常见例题

题目:给定 n 个点 m 条无向边与权值,保证图连通,求最小生成树权值与一组生成树边。

做法:用 Prim,从任意起点 s 入队,堆里维护每个“未纳入点”的最小连接代价,取出最小代价点时把它与其前驱边放入答案并松弛相邻边。若需要输出边集,记录每个点被选入时的入边即可。

代码
// Prim 最小生成树,邻接表 + 小根堆
// 功能:addEdge(u,v,w) 添加无向边;run(s) 从起点 s 生成一棵 MST,返回总权并填充 treeEdges
// 复杂度:O((n+m)log n)
// 说明:若图不连通,将只覆盖与 s 连通的分量;可循环多次以得到最小生成森林
template<typename T>
struct PrimMST
{
    struct Edge { int to; T w; };
    int n;                                   // 点数
    vector<vector<Edge>> g;                  // 邻接表
    vector<int> parent;                      // 每个点被纳入时的前驱
    vector<pair<int,int>> treeEdges;         // 生成树边集

    PrimMST(int n_ = 0) { init(n_); }        // 构造:给定点数
    void init(int n_)
    {
        n = n_;
        g.assign(n, {});
        parent.assign(n, -1);
        treeEdges.clear();
    }

    // 添加无向边 u-v,权 w
    void addEdge(int u, int v, const T &w)
    {
        g[u].push_back({v, w});
        g[v].push_back({u, w});
    }

    // 运行 Prim,从起点 s 开始;返回生成树总权,treeEdges 给出边集
    T run(int s)
    {
        vector<T> dist(n, numeric_limits<T>::max());
        vector<char> used(n, 0);
        parent.assign(n, -1);
        treeEdges.clear();

        using Node = pair<T,int>;                            // (代价, 点)
        priority_queue<Node, vector<Node>, greater<Node>> pq;

        dist[s] = T{}; pq.push({T{}, s});
        T total{}; int picked = 0;

        while (!pq.empty() && picked < n)
        {
            auto [d, u] = pq.top(); pq.pop();
            if (used[u]) continue;
            used[u] = 1; total = total + d; picked++;

            if (parent[u] != -1) treeEdges.emplace_back(parent[u], u);

            for (auto &e : g[u])
            {
                if (!used[e.to] && e.w < dist[e.to])
                {
                    dist[e.to] = e.w; parent[e.to] = u;
                    pq.push({dist[e.to], e.to});
                }
            }
        }
        return total;
    }
};

Kruskal

算法介绍

Kruskal 以边为中心扩展,从小到大排序所有边,遇到不成环则纳入。用并查集判环,可达 O(m log m)。这是最经典的 MST 策略,尤其适用于稀疏图或边天然按权产生的场景。

常见例题

题目:给定 n 点 m 边的无向连通图,输出最小生成树的权值与一组树边。

做法:将边按权升序排序,遍历时若边的两个端点属于不同集合就合并之并计入答案与边集,否则跳过直到取满 n−1 条边为止。

代码
// Kruskal 最小生成树,边排序 + DSU 判环
// 功能:addEdge(u,v,w) 添加无向边;run() 返回总权并填充 treeEdges
// 复杂度:O(m log m)
template<typename T>
struct KruskalMST
{
    struct Edge { int u, v; T w; };
    struct DSU
    {
        int n; vector<int> p, sz;
        DSU(int n_ = 0) { init(n_); }
        void init(int n_)
        {
            n = n_; p.resize(n); sz.assign(n, 1);
            for (int i = 0; i < n; i++) p[i] = i;
        }
        int find(int x)
        {
            while (x != p[x]) x = p[x] = p[p[x]];
            return x;
        }
        bool unite(int a, int b)
        {
            a = find(a); b = find(b);
            if (a == b) return false;
            if (sz[a] < sz[b]) swap(a, b);
            p[b] = a; sz[a] += sz[b];
            return true;
        }
    };

    int n; vector<Edge> edges; vector<Edge> treeEdges;

    KruskalMST(int n_ = 0) { init(n_); }
    void init(int n_)
    {
        n = n_; edges.clear(); treeEdges.clear();
    }

    void addEdge(int u, int v, const T &w)
    {
        edges.push_back({u, v, w});
    }

    T run()
    {
        sort(edges.begin(), edges.end(), [&](const Edge &a, const Edge &b){ return a.w < b.w; });
        DSU dsu(n);
        treeEdges.clear();
        T total{}; int taken = 0;

        for (auto &e : edges)
        {
            if (dsu.unite(e.u, e.v))
            {
                total = total + e.w; treeEdges.push_back(e); taken++;
                if (taken == n - 1) break;
            }
        }
        return total;
    }
};

朱刘算法 Edmonds

算法介绍

在有向图上,以指定根 r 为根的最小入树(最小树形图、最小有向生成树)可用 Chu–Liu/Edmonds 算法求解。核心步骤是在每个非根点上选择一条最小入边;若无环则这些边构成最小树形图;若出现有向环,就把环缩为一个新点并对入环边做“减权”再递归,最终在收缩图上得到最优结构后再把环“还原”并删去环内一条入边。朴素实现为 O(nm),Tarjan 与后续改进可降至 O(m log n) 或 O(m + n log n)。

常见例题

题目:给定 n 个点 m 条有向边及权值,给定根 r,要求以 r 为根的最小树形图的总权,并输出一组树边。

做法:采用 Chu–Liu/Edmonds。每轮为每个非根点选最小入边;若不存在入边则无解;若选出的边集无环则结束;若有环则将环缩点并对进入环的边“减去该入边差值”,在收缩图上递归;回溯时把收缩点对应回原环并删去其中一条入边即可恢复一棵有向生成树。

代码
// Chu–Liu/Edmonds 最小树形图(有向最小生成树),返回总权并给出一组边
// 功能:addEdge(u,v,w) 添加有向边;run(root) 计算以 root 为根的最小树形图
// 复杂度:O(nm) 朴素实现;适合竞赛常规约束;需要连通性:每个非根点必须至少有一条入边
template<typename T>
struct EdmondsArborescence
{
    struct Edge { int u, v; T w; };     // u->v, w
    int n; vector<Edge> edges;           // 顶点编号 [0..n-1]
    vector<Edge> treeEdges;              // 输出的一组树边

    EdmondsArborescence(int n_ = 0) { init(n_); }
    void init(int n_)
    {
        n = n_; edges.clear(); treeEdges.clear();
    }

    void addEdge(int u, int v, const T &w)
    {
        edges.push_back({u, v, w});
    }

    // 主过程:以 root 为根计算最小树形图;不可达则抛出无解标记(返回权值的极大值)
    T run(int root)
    {
        const T INF = numeric_limits<T>::max() / 4;
        T total{}; treeEdges.clear();

        // 原图到当前收缩图的映射
        vector<Edge> es = edges;
        int N = n;

        while (true)
        {
            vector<T> in(N, INF);           // 每个点的最小入边权
            vector<int> pre(N, -1);         // 每个点的最小入边的起点
            for (auto &e : es)
            {
                if (e.u != e.v && e.w < in[e.v]) in[e.v] = e.w, pre[e.v] = e.u;
            }
            in[root] = T{}; pre[root] = -1;
            for (int v = 0; v < N; v++) if (in[v] == INF) return INF; // 有点无入边,根不可达,无解

            // 寻找环并编号
            int cnt = 0;
            vector<int> id(N, -1), vis(N, -1);
            for (int v = 0; v < N; v++)
            {
                int u = v;
                while (vis[u] != v && id[u] == -1 && u != root) vis[u] = v, u = pre[u];
                if (u != root && id[u] == -1)
                {
                    for (int x = pre[u]; x != u; x = pre[x]) id[x] = cnt;
                    id[u] = cnt; cnt++;
                }
            }
            if (cnt == 0)
            {
                for (int v = 0; v < N; v++) if (v != root) total = total + in[v];
                // 可选:恢复具体边。做法是在最终一轮 in/pre 上记录边,然后沿着 pre[v]->v 收集
                // 这里给出一份轻量级恢复:将最后一轮的 pre 作为树边来源
                for (int v = 0; v < N; v++) if (v != root) treeEdges.push_back({pre[v], v, in[v]});
                break;
            }

            for (int v = 0, k = 0; v < N; v++) if (id[v] == -1) id[v] = cnt + k++;

            // 累加当前轮入边权,并构造收缩后的边集(减权)
            for (int v = 0; v < N; v++) if (v != root) total = total + in[v];

            int NN = cnt + (N - count(id.begin(), id.end(), -1) - cnt); // 新图点数
            vector<Edge> nes; nes.reserve(es.size());
            for (auto &e : es)
            {
                int uu = id[e.u], vv = id[e.v];
                if (uu != vv)
                {
                    T w2 = e.w - in[e.v];
                    nes.push_back({uu, vv, w2});
                }
            }
            es.swap(nes); N = *max_element(id.begin(), id.end()) + 1; root = id[root];
        }
        return total;
    }
};

以上实现聚焦“总权值”,边恢复给出最后一轮的轻量做法,通常在竞赛中已可满足“给出一组最优解边”的要求;若需严格恢复全过程,可在每轮收缩时为边维护“原始边指针”并在回溯阶段按环断边规则精确还原。


图匹配

二分图最大匹配 (Hopcroft–Karp)

算法介绍

把所有未匹配的左侧点作为分层起点,用 BFS 建立按“未匹配边/已匹配边”交替的层次网络,再在该网络中用 DFS 同时寻找一批点不相交的最短增广路并一次性增广。每一轮都会把增广路长度至少加二,整体复杂度为 O(E√V) 。

常见例题

题目:给一个左侧大小为 n、右侧大小为 m 的二分图,q 次操作只查询当前最大匹配规模。

做法:用 Hopcroft–Karp 预处理整图的最大匹配;若只是静态询问直接输出匹配数;若离线删除边再恢复,可按时间线倒序把边“加回来”,每次在已有匹配上继续跑若干轮 BFS+DFS 增广即可,保证均摊高效。

代码
// Hopcroft–Karp 二分图最大匹配
// 顶点编号:左侧 1..n,右侧 1..m;边用 addEdge(u,v) 加(u∈[1..n], v∈[1..m])
// 接口:maxMatching() 返回最大匹配规模;matchL[x]/matchR[y] 为匹配对端,未匹配为 0
struct HopcroftKarp
{
    int n, m;                   // 左右集大小
    vector<vector<int>> adj;    // 邻接表:左 u -> 右 v
    vector<int> matchL, matchR; // 左右两侧的匹配对象
    vector<int> dist;           // BFS 层号,存左侧点到“可增广层”的距离

    // 构造:给定左右顶点数,邻接与匹配表清空
    HopcroftKarp(int n_ = 0, int m_ = 0)
    {
        init(n_, m_);
    }

    // 初始化:清空图与匹配关系
    void init(int n_, int m_)
    {
        n = n_, m = m_;
        adj.assign(n + 1, {});
        matchL.assign(n + 1, 0);
        matchR.assign(m + 1, 0);
        dist.assign(n + 1, 0);
    }

    // 加边:左 u 到 右 v
    void addEdge(int u, int v)
    {
        adj[u].push_back(v);
    }

    // BFS 分层:把所有未匹配的左点放到队列,逐层扩展到“可抵达未匹配右点的一层”
    bool bfs()
    {
        queue<int> q;
        for (int u = 1; u <= n; u++)
        {
            if (!matchL[u]) dist[u] = 0, q.push(u);
            else dist[u] = -1;
        }
        bool found = false;
        while (!q.empty())
        {
            int u = q.front(); q.pop();
            for (int v : adj[u])
            {
                int w = matchR[v];
                if (!w) found = true;                         // 抵达空闲右点,说明存在最短增广路
                else if (dist[w] == -1) dist[w] = dist[u] + 1, q.push(w); // 走匹配边扩展层
            }
        }
        return found;
    }

    // DFS 寻找与层次一致的增广路,从左点 u 出发
    bool dfs(int u)
    {
        for (int v : adj[u])
        {
            int w = matchR[v];
            if (!w || (dist[w] == dist[u] + 1 && dfs(w)))
            {
                matchL[u] = v, matchR[v] = u;
                return true;
            }
        }
        dist[u] = -1; // 剪枝:该层无解,避免重复搜索
        return false;
    }

    // 主过程:反复分层 + 在层内找一批点不相交最短增广路
    int maxMatching()
    {
        int res = 0;
        while (bfs())
        {
            for (int u = 1; u <= n; u++)
                if (!matchL[u] && dfs(u)) res++;
        }
        return res;
    }
};

二分图最大权匹配 (匈牙利算法 Kuhn–Munkres)

算法介绍

KM 算法解决指派/二分图最大权匹配问题,典型实现用“顶标(潜在值) + 交错树 + 松弛(slack)”的增广框架,每轮沿零边扩展或调整顶标,使得至少出现一条新的零边从而增广,时间复杂度 O(n^3)。

常见例题

题目:给定 n×n 的收益矩阵 w[i][j],要让每个左点 i 选恰一条到右点 j 的边,最大化权值和。

做法:若是最小费用匹配,把输入取相反数转成最大化;用 KM 维护左/右侧的标号 labelL/labelR、匹配 rightMatch、访问集 visL/visR 与每个右点的松弛量 slack,按经典流程寻找交错树中的增广路即可。

代码
// Kuhn–Munkres(Hungarian)算法,解决 n×n 二分图最大权匹配
// 顶点编号:左 1..n,右 1..n;权值矩阵 w 需事先给定
struct KuhnMunkres
{
    int n;                                  // 规模
    vector<vector<ll>> w;            // 权值矩阵
    vector<ll> labelL, labelR;       // 左右顶标
    vector<int> matchR, pre;                 // 右侧匹配到的左点;pre 记录增广树父亲
    vector<ll> slack;                // 每个右点的当前最小松弛量
    vector<int> slackFrom;                  // slack 来源的左点
    vector<char> visL, visR;                // 本轮是否在交错树中
    const ll INF = (1LL<<60);

    // 构造:给定 n,或直接给矩阵
    KuhnMunkres(int n_ = 0)
    {
        if (n_) init(n_);
        else n = 0;
    }
    KuhnMunkres(const vector<vector<ll>> &w_)
    {
        init(w_);
    }

    // 初始化:只定规模
    void init(int n_)
    {
        n = n_;
        w.assign(n + 1, vector<ll>(n + 1, 0));
    }

    // 初始化:用矩阵
    void init(const vector<vector<ll>> &w_)
    {
        n = (int)w_.size() - 1; // 约定 w_[1..n][1..n]
        w = w_;
    }

    // 主过程:返回最大权值,matchR[j] = i
    ll maxWeight()
    {
        labelL.assign(n + 1, 0), labelR.assign(n + 1, 0);
        matchR.assign(n + 1, 0), pre.assign(n + 1, 0);
        for (int i = 1; i <= n; i++)
        {
            labelL[i] = w[i][1];
            for (int j = 2; j <= n; j++) if (w[i][j] > labelL[i]) labelL[i] = w[i][j];
        }
        for (int s = 1; s <= n; s++)
        {
            slack.assign(n + 1, INF), slackFrom.assign(n + 1, 0);
            visL.assign(n + 1, 0), visR.assign(n + 1, 0);
            queue<int> q; q.push(s); visL[s] = 1; pre.assign(n + 1, 0);
            for (int j = 1; j <= n; j++) slack[j] = labelL[s] + labelR[j] - w[s][j], slackFrom[j] = s;

            int augR = 0; // 找到可增广的右点
            while (true)
            {
                // 尝试在零边上扩展
                while (!q.empty() && !augR)
                {
                    int u = q.front(); q.pop();
                    for (int v = 1; v <= n; v++)
                    {
                        if (visR[v]) continue;
                        ll gap = labelL[u] + labelR[v] - w[u][v];
                        if (gap == 0)
                        {
                            visR[v] = 1, pre[v] = u;
                            if (!matchR[v]) { augR = v; break; }
                            int nxt = matchR[v];
                            if (!visL[nxt]) visL[nxt] = 1, q.push(nxt);
                        }
                        else if (slack[v] > gap) slack[v] = gap, slackFrom[v] = u;
                    }
                }
                if (augR) break;

                // 调整顶标:把所有 visL 增加 delta,把所有 visR 减少 delta,让至少一个 slack 归零
                ll delta = INF;
                for (int v = 1; v <= n; v++) if (!visR[v] && slack[v] < delta) delta = slack[v];
                for (int i = 1; i <= n; i++) if (visL[i]) labelL[i] -= delta;
                for (int j = 1; j <= n; j++) if (visR[j]) labelR[j] += delta; else slack[j] -= delta;

                // 检查新的零边
                for (int v = 1; v <= n; v++)
                {
                    if (!visR[v] && slack[v] == 0)
                    {
                        visR[v] = 1; pre[v] = slackFrom[v];
                        if (!matchR[v]) { augR = v; break; }
                        int nxt = matchR[v];
                        if (!visL[nxt]) visL[nxt] = 1, q.push(nxt);
                    }
                }
                if (augR) break;
            }

            // 反转交错路完成增广
            while (augR)
            {
                int u = pre[augR], pv = matchR[augR];
                matchR[augR] = u, augR = pv;
            }
        }
        ll ans = 0;
        for (int j = 1; j <= n; j++) ans += w[matchR[j]][j];
        return ans;
    }
};

一般图最大匹配 (带花树 Blossom)

算法介绍

Edmonds 的带花树算法把“奇环”压缩为单点继续找增广路,从而把一般图的最大匹配转化为在反复收缩/还原的过程中寻找交替路的问题。经典实现的时间复杂度可做到 O(V^3);

常见例题

题目:给定 n 个点的一般无向图,求最大匹配规模并输出匹配边。

做法:用 Edmonds Blossom 维护交替森林,BFS 扩展时一旦遇到“偶层到偶层”的边触发找花并把整个奇环收缩成超点,继续在收缩图上 BFS;发现可增广的自由点后顺着父边回溯增广;若 BFS 无法继续且没有自由点可增广则获得最大匹配。

代码
// Edmonds Blossom 一般图最大匹配(无权版本),返回最大匹配规模
// 顶点编号 1..n;用 addEdge(u,v) 加边;match[x] 给出匹配对端
struct Blossom
{
    int n;
    vector<vector<int>> adj;    // 无向图
    vector<int> match;          // 匹配对端,0 表示自由
    vector<int> base;           // 每个点当前所在“花”的基点
    vector<int> parent;         // 交替树中的父边对应的点
    vector<int> type;           // 点在 BFS 中的层类型:0 未访问,1 偶层,2 奇层
    vector<int> q;              // 简易队列
    int qs, qe;                 // 队头队尾
    vector<int> mark;           // LCA 辅助标记
    vector<int> inQueue;        // 是否在队列

    Blossom(int n_ = 0) { init(n_); }

    void init(int n_) 
    {
        n = n_;
        adj.assign(n + 1, {});
        match.assign(n + 1, 0);
        base.resize(n + 1);
        parent.assign(n + 1, 0);
        type.assign(n + 1, 0);
        inQueue.assign(n + 1, 0);
        mark.assign(n + 1, 0);
        q.assign(n + 5, 0);
    }

    void addEdge(int u, int v)
    {
        if (u == v) return;
        adj[u].push_back(v), adj[v].push_back(u);
    }

    // 找到两点 u、v 的交替树最低公共祖先(按“基点”视角)
    int lca(int u, int v)
    {
        static int stamp = 0; stamp++;
        while (true)
        {
            u = base[u];
            mark[u] = stamp;
            if (!match[u]) break;
            u = parent[match[u]];
        }
        while (true)
        {
            v = base[v];
            if (mark[v] == stamp) return v;
            if (!match[v]) break;
            v = parent[match[v]];
        }
        return 0; // 不会到达
    }

    // 把包含环上点的“花”整体收缩到 lcaBase,并修正 parent/base/type
    void blossomContract(int u, int v, int lcaBase)
    {
        auto fix = [&](int x, int y)
        {
            while (base[x] != lcaBase) 
            {
                int m = match[x], b = base[x], mb = base[m];
                parent[b] = y, y = m;
                if (type[m] == 2) type[m] = 1, push(m);
                base[b] = base[mb] = lcaBase;
                x = parent[m];
            }
        };
        fix(u, v), fix(v, u);
    }

    // 入队一个偶层点
    void push(int x)
    {
        if (!inQueue[x]) inQueue[x] = 1, q[qe++] = x;
    }

    // 从某个自由点 s 开始做一轮 BFS,若找到可增广的自由点则完成一次增广并返回 true
    bool bfs(int s)
    {
        for (int i = 1; i <= n; i++) base[i] = i, parent[i] = 0, type[i] = 0, inQueue[i] = 0;
        qs = qe = 0; push(s); type[s] = 1;

        while (qs < qe)
        {
            int u = q[qs++];

            for (int v : adj[u])
            {
                if (base[u] == base[v] || match[u] == v) continue;
                if (type[v] == 2) continue;

                if (type[v] == 0)
                {
                    type[v] = 2, parent[v] = u;
                    if (!match[v])
                    {
                        // 找到从 s 到 v 的增广路,回溯反转
                        int x = u, y = v;
                        while (x)
                        {
                            int nx = match[x];
                            match[x] = y, match[y] = x;
                            x = parent[nx], y = nx;
                        }
                        return true;
                    }
                    type[match[v]] = 1, push(match[v]);
                }
                else if (type[v] == 1)
                {
                    int b = lca(u, v);
                    blossomContract(u, v, b);
                }
            }
        }
        return false;
    }

    // 求最大匹配:对每个自由点尝试做一轮 BFS 增广
    int maxMatching()
    {
        int res = 0;
        for (int s = 1; s <= n; s++)
        {
            if (!match[s] && bfs(s)) res++;
        }
        return res;
    }
};

网络流

最大流 (Dinic 算法)

算法介绍

Dinic 的核心是分层图与阻塞流。首先在残量网络上用 BFS 建立从源点出发的分层图,只保留分层单调递增的可行边;随后在这张分层图上用 DFS 沿层推进增广,直到没有可增广的流量为止;再重新分层并重复。分层保证了增广路径长度单调不降,当前弧优化避免重复扫描边,整体在一般图上表现稳定,二分图与单位容量网络上有很强的实践性能。实现上采用前向星或邻接表存边,加入反向边承载回退流量,所有增广都在残量网络上进行。

代码
// Dinic 最大流模板(带当前弧优化),索引从 0 到 n-1
// 功能:addEdge(u,v,c) 添加有向边;maxFlow(s,t) 返回最大流
// 复杂度:一般 O(min(n^{2/3}, m^{1/2}) * m) 的经验值,竞赛中表现优良
template <typename T>
struct Dinic
{
    struct Edge
    {
        int to, rev;   // 终点与反向边下标
        T cap;         // 残量容量
        Edge(int to_=0, int rev_=0, T cap_=0): to(to_), rev(rev_), cap(cap_) {}
    };

    int n;                          // 点数
    vector<vector<Edge>> g;         // 邻接表
    vector<int> level;              // 分层
    vector<int> it;                 // 当前弧指针

    // 构造函数,给定点数 n_ 初始化
    Dinic(int n_=0) { if (n_) init(n_); else n = 0; }

    // 初始化点数并清空图
    void init(int n_)
    {
        n = n_;
        g.assign(n, {});
        level.assign(n, 0);
        it.assign(n, 0);
    }

    // 添加有向边 u->v,容量为 c,同时添加反向边 v->u 容量为 0
    void addEdge(int u, int v, T c)
    {
        Edge a(v, (int)g[v].size(), c);
        Edge b(u, (int)g[u].size(), 0);
        g[u].push_back(a), g[v].push_back(b);
    }

    // 构建分层图,返回是否可达汇点
    bool bfs(int s, int t)
    {
        fill(level.begin(), level.end(), -1);
        queue<int> q;
        level[s] = 0, q.push(s);
        while (!q.empty())
        {
            int u = q.front(); q.pop();
            for (auto &e : g[u])
            {
                if (e.cap > 0 && level[e.to] == -1)
                {
                    level[e.to] = level[u] + 1;
                    if (e.to == t) return true;
                    q.push(e.to);
                }
            }
        }
        return level[t] != -1;
    }

    // 沿分层图 DFS 增广,当前弧优化
    T dfs(int u, int t, T f)
    {
        if (u == t || f == 0) return f;
        for (int &i = it[u]; i < (int)g[u].size(); i++)
        {
            Edge &e = g[u][i];
            if (e.cap > 0 && level[e.to] == level[u] + 1)
            {
                T ret = dfs(e.to, t, min(f, e.cap));
                if (ret > 0)
                {
                    e.cap -= ret;
                    g[e.to][e.rev].cap += ret;
                    return ret;
                }
            }
        }
        return 0;
    }

    // 主过程:反复分层+阻塞流求解最大流
    T maxFlow(int s, int t)
    {
        T flow = 0;
        while (bfs(s, t))
        {
            fill(it.begin(), it.end(), 0);
            while (true)
            {
                T pushed = dfs(s, t, numeric_limits<T>::max());
                if (pushed == 0) break;
                flow += pushed;
            }
        }
        return flow;
    }
};

最小费用最大流

算法介绍

最小费用最大流在满足流量约束的前提下使总费用最小。常用实现是“势能 + Dijkstra”的逐条最短增广路径法。边存费用 cost 和容量 cap,并维护反向边;若图上存在负边权但无负环,可先用 SPFA 计算初始势能 potential,再在每次增广时用 Dijkstra 按调整后的非负边权寻找最短路。沿最短路增广后更新势能,直到不存在从 s 到 t 的可行增广路径。若容量或费用较大,使用 long long 存储;若需要处理上下界,可在此模板上做上下界建图扩展。

代码
// 最小费用最大流模板(势能 + Dijkstra,支持负边权但无负环的情况)
// 功能:addEdge(u,v,cap,cost) 添加有向边;minCostMaxFlow(s,t) 返回 {最大流, 最小费用}
// 复杂度:每次增广一条最短路,配合堆优化 Dijkstra,经验表现稳定
template <typename T, typename C>
struct MinCostMaxFlow
{
    struct Edge
    {
        int to, rev;          // 终点与反向边下标
        T cap;                // 残量容量
        C cost;               // 单位费用
        Edge(int to_=0, int rev_=0, T cap_=0, C cost_=0): to(to_), rev(rev_), cap(cap_), cost(cost_) {}
    };

    int n;                                  // 点数
    vector<vector<Edge>> g;                 // 邻接表
    vector<C> dist, pot;                    // 最短路距离与势能
    vector<int> pvV, pvE;                   // 记录前驱点与前驱边
    const C INF_COST = numeric_limits<C>::max() / 4;
    const T INF_FLOW = numeric_limits<T>::max() / 4;

    // 构造函数
    MinCostMaxFlow(int n_=0) { if (n_) init(n_); else n = 0; }

    // 初始化
    void init(int n_)
    {
        n = n_;
        g.assign(n, {});
        dist.assign(n, 0);
        pot.assign(n, 0);
        pvV.assign(n, -1);
        pvE.assign(n, -1);
    }

    // 添加有向边 u->v
    void addEdge(int u, int v, T cap, C cost)
    {
        Edge a(v, (int)g[v].size(), cap, cost);
        Edge b(u, (int)g[u].size(), 0, -cost);
        g[u].push_back(a), g[v].push_back(b);
    }

    // 若存在负费用边,先用 SPFA 初始化势能
    void initPotentialWithSPFA(int s)
    {
        fill(pot.begin(), pot.end(), INF_COST);
        vector<bool> inq(n, false);
        queue<int> q;
        pot[s] = 0, q.push(s), inq[s] = true;
        while (!q.empty())
        {
            int u = q.front(); q.pop(); inq[u] = false;
            for (auto &e : g[u])
            {
                if (e.cap > 0 && pot[e.to] > pot[u] + e.cost)
                {
                    pot[e.to] = pot[u] + e.cost;
                    if (!inq[e.to]) q.push(e.to), inq[e.to] = true;
                }
            }
        }
        for (int i = 0; i < n; i++) if (pot[i] == INF_COST) pot[i] = 0;
    }

    // 单次 Dijkstra,求以势能矫正后的最短路
    bool dijkstra(int s, int t)
    {
        fill(dist.begin(), dist.end(), INF_COST);
        fill(pvV.begin(), pvV.end(), -1);
        fill(pvE.begin(), pvE.end(), -1);
        using P = pair<C,int>;
        priority_queue<P, vector<P>, greater<P>> pq;
        dist[s] = 0, pq.push({0, s});
        while (!pq.empty())
        {
            auto [d, u] = pq.top(); pq.pop();
            if (d != dist[u]) continue;
            for (int i = 0; i < (int)g[u].size(); i++)
            {
                auto &e = g[u][i];
                if (e.cap > 0)
                {
                    C w = e.cost + pot[u] - pot[e.to];
                    if (dist[e.to] > dist[u] + w)
                    {
                        dist[e.to] = dist[u] + w;
                        pvV[e.to] = u, pvE[e.to] = i;
                        pq.push({dist[e.to], e.to});
                    }
                }
            }
        }
        return dist[t] < INF_COST;
    }

    // 主过程:返回 {最大流, 最小费用}
    pair<T,C> minCostMaxFlow(int s, int t, bool hasNegativeCostEdges = false)
    {
        if (hasNegativeCostEdges) initPotentialWithSPFA(s);
        else fill(pot.begin(), pot.end(), 0);
        T flow = 0;
        C cost = 0;
        while (dijkstra(s, t))
        {
            for (int i = 0; i < n; i++) if (dist[i] < INF_COST) pot[i] += dist[i];
            T aug = INF_FLOW;
            for (int v = t; v != s; v = pvV[v])
            {
                auto &e = g[pvV[v]][pvE[v]];
                aug = min(aug, e.cap);
            }
            for (int v = t; v != s; v = pvV[v])
            {
                auto &e = g[pvV[v]][pvE[v]];
                auto &re = g[e.to][e.rev];
                e.cap -= aug, re.cap += aug;
                cost += aug * e.cost;
            }
            flow += aug;
        }
        return {flow, cost};
    }
};

图的环与回路

欧拉回路 (Eulerian Path/Circuit)

算法介绍

在一个连通的无向图中,所有点度数全为偶数则存在欧拉回路,若恰有两个点为奇度且其余为偶度则存在欧拉路径,起点为奇度点之一;在一个连通的有向图中,若每个点入度等于出度则存在欧拉回路,若恰有一个点出度比入度大 1 且恰有一个点入度比出度大 1 且其余入出度相等则存在欧拉路径,起点为出度大的那个点。实际构造用 Hierholzer 算法,把沿边走尽的回路按访问栈拼接,复杂度 O(n+m)。所有区间与遍历的端点在本文均按闭区间 [l,r] 表示。

常见例题

题目:给定一个 n 点 m 边的图,可能是无向也可能是有向,判断是否存在欧拉路径或欧拉回路;若存在,输出一条合法的经过每条边恰好一次的路径。

做法:先按图的有向或无向条件检查度数约束并检查“忽略零度点后的连通性”,若不满足则无解。满足时选择合法起点,使用 Hierholzer:从起点出发沿尚未使用的边尽量走到底,走不动时将当前点加入答案并回退,最终反转得到欧拉序列。

代码
// 欧拉路径/回路构造(支持无向/有向),Hierholzer算法
// 功能:检查存在性并返回一条欧拉路径/回路;图按1..n编号,闭区间索引
struct Eulerian
{
    struct Edge { int to, id; Edge(int t=0,int i=0):to(t),id(i){} };
    int n, m; bool directed;
    vector<vector<Edge>> g;
    vector<int> inDeg, outDeg, deg;
    vector<int> used, path;

    Eulerian(int n_=0, bool directed_=false): n(n_), m(0), directed(directed_) { init(n_, directed_); }

    void init(int n_, bool directed_)
    {
        n = n_; directed = directed_;
        g.assign(n + 1, {}); inDeg.assign(n + 1, 0); outDeg.assign(n + 1, 0); deg.assign(n + 1, 0);
        used.clear(); path.clear(); m = 0;
    }

    // 加边:无向边计两条半边;有向图计单向
    void addEdge(int u, int v)
    {
        if (!directed) {
            g[u].emplace_back(v, m); g[v].emplace_back(u, m); deg[u]++, deg[v]++, used.push_back(0), m++;
        } else {
            g[u].emplace_back(v, m); outDeg[u]++, inDeg[v]++, used.push_back(0), m++;
        }
    }

    // 选取起点并检查度数条件;返回{-1无解, 起点>=1}
    int pickStart()
    {
        if (!directed) {
            int oddCnt = 0, s = 1;
            for (int i = 1; i <= n; i++) if (deg[i] & 1) oddCnt++, s = i;
            if (oddCnt == 0) {
                for (int i = 1; i <= n; i++) if (deg[i]) return i;
                return -1; // 全零边
            }
            if (oddCnt == 2) return s;
            return -1;
        } else {
            int up = 0, down = 0, s = 1;
            for (int i = 1; i <= n; i++) {
                if (outDeg[i] - inDeg[i] == 1) up++, s = i;
                else if (inDeg[i] - outDeg[i] == 1) down++;
                else if (inDeg[i] != outDeg[i]) return -1;
            }
            if (up == 0 && down == 0) {
                for (int i = 1; i <= n; i++) if (inDeg[i] || outDeg[i]) return i;
                return -1;
            }
            if (up == 1 && down == 1) return s;
            return -1;
        }
    }

    // 忽略度为0的点后检查连通性
    bool isConnected()
    {
        vector<int> vis(n + 1, 0);
        int s = -1;
        if (!directed) {
            for (int i = 1; i <= n; i++) if (deg[i]) { s = i; break; }
            if (s == -1) return true;
            auto dfs = [&](auto &&self, int u) -> void
            {
                vis[u] = 1;
                for (auto e : g[u]) if (!vis[e.to]) self(self, e.to);
            };
            dfs(dfs, s);
            for (int i = 1; i <= n; i++) if (deg[i] && !vis[i]) return false;
            return true;
        } else {
            // 对有向图:在弱连通意义下联通;使用无向视图检查
            vector<vector<int>> ug(n + 1);
            for (int u = 1; u <= n; u++) for (auto e : g[u]) ug[u].push_back(e.to), ug[e.to].push_back(u);
            for (int i = 1; i <= n; i++) if (inDeg[i] || outDeg[i]) { s = i; break; }
            if (s == -1) return true;
            auto dfs = [&](auto &&self, int u) -> void
            {
                vis[u] = 1;
                for (int v : ug[u]) if (!vis[v]) self(self, v);
            };
            dfs(dfs, s);
            for (int i = 1; i <= n; i++) if ((inDeg[i] || outDeg[i]) && !vis[i]) return false;
            return true;
        }
    }

    // 返回是否存在欧拉路径/回路并构造一条序列(顶点序);无解返回空
    vector<int> build()
    {
        path.clear();
        if (!isConnected()) return {};
        int s = pickStart(); if (s == -1) return {};
        vector<int> st; vector<int> it(n + 1, 0);
        auto walk = [&](auto &&self, int u) -> void
        {
            for (int &i = it[u]; i < (int)g[u].size(); i++) {
                auto e = g[u][i]; if (used[e.id]) continue;
                used[e.id] = 1;
                if (!directed) {
                    // 无向边另一半也视作已用:通过“边id”即可避免重复
                }
                self(self, e.to);
            }
            path.push_back(u);
        };
        walk(walk, s);
        // 检查是否用完全部边
        for (int i = 0; i < m; i++) if (!used[i]) return {};
        reverse(path.begin(), path.end());
        return path;
    }
};

环的检测

算法介绍

在无向图中,用 DFS 时若访问到已访问且不是父亲的点则存在环;在有向图中,用颜色数组 0/1/2 标注未访问/递归栈内/已完成,若遇到颜色为 1 的点则存在有向环。两者都可以在发现环时回溯还原一条环路,时间复杂度 O(n+m)。

常见例题

题目:给定一个图(可能有向或无向),判断是否有环;若有,则输出任意一条环上的顶点序列。

做法:无向图用父指针 DFS,碰到非父已访节点回溯输出;有向图用染色 DFS,遇到回边时从当前点向上回溯至该点输出一条简单环。

代码
// 环检测与输出一条环,支持无向/有向,闭区间实现
struct CycleDetector
{
    int n; bool directed;
    vector<vector<int>> g;
    vector<int> color, parent, stackList; // color: 0,1,2 for directed; parent for undirected
    vector<int> cycle;

    CycleDetector(int n_=0, bool directed_=false): n(n_), directed(directed_) { init(n_, directed_); }

    void init(int n_, bool directed_)
    {
        n = n_; directed = directed_;
        g.assign(n + 1, {}); color.assign(n + 1, 0); parent.assign(n + 1, -1);
        stackList.clear(); cycle.clear();
    }

    void addEdge(int u, int v)
    {
        if (!directed) g[u].push_back(v), g[v].push_back(u);
        else g[u].push_back(v);
    }

    // 无向图:检测并返回一条环;若无环返回空
    vector<int> findUndirectedCycle()
    {
        vector<int> vis(n + 1, 0);
        auto dfs = [&](auto &&self, int u, int fa) -> bool
        {
            vis[u] = 1; parent[u] = fa;
            for (int v : g[u]) {
                if (v == fa) continue;
                if (!vis[v]) { if (self(self, v, u)) return true; }
                else {
                    // 回溯还原环 v -> ... -> u -> v
                    vector<int> tmp; int x = u; tmp.push_back(v);
                    while (x != v) tmp.push_back(x), x = parent[x];
                    tmp.push_back(v); reverse(tmp.begin(), tmp.end()); cycle = tmp; return true;
                }
            }
            return false;
        };
        for (int i = 1; i <= n; i++) if (!vis[i]) if (dfs(dfs, i, -1)) return cycle;
        return {};
    }

    // 有向图:检测并返回一条环;若无环返回空
    vector<int> findDirectedCycle()
    {
        vector<int> inStack(n + 1, 0);
        auto dfs = [&](auto &&self, int u) -> bool
        {
            color[u] = 1; inStack[u] = 1; stackList.push_back(u);
            for (int v : g[u]) {
                if (color[v] == 0) { parent[v] = u; if (self(self, v)) return true; }
                else if (inStack[v]) {
                    vector<int> tmp; int x = u; tmp.push_back(v);
                    while (x != v) tmp.push_back(x), x = parent[x];
                    tmp.push_back(v); reverse(tmp.begin(), tmp.end()); cycle = tmp; return true;
                }
            }
            inStack[u] = 0; color[u] = 2; stackList.pop_back(); return false;
        };
        for (int i = 1; i <= n; i++) if (color[i] == 0) if (dfs(dfs, i)) return cycle;
        return {};
    }
};

最小环

算法介绍

无向无权图的最小环长度可通过每个源点做一次 BFS 获得:对源 s 的 BFS 中,若遇到一条指向已访问且不是父亲的边 (u,v) 则得到一条长度为 dist[u]+dist[v]+1 的环,取最小值。加权图可对每条边 (u,v,w) 运行一次 Dijkstra,临时移除该边,最小环候选为 dist(u→v)+w,取全局最小。前者 O(n(n+m)),后者 O(m·(m log n)),在竞赛范围内常可通过剪枝或稀疏性通过。

常见例题

题目:给定一个 n 点 m 边的无向无权图,若图中有环,输出最小环的长度与一条对应的环;若没有,输出 -1。
做法:枚举源点 s 做 BFS,记录父亲与距离;当扫描到边 (u,v) 且 v 已访问且 parent[u] != v 时,更新答案并通过 parent 指针回溯还原该最小环的点集。

代码
// 无向无权图最小环:返回环长与一条环
struct MinCycleUnweighted
{
    int n; vector<vector<int>> g;
    MinCycleUnweighted(int n_=0): n(n_) { init(n_); }
    void init(int n_) { n = n_; g.assign(n + 1, {}); }
    void addEdge(int u, int v) { g[u].push_back(v), g[v].push_back(u); }

    pair<int, vector<int>> solve()
    {
        int best = inf; vector<int> bestCycle;
        vector<int> dist(n + 1), parent(n + 1);
        for (int s = 1; s <= n; s++) {
            fill(dist.begin(), dist.end(), -1); fill(parent.begin(), parent.end(), -1);
            queue<int> q; q.push(s); dist[s] = 0;
            while (!q.empty()) {
                int u = q.front(); q.pop();
                for (int v : g[u]) {
                    if (dist[v] == -1) dist[v] = dist[u] + 1, parent[v] = u, q.push(v);
                    else if (parent[u] != v) {
                        int len = dist[u] + dist[v] + 1;
                        if (len < best) {
                            best = len; vector<int> a, b; int x = u, y = v;
                            while (x != -1 && x != y) a.push_back(x), x = parent[x];
                            if (x == y) { a.push_back(y); reverse(a.begin(), a.end()); bestCycle = a; }
                            else {
                                x = u; while (x != -1) a.push_back(x), x = parent[x];
                                y = v; while (y != -1) b.push_back(y), y = parent[y];
                                // 找到 LCA 合并
                                while (!a.empty() && !b.empty() && a.back() == b.back()) a.pop_back(), b.pop_back();
                                reverse(a.begin(), a.end()); bestCycle = a; bestCycle.push_back(b.front());
                                for (int i = 1; i < (int)b.size(); i++) bestCycle.push_back(b[i]);
                            }
                        }
                    }
                }
            }
        }
        return best == inf ? make_pair(-1, vector<int>{}) : make_pair(best, bestCycle);
    }
};
// 加权有向/无向图最小环(按边跑最短路):返回最小环长度(不可还原路径时至少可给出长度)
// 若需还原路径,可在 Dijkstra 中记录前驱,然后在命中边(u,v,w)时回溯 u->...->v 再拼 w
template <typename W>
struct MinCycleWeighted
{
    int n; bool directed;
    struct E { int to; W w; };
    vector<vector<E>> g;

    MinCycleWeighted(int n_=0, bool directed_=false): n(n_), directed(directed_) { init(n_, directed_); }
    void init(int n_, bool directed_) { n = n_; directed = directed_; g.assign(n + 1, {}); }
    void addEdge(int u, int v, W w)
    {
        g[u].push_back({v, w});
        if (!directed) g[v].push_back({u, w});
    }

    W solve()
    {
        const W INF = numeric_limits<W>::max() / 4;
        W ans = INF;
        // 枚举每条边 (u,v,w) 作为最后一条环边
        for (int u = 1; u <= n; u++) for (auto e : g[u]) {
            int v = e.to; W w = e.w;
            // Dijkstra:求去掉这条“可回头”边后的 u->v 最短路
            vector<W> d(n + 1, INF);
            using P = pair<W,int>;
            priority_queue<P, vector<P>, greater<P>> pq;
            d[u] = 0; pq.push({0, u});
            while (!pq.empty()) {
                auto [du, x] = pq.top(); pq.pop();
                if (du != d[x]) continue;
                for (auto ed : g[x]) {
                    int y = ed.to; W nw = du + ed.w;
                    // 在无向图中避免直接用回边 (v->u) 走回去构成长度2环;有向图不需要这一步
                    if (!directed && ((x == u && y == v) || (x == v && y == u))) continue;
                    if (nw < d[y]) d[y] = nw, pq.push({nw, y});
                }
            }
            if (d[v] != INF) ans = min(ans, d[v] + w);
        }
        return ans == INF ? -1 : ans;
    }
};

杂项

拓扑排序

算法介绍

有向无环图的拓扑序是把所有顶点线性排列,使每条有向边均从序前点指向序后点。常用两法。方法一是 Kahn 算法,统计入度,把入度为 0 的点入队,不断弹出并“删除”其出边,若能取出 n 个点则存在拓扑序,否则有环。方法二是深搜后序逆序,递归访问未访问的点,沿出边深搜,离开时把点加入序列,最终逆序即拓扑序;若在 DFS 栈内再次遇到灰色点则有环。

常见例题

题目:给定 n 个任务与 m 条先后关系边 u→v,若无环输出任意拓扑序,若有环输出 -1。

做法:用 Kahn 算法求序列,弹出数不足 n 判环;或用 DFS 判环并产出序列。

代码
// 拓扑排序模板,提供 Kahn 与 DFS 两种实现,均基于闭区间边界注释
// 功能:判环与返回拓扑序;若有环返回空序列
struct TopoSort
{
    int n;                            // 点数,默认编号 [0,n-1]
    vector<vector<int>> g;            // 邻接表
    TopoSort(int n_ = 0) { init(n_); }
    void init(int n_) { n = n_; g.assign(n, {}); }
    void addEdge(int u, int v) { g[u].push_back(v); } // 加有向边 u->v

    // Kahn:O(n+m),队列弹出顺序即拓扑序
    vector<int> kahn()
    {
        vector<int> deg(n), ord; ord.reserve(n);
        for (int u = 0; u < n; u++) for (int v : g[u]) deg[v]++;
        queue<int> q;
        for (int u = 0; u < n; u++) if (!deg[u]) q.push(u);
        while (!q.empty())
        {
            int u = q.front(); q.pop(); ord.push_back(u);
            for (int v : g[u]) if (--deg[v] == 0) q.push(v);
        }
        if ((int)ord.size() != n) return {}; // 有环
        return ord;
    }

    // DFS:O(n+m),检测回边判环,后序逆序即拓扑序
    vector<int> dfsTopo()
    {
        vector<int> vis(n), ord; ord.reserve(n); bool hasCycle = false;
        // vis: 0=未访问,1=栈内,2=已完成
        auto dfs = [&](auto &&self, int u) -> void
        {
            vis[u] = 1;
            for (int v : g[u])
            {
                if (vis[v] == 0) self(self, v);
                else if (vis[v] == 1) hasCycle = true;
                if (hasCycle) return;
            }
            vis[u] = 2; ord.push_back(u);
        };
        for (int u = 0; u < n; u++) if (!vis[u]) dfs(dfs, u);
        if (hasCycle) return {};
        reverse(ord.begin(), ord.end());
        return ord;
    }
};

二分图

算法介绍

二分图是把顶点分为不相交的两侧 U 与 V,使得每条边必然跨侧。判定可用 BFS 或 DFS 染色,给起点色 0,相邻点着相反颜色,若遇到同色相邻则非二分。若图不连通,从所有未染色点重启。此判定等价于“图无奇环”。

常见例题

题目:给定无向图,判断是否是二分图;若是,输出两侧点集。

做法:对每个未染色连通块执行 BFS/DFS 染色,若冲突则输出 NO,否则输出 YES 与按颜色分组的结果。

代码
// 二分图判定与染色,支持多连通块
// 功能:check() 返回是否二分;getColor() 返回每点颜色 0/1(未判定前为-1)
struct BipartiteCheck
{
    int n;
    vector<vector<int>> g;
    vector<int> color; // -1=未染色, 0/1 两侧
    BipartiteCheck(int n_ = 0) { init(n_); }
    void init(int n_) { n = n_; g.assign(n, {}); color.assign(n, -1); }
    void addEdge(int u, int v) { g[u].push_back(v); g[v].push_back(u); }

    bool check()
    {
        queue<int> q;
        for (int s = 0; s < n; s++)
        {
            if (color[s] != -1) continue;
            color[s] = 0; q.push(s);
            while (!q.empty())
            {
                int u = q.front(); q.pop();
                for (int v : g[u])
                {
                    if (color[v] == -1) color[v] = color[u] ^ 1, q.push(v);
                    else if (color[v] == color[u]) return false;
                }
            }
        }
        return true;
    }

    const vector<int>& getColor() const { return color; }
};

2-SAT 问题

算法介绍

2-SAT 要求判断由若干子句构成的合取范式是否可满足,每个子句形如 (xᵢ ∨ xⱼ) 或其字面取反。经典解法构建蕴含图,对变量 x 建两个点 x 与 ¬x,子句 (a ∨ b) 等价于 (¬a → b) 与 (¬b → a)。在蕴含图上分解强连通分量,若同一变量的 x 与 ¬x 落在同一 SCC 中则无解;否则按分量的逆拓扑序给出可行赋值。实现通常用 Kosaraju 或 Tarjan,时间 O(n+m)。

常见例题

题目:给定 n 个布尔变量与 m 条约束,每条约束是“变量 a 取值可能取反后与变量 b 取值可能取反之间至少一个为真”,判断是否可满足并给出一个方案。

做法:把每条 (litA ∨ litB) 转成两条蕴含边,跑 SCC 判冲突并按分量拓扑序取值。

代码
// 2-SAT 模板,变量下标 [0..n-1],每个变量映射为两个点:x 与 x^1 表示互为取反
// 点编号规则:id(x, val) = x<<1 | (val?1:0),val=1 表示真,0 表示假
// 功能:addOr(a, va, b, vb) 添加 (a==va) ∨ (b==vb);solve() 返回是否可满足并给出赋值
struct TwoSat
{
    int n;
    vector<vector<int>> g, gr;
    vector<int> comp, vis, order;  // Kosaraju:order 为第一遍后序
    vector<int> assignVal;         // 变量最终取值

    TwoSat(int n_ = 0) { init(n_); }
    void init(int n_)
    {
        n = n_;
        g.assign(2 * n, {}); gr.assign(2 * n, {});
        comp.assign(2 * n, -1); vis.assign(2 * n, 0); order.clear();
        assignVal.assign(n, 0);
    }

    static int id(int x, bool val) { return x << 1 | (val ? 1 : 0); }
    static int neg(int u) { return u ^ 1; }

    // 添加蕴含边 u->v
    void addImp(int u, int v) { g[u].push_back(v); gr[v].push_back(u); }

    // (a==va) ∨ (b==vb)
    void addOr(int a, bool va, int b, bool vb)
    {
        int u = id(a, va), nu = neg(u);
        int v = id(b, vb), nv = neg(v);
        addImp(nu, v), addImp(nv, u);
    }

    // 强制 a==va
    void addTrue(int a, bool va) { addImp(neg(id(a, va)), id(a, va)); }

    // a==b
    void addEqual(int a, int b)
    {
        addOr(a, 1, b, 0); addOr(a, 0, b, 1);
        addOr(b, 1, a, 0); addOr(b, 0, a, 1);
    }

    // a XOR b == 1
    void addXor(int a, int b)
    {
        addOr(a, 1, b, 1); addOr(a, 0, b, 0);
    }

    // Kosaraju 求 SCC 并给出赋值
    bool solve()
    {
        // 第一遍 DFS:后序栈
        auto dfs1 = [&](auto &&self, int u) -> void
        {
            vis[u] = 1;
            for (int v : g[u]) if (!vis[v]) self(self, v);
            order.push_back(u);
        };
        for (int u = 0; u < 2 * n; u++) if (!vis[u]) dfs1(dfs1, u);

        // 第二遍 DFS:反图按 order 逆序染色
        int cid = 0;
        auto dfs2 = [&](auto &&self, int u) -> void
        {
            comp[u] = cid;
            for (int v : gr[u]) if (comp[v] == -1) self(self, v);
        };
        for (int i = 2 * n - 1; i >= 0; i--)
        {
            int u = order[i];
            if (comp[u] == -1) dfs2(dfs2, u), cid++;
        }

        // 冲突判定与拓扑序赋值
        for (int x = 0; x < n; x++) if (comp[id(x, 0)] == comp[id(x, 1)]) return false;
        vector<pair<int,int>> seq;
        seq.reserve(2 * n);
        for (int u = 0; u < 2 * n; u++) seq.emplace_back(comp[u], u);
        sort(seq.rbegin(), seq.rend());                // 按分量拓扑序(大到小)
        vector<int> val(2 * n, -1);                    // 结点布尔
        for (auto [_, u] : seq) if (val[u] == -1) val[u] = 0, val[neg(u)] = 1;
        for (int x = 0; x < n; x++) assignVal[x] = val[id(x, 1)];
        return true;
    }

    const vector<int>& assignment() const { return assignVal; } // 0/1
};

支配树

算法介绍

在以起点 s 可达的有向图中,若每条从 s 到 v 的路径都经过 u,则称 u 支配 v。每个 v 除 s 外存在唯一的立即支配点 idom(v),由所有 idom 关系构成的树即支配树。经典做法是 Lengauer–Tarjan 算法,核心是 DFS 编号后计算半支配点 semi,再用并查集的带权路径压缩与评估函数 eval/link 完成 idom 计算,总复杂度接近线性。

常见例题

题目:给定以 0 为入口的控制流图,构造其支配树并回答若干查询“u 是否支配 v”。

做法:用 Lengauer–Tarjan 求出每个点的 idom,随后在支配树上做一次 DFS 取入栈出栈时间 tin/tout,查询时判断 u 是否祖先即可。

代码
// 支配树(Lengauer–Tarjan),点编号 [0..n-1],只考虑从 root 可达部分
// 功能:build(root) 之后可查询 idom[v] 与支配树 children
struct DominatorTree
{
    int n;
    vector<vector<int>> g, rg, bucket, tree;   // 正图、反图、半支配桶、支配树
    vector<int> arr, rev, par, semi, idom, dsu, best, tin, tout;
    int timer;

    DominatorTree(int n_ = 0) { init(n_); }
    void init(int n_)
    {
        n = n_;
        g.assign(n, {}); rg.assign(n, {});
        bucket.assign(n, {}); tree.assign(n, {});
        arr.assign(n, -1); rev.assign(n, -1); par.assign(n, -1);
        semi.assign(n, -1); idom.assign(n, -1);
        dsu.assign(n, -1); best.assign(n, -1);
        tin.assign(n, 0); tout.assign(n, 0); timer = 0;
    }

    void addEdge(int u, int v) { g[u].push_back(v); }

    void build(int root)
    {
        // 1) DFS 编号
        int T = 0;
        auto dfs = [&](auto &&self, int u) -> void
        {
            arr[u] = T; rev[T] = u; semi[T] = T; dsu[T] = T; best[T] = T; T++;
            for (int v : g[u])
            {
                if (arr[v] == -1) par[v] = u, self(self, v);
                rg[arr[v]].push_back(arr[u]); // 反图按 DFS 序编号连边
            }
        };
        dfs(dfs, root);
        if (T == 0) return; // root 不可达

        // 2) Lengauer–Tarjan 主体
        auto eval = [&](auto &&self, int v) -> int
        {
            if (dsu[v] == v) return v;
            int u = self(self, dsu[v]);
            if (semi[best[dsu[v]]] < semi[best[v]]) best[v] = best[dsu[v]];
            return dsu[v] = u;
        };
        auto link = [&](int v, int p) { dsu[v] = p; };

        for (int w = T - 1; w >= 1; w--)
        {
            // 计算 semi[w]
            for (int v : rg[w])
            {
                int u = eval(eval, v);
                if (semi[u] < semi[w]) semi[w] = semi[u];
            }
            bucket[semi[w]].push_back(w);
            link(w, arr[par[rev[w]]]);

            // 处理等待在父亲处的点
            for (int v : bucket[arr[par[rev[w]]]])
            {
                int u = eval(eval, v);
                idom[v] = semi[u] < semi[v] ? u : arr[par[rev[w]]];
            }
            bucket[arr[par[rev[w]]]].clear();
        }

        // 3) 修正 idom 并映射回原编号
        for (int w = 1; w < T; w++) if (idom[w] != semi[w]) idom[w] = idom[idom[w]];
        vector<int> idomReal(n, -1);
        idomReal[rev[0]] = -1; // 根无 idom
        for (int w = 1; w < T; w++) idomReal[rev[w]] = rev[idom[w]];
        idom.swap(idomReal);

        // 4) 构造支配树并做 dfs 标记时间
        for (int v = 0; v < n; v++) if (idom[v] != -1) tree[idom[v]].push_back(v);
        auto dfs2 = [&](auto &&self, int u) -> void
        {
            tin[u] = ++timer;
            for (int v : tree[u]) self(self, v);
            tout[u] = ++timer;
        };
        dfs2(dfs2, root);
    }

    // 查询立即支配点
    int getIdom(int v) const { return idom[v]; }

    // 判断 u 是否支配 v(在支配树上 u 是否为 v 的祖先)
    bool dominates(int u, int v) const { return tin[u] <= tin[v] && tout[v] <= tout[u]; }
};

六、数学

数论

筛法求素数

算法介绍

采用欧拉筛(线性筛)一次遍历同时维护素数表与每个合数的最小质因子,使每个合数只被它的最小质因子“标记”一次,从而达到 O(n) 的时间复杂度,空间 O(n)。实现要点是维护最小质因子数组 minPrime,并在遍历每个 i 时仅用不超过 minPrime[i] 的素数去扩展。

代码
// 欧拉筛(线性筛)
// 功能:init(n) 线性时间筛出 [2..n] 全部素数,并记录每个数的最小质因子 minPrime;isPrime(x) O(1) 判断;factorize(x) 返回分解结果
struct EulerSieve
{
    int n;                // 当前筛的上界
    vector<int> primes;   // 素数表(升序)
    vector<int> minPrime; // minPrime[x] 为 x 的最小质因子(素数的最小质因子就是自身)

    EulerSieve(int n_ = 0)
    {
        if (n_)
            init(n_);
        else
            n = 0;
    } // 构造函数,可选直接筛;否则保持空状态

    void init(int n_) // 初始化并执行线性筛
    {
        n = n_;                    // 记录上界
        minPrime.assign(n + 1, 0); // 最小质因子数组清零(0 表示尚未确定)
        primes.clear();
        primes.reserve(n / 10);      // 预留一部分空间,经验值加速
        for (int i = 2; i <= n; i++) // i 从 2 到 n 逐个处理
        {
            if (!minPrime[i])
                minPrime[i] = i, primes.push_back(i); // 若未标记,则 i 为素数,最小质因子设为自身并加入素数表
            for (int p : primes)                      // 用素数表中的素数尝试扩展
            {
                ll v = 1LL * p * i; // 计算乘积 p*i(用 ll 防止溢出)
                if (p > minPrime[i] || v > n)
                    break;       // 只用不大于 minPrime[i] 的素数扩展;越界则停止
                minPrime[v] = p; // 合数 v 的最小质因子即为当前 p
            }
        }
    }

    bool isPrime(int x) // O(1) 判断素性(需保证 2<=x<=n)
    {
        if (x < 2 || x > n)
            return false;        // 越界或小于 2 均非素数
        return minPrime[x] == x; // 素数的最小质因子等于其自身
    }

    vector<pair<int, int>> factorize(int x) // 返回 x 的质因子分解(<质因子, 幂次>),需 1<x<=n
    {
        vector<pair<int, int>> ret; // 结果容器
        while (x > 1)               // 反复提取最小质因子
        {
            int p = minPrime[x], c = 0; // 当前最小质因子 p 以及其幂次计数 c
            while (x % p == 0)
                x /= p, ++c;        // 连续整除统计幂次
            ret.emplace_back(p, c); // 记录因子与幂次
        }
        return ret; // 返回分解结果
    }
};

快速幂

算法介绍

二进制幂把指数按二进制展开,仅做 O(log b) 次乘方即得 a^b。

代码
// 快速幂
const ll mod = 1e9 + 7;
struct FastPower
{
    ll fpow(ll x, ll k, ll p = mod)
    {
        ll res = 1; // 初始答案为 1
        while (k)          // 迭代直到 b 归零
        {
            if (k & 1)
                res = res * x % p;  // 若当前位为 1,累计乘
            x = x * x % p, k >>= 1; // 底数平方,指数右移
        }
        return res;
    }
};

扩展欧几里得 (EXGCD)

算法介绍

扩展欧几里得在求 gcd(a,b) 的同时给出整数解 x、y 满足 ax+by=gcd(a,b)。实现时注意负数处理、整除截断方向与返回值的一致性(本实现等价于 cp-algorithms 的标准写法)。

代码
// 扩展欧几里得
// 功能:exgcd(a,b,x,y) 求 gcd(a,b) 及一组系数 x,y 使 ax+by=g;solveLinear(a,m,c, x0, g) 用于解 ax≡c(mod m)
struct Exgcd
{
    static ll exgcd(ll a, ll b, ll &x, ll &y) // 返回 gcd(a,b),并通过引用返回 x,y
    {
        if (!b)
        {
            x = (a >= 0 ? 1 : -1), y = 0;
            return llabs(a);
        } // 递归基:b==0 时 gcd=|a|,此时 x=±1, y=0
        ll x1, y1;
        ll g = exgcd(b, a % b, x1, y1); // 递归计算 gcd(b, a%b) 及对应系数 x1,y1
        x = y1, y = x1 - (a / b) * y1;         // 回代得到当前层的 x,y
        return g;                              // 返回 gcd
    }

    static bool solveLinear(ll a, ll m, ll c, ll &x0, ll &g) // 解线性同余 ax≡c(mod m),输出一组解 x0 与 gcd g
    {
        ll y;
        g = exgcd(a, m, x0, y); // 先求出 ax+my=g 的一组解
        if (c % g)
            return false;                        // 若 c 不能被 g 整除则无解
        x0 = ((__int128)x0 * (c / g)) % (m / g); // 将通解中特解缩放为满足 ax≡c 的解,并对模 (m/g) 化到标准范围
        if (x0 < 0)
            x0 += m / g; // 调整为非负最小代表
        return true;     // 返回可解
    }
};

模逆元

算法介绍

在质数模 p 下利用费马小定理 a^{p−2} 得逆;一般模下只要 gcd(a,m)=1 即可用 EXGCD 求到逆元。若需要 1..n 的所有逆元还可 O(n) 线性预处理。


中国剩余定理 (CRT)

算法介绍

通用 CRT 逐对合并:对 x≡a1(mod m1)、x≡a2(mod m2),先用 exgcd 检查一致性,再在模 lcm(m1,m2) 的意义下构造合并后的同余。没有公共解时直接判无解。

代码
// 通用 CRT(允许非互质模)
// 功能:merge(a1,m1,a2,m2) 合并两条同余;crt(A,M) 合并一组同余;无解以 second=-1 表示
struct CRT
{
    static pair<ll, ll> merge(ll a1, ll m1, ll a2, ll m2) // 合并两式,返回 <r, M>
    {
        if (m1 == -1 || m2 == -1)
            return {0, -1}; // 若任一已无解,直接无解
        ll x, y;
        ll g = Exgcd::exgcd(m1, m2, x, y); // 求 m1*x + m2*y = g
        if ((a2 - a1) % g)
            return {0, -1};                                     // 若常数项差不能被 g 整除,无解
        ll lcm = m1 / g * m2;                            // 新模为 lcm(m1,m2)
        ll t = ((__int128)(a2 - a1) / g * x) % (m2 / g); // 计算调整系数 t
        ll r = (a1 + (__int128)m1 * t) % lcm;            // 构造新解 r
        if (r < 0)
            r += lcm;    // 调整为非负最小代表
        return {r, lcm}; // 返回合并结果
    }

    static pair<ll, ll> crt(const vector<ll> &A, const vector<ll> &M) // 合并多条同余
    {
        ll r = 0, mod = 1;               // 初始为空系统:x≡0(mod 1)
        for (int i = 0; i < (int)A.size(); i++) // 逐条合并
        {
            auto cur = merge(r, mod, ((A[i] % M[i]) + M[i]) % M[i], M[i]); // 将 A[i] 规范到 [0,Mi)
            if (cur.second == -1)
                return {0, -1};              // 若无解则返回无解
            r = cur.first, mod = cur.second; // 更新当前系统
        }
        return {r, mod}; // 返回合并后的整体解
    }
};

Miller–Rabin 素数测试

算法介绍

将 n−1 写作 2^s·d,随机或固定底数 a 做幂链测试;若在幂链中没有出现 1 或 n−1,则判为合数。对 64 位范围,使用一组固定底数可以实现确定性判断,并在测试前用若干小素数试除可显著降常。下面实现使用常见的一组 64 位确定性底数并采用 128 位中间量模乘

代码
// 64位确定性 Miller–Rabin
// 功能:isPrime(n) 判断无符号64位整数的素性;先小素数试除,再走 MR 幂链
struct MillerRabin
{
    static ull mul(ull a, ull b, ull m) // 安全模乘
    {
        return (unsigned __int128)a * b % m; // 128 位相乘取模
    }

    static ull powm(ull a, ull e, ull n) // 快速幂 a^e mod n
    {
        ull r = 1;
        a %= n;   // 初始 r=1,底数入模
        while (e) // 指数大于 0 继续
        {
            if (e & 1)
                r = mul(r, a, n);      // 二进制幂:当前位为 1 则乘
            a = mul(a, a, n), e >>= 1; // 底数平方,指数右移
        }
        return r; // 返回结果
    }

    static bool trial(ull a, ull s, ull d, ull n) // 对底 a 执行一次 MR 试验
    {
        ull x = powm(a, d, n); // 先算 a^d mod n
        if (x == 1 || x == n - 1)
            return true;                           // 若一开始落在安全值则通过
        for (ull i = 1; i < s; i++) // 连续平方 s-1 次
        {
            x = mul(x, x, n); // x ← x^2 mod n
            if (x == n - 1)
                return true; // 任何一步命中 n-1 即通过
        }
        return false; // 否则判合数
    }

    static bool isPrime(ull n) // 主过程:返回 n 是否为素
    {
        if (n < 2)
            return false; // 0、1 非素数
        for (ull p : {2ull, 3ull, 5ull, 7ull, 11ull, 13ull, 17ull, 19ull, 23ull, 29ull, 31ull, 37ull})
        {
            if (n % p == 0)
                return n == p; // 先用少量小素数试除,整除即看是否等于它本身
        }
        ull d = n - 1, s = 0; // 将 n-1 分解为 2^s·d
        while ((d & 1) == 0)
            d >>= 1, ++s; // 连续除以 2 得到奇数 d 与次数 s
        for (ull a : {2ull, 3ull, 5ull, 7ull, 11ull, 13ull, 17ull, 19ull, 23ull, 29ull, 31ull, 37ull})
        {
            if (!trial(a % n, s, d, n))
                return false; // 对每个底数做一次试验,失败即合数
        }
        return true; // 全部通过则为素数
    }
};

Pollard–Rho 大整数分解

算法介绍

利用伪随机迭代序列与“龟兔赛跑”寻找非平凡因子,期望时间与最小质因子 p 的平方根同阶,空间 O(1)。与 Miller–Rabin 联用即可在 64 位范围内高效分解。实现关键在于 64 位安全模乘、随机重启以及递归拆分。

代码
// 64位 Pollard–Rho 分解
// 功能:factor(n,res) 将 n 分解为若干质因子(含重数)压入 res;内部先判素再分解
struct PollardRho
{
    static ull mul(ull a, ull b, ull m) // 安全模乘
    {
        return (unsigned __int128)a * b % m;                                                        // 128 位相乘取模
    }

    static ull f(ull x, ull c, ull mod) // 迭代函数 f(x)=x^2+c (mod mod)
    {
        return (mul(x, x, mod) + c) % mod;                                                          // 计算一次迭代
    }

    static ull rho(ull n)                                             // 返回 n 的一个非平凡因子
    {
        if ((n & 1ull) == 0) return 2;                                                              // 偶数直接返回 2
        mt19937_64 rng(random_device{}());                                                // 随机数引擎
        uniform_int_distribution<ull> dist(2, n - 2);                           // 在 [2, n-2] 上取随机数
        while (true)                                                                                // 不断尝试直到成功
        {
            ull x = dist(rng), y = x, c = dist(rng); if (c == 0) c = 1;              // 随机初始化 x,y,c;c 不为 0
            ull d = 1;                                                               // d 为当前 gcd 候选
            while (d == 1)                                                                          // 当未找到因子时持续迭代
            {
                x = f(x, c, n); y = f(f(y, c, n), c, n);                                            // 龟兔赛跑:x 走一步,y 走两步
                ull diff = x > y ? x - y : y - x;                                    // 计算 |x-y|
                d = gcd(diff, n);                                                              // 取 gcd
            }
            if (d != n) return d;                                                                   // 找到非平凡因子则返回
        }
    }

    static void factor(ull n, vector<ull> &res)                       // 将 n 分解到 res
    {
        if (n == 1) return;                                                                         // 1 无需分解
        if (MillerRabin::isPrime(n)) { res.push_back(n); return; }                                  // 素数直接放入结果
        ull d = rho(n);                                                              // 找到一个非平凡因子 d
        factor(d, res), factor(n / d, res);                                                         // 递归分解 d 与 n/d
    }
};

离散对数(BSGS)

算法介绍

在模 m 的乘法群内求 g^x≡a 的最小非负解 x。设 n≈⌈√m⌉,先预存 baby 步 g^j 的哈希映射,再以 step=g^{-n} 为步长做 giant 步并查表匹配。若 gcd(g,m)≠1,可逐步约去公共因子并修正目标,直到进入可逆情形。本实现按 cp-algorithms 的思路给出通用 BSGS。

代码
// BSGS(Baby-step Giant-step)
// 功能:solve(g,a,m) 求最小非负 x 使 g^x≡a(mod m),若不存在返回 -1;兼容 gcd(g,m)≠1 的约化流程
struct BSGS
{
    static ll solve(ll g, ll a, ll m) // 返回离散对数解
    {
        g %= m, a %= m; // 先把底与目标化到模内
        if (m == 1)
            return 0;                                                  // 模 1 退化情形
        ll k = 0, t = 1;                                        // k 记录约化次数,t 累计被约去的常数因子
        for (ll d = gcd(g, m); d != 1; d = gcd(g, m)) // 当 gcd(g,m)≠1 时循环约化
        {
            if (a % d)
                return -1;                              // a 不可被 d 整除则无解
            m /= d, a /= d, t = (t * (g / d)) % m, ++k; // 同时缩模、缩 a,更新常数 t 并累加约化步数
            if (t == a)
                return k; // 约化过程中恰好命中答案则直接返回
        }

        ll n = (ll)sqrt((ld)m) + 1; // 取步长 n≈√m
        unordered_map<ll, ll> table;
        table.reserve(n * 2); // 预分配哈希表,存 baby 步
        ll cur = 1;    // cur = g^j
        for (ll j = 0; j < n; j++)
        {
            if (!table.count(cur))
                table[cur] = j;
            cur = (__int128)cur * g % m;
        } // 预存 j 与 g^j 的映射(只保留最小 j)

        ll gn = 1; // gn = g^n
        for (ll i = 0; i < n; i++)
            gn = (__int128)gn * g % m; // 连乘获得 g^n(也可用快速幂)
        ll invGn = ModInverse::invCoprime(gn, m);
        if (invGn == -1)
            return -1;                                        // 需要 g^n 的逆;若不存在说明仍不在乘法群
        ll rhs = a * ModInverse::invCoprime(t, m) % m; // 将方程化为 (g^n)^i * g^j ≡ a * t^{-1}

        for (ll i = 0; i <= n; i++) // giant 步逐步右乘 inv(gn)
        {
            auto it = table.find(rhs); // 查表是否存在某个 j 使得 g^j == rhs
            if (it != table.end())
                return i * n + it->second + k; // 命中即返回 x = i*n + j + 已约化次数 k
            rhs = (__int128)rhs * invGn % m;   // 下一次巨步:乘以 g^{-n}
        }
        return -1; // 遍历失败则无解
    }
};

原根求法

算法介绍

对奇素数 p,先分解 φ(p)=p−1 的全部质因子集合 {q_i},再枚举 g≥2,若对所有 q_i 都满足 g^{(p−1)/q_i}≠1(mod p),则 g 为原根。这里用 Pollard–Rho 分解 p−1,并用快速幂检验候选。

代码
// 原根(质数模)
// 功能:primitiveRoot(p) 返回最小原根;p=2 返回 1;非素数返回 -1
struct PrimitiveRoot
{
    static vector<ull> uniquePrimeFactors(ull x) // 求 x 的不同质因子集合
    {
        vector<ull> fac, all; // fac 保存去重后的因子;all 存放带重因子
        PollardRho::factor(x, all);          // 调用 Pollard-Rho 分解 x
        sort(all.begin(), all.end());        // 排序以便去重
        for (size_t i = 0; i < all.size();)  // 扫描去重
        {
            fac.push_back(all[i]); // 收集一个新的质因子
            size_t j = i;
            while (j < all.size() && all[j] == all[i])
                j++; // 跳过相同的因子
            i = j;   // 前进到下一个不同因子
        }
        return fac; // 返回去重后的质因子集合
    }

    static ll primitiveRoot(ll p) // 返回最小原根
    {
        if (p == 2)
            return 1; // p=2 的唯一原根是 1
        if (!MillerRabin::isPrime(p))
            return -1;                                            // 非素数没有保证存在原根,这里直接返回 -1
        ull phi = p - 1;                           // φ(p)=p-1(p 为素)
        vector<ull> fac = uniquePrimeFactors(phi); // 分解 p-1 得到所有不同质因子
        for (ll g = 2; g < p; g++)                         // 从 2 开始枚举候选 g
        {
            bool ok = true;    // 标记当前 g 是否通过所有检验
            for (auto q : fac) // 对每个 q 检验 g^{phi/q}≠1
            {
                if (FastPower::powMod(g, phi / q, p) == 1)
                {
                    ok = false;
                    break;
                } // 若等于 1,g 非原根,失败
            }
            if (ok)
                return g; // 全部不等于 1 则 g 是原根,返回之
        }
        return -1; // 理论上不会到达(素数一定存在原根),兜底返回
    }
};

矩阵与线性代数

高斯消元

算法介绍

用行初等变换把增广矩阵化为上三角或行阶梯形,随后回代得到解;若某行系数全零而常数项非零,则无解;若自由变量存在则有无穷多解。数值场景中采用按列的局部主元选择以提升稳定性;在竞赛中通常使用 ld 并设置合适的误差阈值以判断“接近零”的枢轴。

代码
// 高斯消元(部分选主元,实数域)—— 逐行详细注释
// 功能:给定 n×m 系数矩阵 A 与长度为 n 的常数向量 b,解线性方程组 A x = b;返回是否有解,并在 ans 中给出一组解;可同时返回秩与是否唯一
struct Gauss
{
    int n;                // 存放方程数量(行数)
    int m;                // 存放未知数数量(列数)
    ld eps;               // 判断零的阈值
    vector<vector<ld>> a; // 增广矩阵存放区:尺寸为 n×(m+1),最后一列是常数项

    Gauss(int n_ = 0, int m_ = 0, ld eps_ = 1e-12) // 构造函数,允许直接给定尺寸与误差阈值
    {
        if (n_ && m_)
            init(n_, m_, eps_); // 若传入合法尺寸则立即初始化
        else
            n = 0, m = 0, eps = eps_; // 否则仅记录 eps 并保持空矩阵
    }

    void init(int n_, int m_, ld eps_ = 1e-12) // 初始化尺寸与阈值并清零增广矩阵
    {
        n = n_, m = m_, eps = eps_;        // 记录行列与阈值
        a.assign(n, vector<ld>(m + 1, 0)); // 分配 n 行、m+1 列并全部置零
    }

    int sgn(ld x) // 返回带阈值的符号:-1/0/1
    {
        if (fabsl(x) <= eps)
            return 0;          // 绝对值不超过 eps 视为 0
        return x < 0 ? -1 : 1; // 否则按正负返回
    }

    bool solve(vector<ld> &ans, int &rankA, bool &unique) // 主过程:消元与回代;返回是否有解;输出一组解、矩阵秩与是否唯一
    {
        ans.assign(m, 0);         // 预设答案向量为全 0(若存在自由变量可视为取 0)
        rankA = 0;                // 初始化秩为 0
        vector<int> where(m, -1); // where[j] 记录第 j 列(变量 j)在哪一行成为主变量,默认 -1 表示自由

        for (int col = 0, row = 0; col < m && row < n; col++) // 逐列选择主元并向下消元,row 指向当前将要放置主元的行
        {
            int piv = row;                // 初始认为当前行就是主元行
            for (int i = row; i < n; i++) // 在当前列 col 的 [row..n-1] 中寻找绝对值最大的元素作为主元
                if (fabsl(a[i][col]) > fabsl(a[piv][col]))
                    piv = i; // 若发现更大的绝对值则更新主元行
            if (!sgn(a[piv][col]))
                continue; // 若整列均近似为 0,则此列无主元,继续下一列

            if (piv != row)
                swap(a[piv], a[row]);    // 将最大主元行交换到当前位置 row,提高数值稳定性
            where[col] = row;            // 记录变量 col 的主元出现在行 row
            ld inv = 1.0L / a[row][col]; // 计算主元的倒数以便归一化
            for (int j = col; j <= m; j++)
                a[row][j] *= inv; // 将主元行整体除以主元,使主元变为 1(从 col 开始到常数列一起缩放)

            for (int i = 0; i < n; i++)
                if (i != row) // 对其余所有行做消元(列 col 清为 0)
                {
                    if (!sgn(a[i][col]))
                        continue;                      // 若该行在 col 列本就为 0,跳过
                    ld factor = a[i][col];             // 取当前行在主元列的系数作为倍数
                    for (int j = col; j <= m; j++)     // 从 col 开始,行 i ← 行 i − factor × 行 row
                        a[i][j] -= factor * a[row][j]; // 元素级更新,确保该列被消成 0
                }

            row++; // 已成功放置一个主元行,行指针下移
        }

        for (int i = 0; i < n; i++) // 消元结束后检查是否出现 0 = 非 0 的矛盾行
        {
            bool allZero = true; // 标记本行系数是否全 0
            for (int j = 0; j < m; j++)
                if (sgn(a[i][j])) // 只要有一个非零系数就不是全 0
                {
                    allZero = false;
                    break;
                } // 发现非零后立即退出
            if (allZero && sgn(a[i][m]))
                return false; // 若系数全 0 但常数不为 0,则无解
        }

        for (int j = 0; j < m; j++)
            if (where[j] != -1)          // 根据 where 写出一组解:主变量直接由对应行的常数给出
                ans[j] = a[where[j]][m]; // 第 j 个未知数的值就是该行的常数项(主元为 1 且列已被清)
        rankA = 0;
        for (int j = 0; j < m; j++)
            if (where[j] != -1)
                rankA++;       // 统计主元列的个数作为秩
        unique = (rankA == m); // 若主元覆盖所有列则解唯一,否则存在自由变量
        return true;           // 到此说明无矛盾,至少存在一组可行解(自由变量取 0)
    }
};

矩阵快速幂

算法介绍

将二进制快速幂推广到矩阵乘法的半群;利用幂的二进制展开仅需 O(log k) 次乘法就能求得 A^k。竞赛中常在模 p 下计算以避免溢出,并按行列三重循环实现乘法;乘法内部可以跳过零项以降常。该技巧常用来求线性递推第 n 项或构造状态转移解 DP。

代码
// 方阵与矩阵快速幂(模意义)—— 逐行详细注释
// 功能:提供 n×n 方阵的乘法与幂运算;所有元素按 mod 取模;适合线性递推与图上路径计数等问题
template <typename T>
struct Matrix
{
    int n;               // 方阵维度
    ll mod;              // 取模数(若为 0 表示不取模但需保证不溢出)
    vector<vector<T>> a; // 存储矩阵元素的二维数组(行主序)

    Matrix(int n_ = 0, ll mod_ = 0, bool ident = false) // 构造函数,允许指定维度、模数与是否构造单位阵
    {
        if (n_)
            init(n_, mod_, ident); // 若给定维度则立即初始化
        else
            n = 0, mod = mod_; // 否则仅记录模数,维度置 0
    }

    void init(int n_, ll mod_, bool ident = false) // 初始化矩阵尺寸、模数并根据 ident 决定是否置单位阵
    {
        n = n_, mod = mod_;             // 记录维度与模数
        a.assign(n, vector<T>(n, T{})); // 分配 n×n 并全部置零
        if (ident)
            for (int i = 0; i < n; i++)
                a[i][i] = 1; // 如需单位阵则将对角线元素置 1
    }

    static Matrix identity(int n, ll mod) // 工具函数:构造 n 维单位阵(同一模数)
    {
        return Matrix(n, mod, true); // 直接使用带 ident 的构造
    }

    Matrix operator*(const Matrix &o) const // 重载乘法:C = A * B(维度一致)
    {
        Matrix res(n, mod, false);  // 结果矩阵初始化为全 0
        for (int i = 0; i < n; i++) // 枚举行 i
        {
            for (int k = 0; k < n; k++) // 枚举中间维 k(A 的列 / B 的行)
            {
                if (a[i][k] == T{})
                    continue;                                       // 若 A[i][k] 为 0,可跳过以降常
                ll aik = (mod ? (ll)(a[i][k] % mod) : (ll)a[i][k]); // 取出 A[i][k] 并按需取模
                for (int j = 0; j < n; j++)                         // 枚举列 j
                {
                    ll add = aik * (mod ? (ll)(o.a[k][j] % mod) : (ll)o.a[k][j]); // 计算一项乘积(按需取模)
                    ll cur = (mod ? (ll)(res.a[i][j] % mod) : (ll)res.a[i][j]);   // 取出现有值(按需取模)
                    ll nxt = mod ? (cur + add) % mod : (cur + add);               // 做加法并按需取模
                    res.a[i][j] = (T)nxt;                                         // 写回结果位置
                }
            }
        }
        return res; // 返回乘积
    }

    static Matrix power(Matrix base, ll exp) // 快速幂:返回 base^exp
    {
        Matrix res = identity(base.n, base.mod); // 初始答案设为单位阵
        while (exp)                              // 只要指数未清零就继续
        {
            if (exp & 1)
                res = res * base;          // 若当前位为 1,则把当前底乘入答案
            base = base * base, exp >>= 1; // 底矩阵自乘(平方),指数右移一位
        }
        return res; // 返回最终的幂
    }
};

杂项

排列与组合计算

算法介绍

组合数学的核心问题是计算排列数与组合数。排列数公式为 P(n, k) = n! / (n − k)!,组合数公式为 C(n, k) = n! / (k!(n − k)!)。在模数为质数的情况下,通常预处理阶乘与阶乘逆元数组,再通过 O(1) 时间求得结果。

常见例题

给定 n 和 m,求 C(n, m) 模 1e9+7 的值。做法是先预处理 0..n 的阶乘和逆元,之后直接调用组合数函数即可。

代码
// 排列与组合计算,模数必须为质数
// 功能:O(1) 时间计算排列数与组合数
// 复杂度:预处理O(n),单次查询O(1)
struct Comb
{
    int n;                     // 最大值
    int mod;                   // 模数
    vector<ll> fact, invFact;  // 阶乘与阶乘逆元

    // 构造函数,直接完成预处理
    Comb(int n_, int mod_) { init(n_, mod_); }

    // 预处理阶乘与逆元
    void init(int n_, int mod_)
    {
        n = n_; mod = mod_;
        fact.assign(n + 1, 1);
        invFact.assign(n + 1, 1);
        for (int i = 1; i <= n; i++) fact[i] = fact[i - 1] * i % mod;
        invFact[n] = qpow(fact[n], mod - 2);
        for (int i = n; i > 0; i--) invFact[i - 1] = invFact[i] * i % mod;
    }

    // 快速幂
    ll qpow(ll a, ll e)
    {
        ll r = 1;
        while (e) { if (e & 1) r = r * a % mod; a = a * a % mod; e >>= 1; }
        return r;
    }

    // 计算排列数P(n,k)
    ll perm(int n_, int k)
    {
        if (k < 0 || k > n_) return 0;
        return fact[n_] * invFact[n_ - k] % mod;
    }

    // 计算组合数C(n,k)
    ll comb(int n_, int k)
    {
        if (k < 0 || k > n_) return 0;
        return fact[n_] * invFact[k] % mod * invFact[n_ - k] % mod;
    }
};

容斥原理

算法介绍

容斥原理用于计算若干集合并的大小,公式为 |∪Si| = Σ|Si| − Σ|Si∩Sj| + Σ|Si∩Sj∩Sk| − …。实际应用中常结合子集枚举,复杂度为 O(2^n)。

常见例题

求 1..n 中能被至少一个质数 p1..pk 整除的整数个数。做法是对子集枚举,用容斥原理加减贡献。

代码
// 容斥原理模板
// 功能:通过枚举子集实现集合并大小计算
// 复杂度:O(2^k),k为集合数
ll countMultiples(ll n, const vector<int> &primes)
{
    int k = primes.size();
    ll ans = 0;
    for (int mask = 1; mask < (1 << k); mask++)
    {
        ll mul = 1, bits = 0;
        for (int i = 0; i < k; i++) if (mask >> i & 1)
        {
            bits++; mul *= primes[i];
            if (mul > n) { mul = n + 1; break; }
        }
        if (mul <= n) ans += (bits % 2 ? 1 : -1) * (n / mul);
    }
    return ans;
}

莫比乌斯反演

算法介绍

莫比乌斯函数 μ(n) 在 n 为平方数倍时为 0,否则为 (−1)^k,其中 k 是质因数个数。莫比乌斯反演公式为 f(n) = Σ_{d|n} g(d) ⇔ g(n) = Σ_{d|n} μ(d) f(n/d)。常见应用是容斥计数和数论函数转换。

常见例题

给定 n,统计 1 ≤ i, j ≤ n 的最大公约数为 1 的数对数量。做法是通过莫比乌斯反演,将答案写为 Σ μ(d) floor(n/d)^2。

代码
// 莫比乌斯函数与反演
// 功能:线性筛μ,支持快速反演
// 复杂度:O(n)
struct Mobius
{
    int n;
    vector<int> mu, primes, isComp;

    // 构造函数,筛出1..n的莫比乌斯函数
    Mobius(int n_) { init(n_); }

    void init(int n_)
    {
        n = n_;
        mu.assign(n + 1, 0);
        isComp.assign(n + 1, 0);
        mu[1] = 1;
        for (int i = 2; i <= n; i++)
        {
            if (!isComp[i]) { primes.push_back(i); mu[i] = -1; }
            for (int p : primes)
            {
                if (1LL * i * p > n) break;
                isComp[i * p] = 1;
                if (i % p == 0) { mu[i * p] = 0; break; }
                else mu[i * p] = -mu[i];
            }
        }
    }
};

杜教筛

算法介绍

杜教筛是求前缀和型数论函数的优化技巧。对于积性函数 f,若 f 与 g 卷积为常函数,则 S(n) = Σ f(i) 可通过分块递归求解,复杂度 O(n^{2/3})。常见应用为 μ 与 φ 的前缀和。

常见例题

计算 Σ_{i=1}^n μ(i) 与 Σ_{i=1}^n φ(i)。直接计算复杂度为 O(n),杜教筛可降至 O(n^{2/3})。

代码
// 杜教筛模板
// 功能:求积性函数前缀和,如莫比乌斯函数、欧拉函数
// 复杂度:O(n^{2/3})
struct DuSieve
{
    int n;
    unordered_map<int, ll> cacheMu, cachePhi;

    DuSieve(int n_) { n = n_; }

    // 莫比乌斯函数前缀和
    ll sumMu(int x)
    {
        if (x <= n) return preMu[x];
        if (cacheMu.count(x)) return cacheMu[x];
        ll ans = 1;
        for (int l = 2, r; l <= x; l = r + 1)
        {
            r = x / (x / l);
            ans -= (r - l + 1) * sumMu(x / l);
        }
        return cacheMu[x] = ans;
    }

    // 欧拉函数前缀和
    ll sumPhi(int x)
    {
        if (x <= n) return prePhi[x];
        if (cachePhi.count(x)) return cachePhi[x];
        ll ans = 1LL * x * (x + 1) / 2;
        for (int l = 2, r; l <= x; l = r + 1)
        {
            r = x / (x / l);
            ans -= 1LL * (r - l + 1) * sumPhi(x / l);
        }
        return cachePhi[x] = ans;
    }

    vector<int> preMu, prePhi;

    // 预处理1..n的μ和φ
    void preprocess()
    {
        preMu.assign(n + 1, 0);
        prePhi.assign(n + 1, 0);
        vector<int> primes, isComp(n + 1);
        preMu[1] = prePhi[1] = 1;
        for (int i = 2; i <= n; i++)
        {
            if (!isComp[i]) { primes.push_back(i); preMu[i] = -1; prePhi[i] = i - 1; }
            for (int p : primes)
            {
                if (1LL * i * p > n) break;
                isComp[i * p] = 1;
                if (i % p == 0)
                {
                    preMu[i * p] = 0; prePhi[i * p] = prePhi[i] * p; break;
                }
                else
                {
                    preMu[i * p] = -preMu[i]; prePhi[i * p] = prePhi[i] * (p - 1);
                }
            }
        }
        for (int i = 2; i <= n; i++) preMu[i] += preMu[i - 1], prePhi[i] += prePhi[i - 1];
    }
};

七、计算几何

基础计算几何

缺省部


const ld eps = 1e-8;    // 定义浮点数的最小误差范围
const ld pi = acos(-1); // 定义圆周率π
const ld infd = 1e20;   // 定义一个较大的无穷数

// 比较函数,返回x的符号
int dcmp(ld x)
{
    if (fabs(x) < eps) // 如果x在误差范围内,认为x为0
        return 0;
    if (x < 0) // x小于0返回-1
        return -1;
    else // x大于0返回1
        return 1;
}

// 计算x的平方
ld squ(ld x) { return x * x; }

// 点结构体定义
struct Point
{
    ld x, y; // 点的x和y坐标

    // 默认构造函数,将点初始化为原点(0, 0)
    Point() { x = y = 0; }

    // 带参数构造函数,根据给定坐标初始化点
    Point(ld _x, ld _y) { x = _x, y = _y; }

    // 输入函数,从标准输入流中读取x和y坐标
    void input()
    {
        cin >> x;
        cin >> y;
    }

    // 重载==运算符,判断两个点是否相等
    friend bool operator==(Point A, Point B)
    {
        return dcmp(A.x - B.x) == 0 && dcmp(A.y - B.y) == 0;
    }

    // 重载<运算符,定义点的排序规则
    friend bool operator<(Point A, Point B)
    {
        return dcmp(A.x - B.x) == 0 ? dcmp(A.y - B.y) < 0 : A.x < B.x;
    }

    // 重载-运算符,计算两个点之间的向量
    friend Point operator-(Point A, Point B)
    {
        return Point(A.x - B.x, A.y - B.y);
    }

    // 重载+运算符,计算两个点的向量和
    friend Point operator+(Point A, Point B)
    {
        return Point(A.x + B.x, A.y + B.y);
    }

    // 重载*运算符,将点乘以标量
    friend Point operator*(Point A, ld k)
    {
        return Point(A.x * k, A.y * k);
    }

    // 重载/运算符,将点除以标量
    friend Point operator/(Point A, ld k)
    {
        return Point(A.x / k, A.y / k);
    }

    // 重载^运算符,计算两个向量的叉积
    friend ld operator^(Point A, Point B)
    {
        return A.x * B.y - A.y * B.x;
    }

    // 重载*运算符,计算两个向量的点积
    friend ld operator*(Point A, Point B)
    {
        return A.x * B.x + A.y * B.y;
    }

    // 计算向量的模长的平方
    ld len2() { return x * x + y * y; }

    // 计算向量的模长
    ld len() { return sqrt(len2()); }

    // 计算向量的极角
    ld angle() { return atan2(y, x); }

    // 计算以当前点为顶点的两个向量之间的夹角(单位:弧度)
    ld rad(Point A, Point B)
    {
        Point P = *this;
        return fabs(atan2((A - P) ^ (B - P), (A - P) * (B - P)));
    }

    // 将向量缩放到模长为r
    Point trunc(ld r)
    {
        ld l = len();
        if (!dcmp(l)) // 如果向量模长为0,返回原向量
            return *this;
        r /= l; // 缩放比例
        return Point(x * r, y * r);
    }

    // 将向量逆时针旋转90度
    Point rotate_left() { return Point(-y, x); }

    // 将向量顺时针旋转90度
    Point rotate_right() { return Point(y, -x); }

    // 将向量绕点P逆时针旋转ang角度
    Point rotate(Point P, ld ang)
    {
        Point v = (*this) - P;
        ld c = cos(ang), s = sin(ang);
        return Point(P.x + v.x * c - v.y * s, P.y + v.x * s + v.y * c);
    }

    // 将向量绕原点逆时针旋转ang角度
    Point rotate(ld ang) { return rotate(Point(0, 0), ang); }
};

线

struct Line
{
    Point s, e; // 线段的两个端点
    Line() {}
    Line(Point _s, Point _e)
    {
        s = _s;
        e = _e;
    }

    // 输入线段的两个端点坐标
    void input()
    {
        s.input();
        e.input();
    }

    // 如果终点比起点小,交换起点和终点
    void adjust()
    {
        if (e < s)
            swap(e, s);
    }

    // 判断两条线段是否相等,只有起点和终点都相同时返回 true
    friend bool operator==(Line A, Line B) { return A.s == B.s && A.e == B.e; }

    // 构造一条通过点 p 且与 x 轴成角 ang 的直线
    Line(Point p, ld ang)
    {
        s = p;
        e = p + (dcmp(ang - pi / 2) == 0 ? Point(0, 1) : Point(1, tan(ang)));
    }

    // 根据一般式直线方程 ax+by+c=0 构造直线
    Line(ld a, ld b, ld c)
    {
        if (dcmp(a) == 0)
            s = Point(0, -c / b), e = Point(1, -c / b); // 垂直于 x 轴
        else if (dcmp(b) == 0)
            s = Point(-c / a, 0), e = Point(-c / a, 1); // 垂直于 y 轴
        else
            s = Point(0, -c / b), e = Point(1, (-c - a) / b); // 普通斜率
    }

    // 计算线段的长度
    ld len() { return (e - s).len(); }

    // 计算线段的斜角(弧度制)
    ld angle2() { return atan2(e.y - s.y, e.x - s.x); }

    // 计算线段与 x 轴之间的夹角,并确保角度在 [0, pi) 之间
    ld angle()
    {
        ld k = angle2();
        if (dcmp(k) < 0)
            k += pi;
        if (dcmp(k - pi) == 0)
            k -= pi;
        return k;
    }

    // 判断点 p 与线段的位置关系
    // 返回值为1表示 p 在直线左侧;3表示 p 在直线上
    int relation(Point p)
    {
        int c = dcmp((p - s) ^ (e - s));
        return !c ? 3 : 1 + (c > 0);
    }

    // 判断点 p 是否在线段上
    bool PointOnSegment(Point p) { return dcmp((p - s) ^ (e - s)) == 0 && dcmp((p - s) * (p - e)) <= 0; }

    // 判断两条线段是否平行
    bool parallel(Line v) { return dcmp((e - s) ^ (v.e - v.s)) == 0; }

    // 判断两线段的相交情况
    // 返回值 2 表示规范相交,1 表示非规范相交,0 表示不相交
    int segcrossseg(Line v)
    {
        int d1 = dcmp((e - s) ^ (v.s - s));
        int d2 = dcmp((e - s) ^ (v.e - s));
        int d3 = dcmp((v.e - v.s) ^ (s - v.s));
        int d4 = dcmp((v.e - v.s) ^ (e - v.s));
        if ((d1 ^ d2) == -2 && (d3 ^ d4) == -2)
            return 2;
        return (d1 == 0 && dcmp((v.s - s) * (v.s - e)) <= 0) ||
               (d2 == 0 && dcmp((v.e - s) * (v.e - e)) <= 0) ||
               (d3 == 0 && dcmp((s - v.s) * (s - v.e)) <= 0) ||
               (d4 == 0 && dcmp((e - v.s) * (e - v.e)) <= 0);
    }

    // 判断线段与直线的相交情况
    // 返回值 2 表示规范相交,1 表示非规范相交,0 表示不相交
    int linecrossseg(Line v)
    {
        int d1 = dcmp((e - s) ^ (v.s - s));
        int d2 = dcmp((e - s) ^ (v.e - s));
        if ((d1 ^ d2) == -2)
            return 2;
        return (d1 == 0 || d2 == 0);
    }

    // 判断两直线的相交情况
    // 返回值 0 表示平行,1 表示重合,2 表示相交
    int linecrossline(Line v)
    {
        if ((*this).parallel(v))
            return v.relation(s) == 3;
        return 2;
    }

    // 计算两条直线的交点
    Point Intersection(Line v)
    {
        ld a1 = (v.e - v.s) ^ (s - v.s);
        ld a2 = (v.e - v.s) ^ (e - v.s);
        return Point((s.x * a2 - e.x * a1) / (a2 - a1), (s.y * a2 - e.y * a1) / (a2 - a1));
    }

    // 计算点到直线的距离
    ld dispointtoline(Point p) { return fabs((p - s) ^ (e - s)) / len(); }

    // 计算点到线段的距离
    ld dispointtoseg(Point p)
    {
        if (dcmp((p - s) * (e - s)) < 0 || dcmp((p - e) * (s - e)) < 0)
            return min((p - s).len(), (p - e).len());
        return dispointtoline(p);
    }

    // 计算线段到线段的距离
    // 前提是两线段不相交,如果相交距离就是0
    ld dissegtoseg(Line v)
    {
        return min(min(dispointtoseg(v.s), dispointtoseg(v.e)), min(v.dispointtoseg(s), v.dispointtoseg(e)));
    }

    // 计算点 p 在直线上的投影
    Point lineprog(Point p)
    {
        return s + (((e - s) * ((e - s) * (p - s))) / ((e - s).len2()));
    }

    // 计算点 p 关于直线的对称点
    Point symmetrypoint(Point p)
    {
        Point q = lineprog(p);
        return Point(2 * q.x - p.x, 2 * q.y - p.y);
    }
};

杂项

海伦公式求三角形面积
ld area(Point A, Point B, Point C) 
{
	return fabs((A - B) ^ (C - B)) / 2; 
}
三点叉积
// 三点叉积
ld Xmul(Point a, Point b, Point o)
{
    return (a.x - o.x) * (b.y - o.y) - (b.x - o.x) * (a.y - o.y);
}
判断点是否在三角形内
// 判断点q是否在三角形abc中(使用极角法)
bool inTriangle(Point q, Point a, Point b, Point c)
{
    Line A(a, b), B(b, c), C(c, a);

    // 如果点q在三角形的边上
    if (A.PointOnSegment(q) || B.PointOnSegment(q) || C.PointOnSegment(q))
        return true;

    // 判断点q是否在三角形内部
    // 如果q在边ab、bc上方且在边ca下方
    if (Xmul(a, b, q) > eps && Xmul(b, c, q) > eps && Xmul(a, c, q) < -eps) 
        return true;

    // 否则,点q在三角形外部
    return false;
}

基础结构体

struct Circle
{
    Point p;  // 圆心
    ld r;     // 半径

    // 默认构造函数
    Circle() {}

    // 根据圆心和半径构造圆
    Circle(Point p, ld r)
    {
        this->p = p;
        this->r = r;
    }

    // 根据圆心坐标 (x, y) 和半径 r 构造圆
    Circle(ld x, ld y, ld r)
    {
        p = Point(x, y);
        this->r = r;
    }

    // 根据三个点构造圆,opt 为 0 表示外接圆,1 表示内切圆
    Circle(Point a, Point b, Point c, bool opt)
    {
        Line u, v;
        if (opt == 0)
        {   // 构造外接圆
            u = Line((a + b) / 2, ((a + b) / 2) + ((b - a).rotate_left()));  // 通过边 AB 的中点和其垂直线的直线
            v = Line((b + c) / 2, ((b + c) / 2) + ((c - b).rotate_left()));  // 通过边 BC 的中点和其垂直线的直线
            p = u.Intersection(v);  // 求两条直线的交点即为圆心
            r = (p - a).len();  // 计算半径
        }
        else
        {   // 构造内切圆
            ld m = atan2(b.y - a.y, b.x - a.x), n = atan2(c.y - a.y, c.x - a.x);
            u.s = a;
            u.e = u.s + Point(cos((n + m) / 2), sin((n + m) / 2));  // 角平分线
            v.s = b;
            m = atan2(a.y - b.y, a.x - b.x), n = atan2(c.y - b.y, c.x - b.x);
            v.e = v.s + Point(cos((n + m) / 2), sin((n + m) / 2));  // 角平分线
            p = u.Intersection(v);  // 求两条角平分线的交点即为内切圆圆心
            r = Line(a, b).dispointtoseg(p);  // 计算半径
        }
    }

    // 判断两个圆是否相等,圆心相同且半径相等时返回 true
    friend bool operator==(Circle A, Circle B) { return A.p == B.p && dcmp(A.r - B.r) == 0; }

    // 计算圆的面积
    ld area() { return pi * r * r; }

    // 计算圆的周长
    ld circumference() { return 2 * pi * r; }

    // 判断点与圆的关系
    // 返回值为 0 表示点在圆外,1 表示点在圆上,2 表示点在圆内
    int relation(Point b)
    {
        int opt = dcmp((b - p).len() - r);
        return opt < 0 ? 2 : (opt == 0);
    }

    // 判断线段与圆的关系
    int relationseg(Line v)
    {
        int opt = dcmp(v.dispointtoseg(p) - r);
        return opt < 0 ? 2 : (opt == 0);
    }

    // 判断直线与圆的关系
    int relationline(Line v)
    {
        int opt = dcmp(v.dispointtoline(p) - r);
        return opt < 0 ? 2 : (opt == 0);
    }

    // 判断两圆的相互关系
    // 5 表示相离,4 表示外切,3 表示相交,2 表示内切,1 表示内含
    int relationcircle(Circle A)
    {
        ld d = (p - A.p).len();  // 圆心距
        if (dcmp(d - r - A.r) > 0)
            return 5;  // 相离
        if (dcmp(d - r - A.r) == 0)
            return 4;  // 外切
        return 2 + dcmp(d - fabs(r - A.r));  // 相交或内切或内含
    }

    // 求直线与圆的交点
    vector<Point> pointcrossline(Line v)
    {
        vector<Point> vec;
        vec.clear();
        if (!(*this).relationline(v))
            return vec;  // 无交点
        Point a = v.lineprog(p);  // 圆心在直线上的投影点
        ld d = v.dispointtoline(p);  // 圆心到直线的距离
        d = sqrt(r * r - d * d);  // 计算交点的距离
        if (dcmp(d) == 0)
            vec.pb(a);  // 相切时仅有一个交点
        else
            vec.pb(a + (v.e - v.s).trunc(d)), vec.pb(a - (v.e - v.s).trunc(d));  // 相交时有两个交点
        return vec;
    }

    // 求两圆的交点
    vector<Point> pointcrosscircle(Circle A)
    {
        vector<Point> vec;
        vec.clear();
        int t = relationcircle(A);
        if (t == 5 || t == 1)
            return vec;  // 无交点或内含时无交点
        ld d = (p - A.p).len();  // 圆心距
        ld l = (d * d + r * r - A.r * A.r) / (2. * d);  // 计算交点的距离
        ld h = sqrt(r * r - l * l);  // 计算交点的偏移距离
        Point q = p + (A.p - p).trunc(l);  // 圆心之间的中点
        if (t == 2 || t == 4)
        {
            vec.pb(q);  // 内切或外切时仅有一个交点
            return vec;
        }
        vec.pb(q + ((A.p - p).rotate_left()).trunc(h));  // 第一个交点
        vec.pb(q + ((A.p - p).rotate_right()).trunc(h));  // 第二个交点
        return vec;
    }
};


最小圆覆盖

// 最小圆覆盖算法
// 该函数接受一个包含若干点的向量,计算能够覆盖所有点的最小圆。
// 输入:点的向量 p
// 输出:能够覆盖所有点的最小圆

Circle smallestcircle(vector<Point> p)
{
    int n = p.size(); // 获取点的数量
    random_shuffle(all(p)); // 随机打乱点的顺序
    Circle C = Circle(p[0], 0.0); // 初始化圆为第一个点,半径为0
    for (int i = 1; i < n; i++) // 遍历所有点
        if (C.relation(p[i]) == 0) // 如果点p[i]在当前圆的外部
        {
            C = Circle(p[i], 0.0); // 更新圆为该点,半径为0
            for (int j = 0; j < i; j++) // 再次遍历之前的点
                if (C.relation(p[j]) == 0) // 如果点p[j]在当前圆的外部
                {
                    // 计算以p[i]和p[j]为直径端点的圆
                    C = Circle((p[i] + p[j]) / 2, (p[i] - p[j]).len() / 2);
                    for (int k = 0; k < j; k++) // 再次遍历之前的点
                        if (C.relation(p[k]) == 0) // 如果点p[k]在当前圆的外部
                            // 计算以p[i]、p[j]、p[k]为外接圆的圆
                            C = Circle(p[i], p[j], p[k], 0);
                }
        }
    return C; // 返回能够覆盖所有点的最小圆
}


凸包

基础结构体

// 凸包结构体定义
struct Polygon
{
    int n;           // 凸包的顶点数
    vector<Point> p; // 存储凸包顶点的向量
    vector<Line> l;  // 存储凸包边的向量

    // 默认构造函数,初始化凸包的顶点数为0,并清空顶点和边的向量
    Polygon()
    {
        n = 0;
        p.clear();
        l.clear();
    }

    // 带参数的构造函数,接受一个点的向量并构建凸包
    Polygon(vector<Point> a)
    {
        n = a.size(); // 设置顶点数量
        p = a;        // 设置顶点列表
        l.resize(n);  // 根据顶点数量调整边的数量
        for (int i = 0; i < n; i++)
            l[i] = Line(p[i], p[(i + 1) % n]); // 依次连接相邻的顶点,构建边
    }

    // 计算凸包的面积
    ld area()
    {
        ld ans = 0;
        // 使用凸包的顶点通过向量叉积计算面积
        for (int i = 2; i < n; i++)
            ans += (p[i] - p[0]) ^ (p[i - 1] - p[0]);
        return fabs(ans) / 2; // 返回绝对值的一半,即为凸包的面积
    }

    // 计算凸包的直径(即两点之间的最大距离)
    ld diameter()
    {
        if (n == 2)
        {
            return (p[0] - p[1]).len(); // 如果凸包只有两个顶点,则返回这两个点之间的距离
        }
        int j = 2;  // 初始化变量 j,用于追踪与当前边相对的最远顶点
        ld ans = 0; // 初始化答案变量,用于存储最大距离
        for (int i = 0; i < n; i++)
        {
            // 使用旋转卡壳算法找到最大直径
            while (((p[(i + 1) % n] - p[i]) ^ (p[j] - p[i])) < ((p[(i + 1) % n] - p[i]) ^ (p[(j + 1) % n] - p[i])))
                j = (j + 1) % n; // 更新 j,寻找使得叉积最大的点
            // 更新答案为当前计算的最大距离
            ans = max(ans, max((p[i] - p[j]).len(), (p[(i + 1) % n] - p[(j + 1) % n]).len()));
        }
        return ans; // 返回计算出的凸包直径
    }
	
    // 计算凸包的周长
    ld perimeter()
    {
        ld res = 0.0;
        // 如果凸包只有两个顶点,则返回这两个点之间的距离
        if (n == 2)
            return (p[0] - p[1]).len();
        // 累加每条边的长度
        for (int i = 0; i < n; i++)
        {
            res += l[i].len();
        }
        return res;
    }
    
    // 判断点是否在凸包内
    bool inside(Point q)
    {
        // 遍历凸包的每一条边
        for (int i = 0; i < n; i++)
        {
            // 使用叉积法判断点 q 是否在当前边的左侧或线上
            if (((p[(i + 1) % n] - p[i]) ^ (q - p[i])) < 0)
            {
                // 如果点 q 在任意一条边的右侧,则点在凸包外
                return false;
            }
        }
        // 如果点 q 在所有边的左侧或线上,则点在凸包内或边上
        return true;
    }
    
    // 使用二分查找判断点是否在凸包内,复杂度为 O(log n)
    bool insideLogN(Point q)
    {
        // 二分查找点 a 是否在凸多边形内部
        int l = 1;
        int r = n - 1;

        while (l < r)
        {
            int mid = (l + r) / 2;

            // 检查点 a 是否在由点 p[0], p[mid], p[mid + 1] 构成的三角形内
            if (inTriangle(q, p[0], p[mid], p[mid + 1]) == true)
                return true;

            // 检查点 a 是否在 p[0], p[mid], p[mid + 1] 的边上或在三角形之外
            if (Xmul(p[0], p[mid], q) >= 0 && Xmul(p[0], p[mid + 1], q) <= 0 && Xmul(p[mid], p[mid + 1], q) < 0)
                return false;

            // 根据点 a 与 p[0] 的相对位置调整二分查找的区间
            if (Xmul(p[0], p[mid], q) > 0 && Xmul(p[0], p[mid + 1], q) > 0)
                l = mid + 1; // 点 a 位于 mid 右侧,移动左边界
            else
                r = mid; // 点 a 位于 mid 左侧,移动右边界
        }

        return false; // 最终未找到点 a 在凸多边形内
    }
    

    // 计算点到凸包的最短距离
    ld distanceToPoint(Point p)
    {
        ld minDist = infd; // 初始化为无穷大
        for (int i = 0; i < n; i++)
        {
            minDist = min(minDist, l[i].dispointtoseg(p));
        }
        return minDist;
    }
    
    // 从输入读取凸包的顶点
    void input()
    {
        cin >> n;    // 输入顶点数
        p.resize(n); // 调整点的向量大小
        for (int i = 0; i < n; i++)
            p[i].input(); // 输入每个顶点的坐标
    }
};

求凸包

// 计算点集的凸包
Polygon ConvexHull(vector<Point> a)
{
    // 凸包的严格定义:即不存在三点共线的情况
    // 如果允许三点共线的非严格凸包,则将 <= 改为 <
    int n = a.size(), m = -1;         // n为点的数量,m为凸包点的数量(初始化为-1)
    vector<Point> p(n * 2);           // 用于存储凸包的点集,最大为2倍的输入点集大小
    sort(all(a));                     // 首先对点集按x坐标排序,若x坐标相同则按y坐标排序

    // 构建凸包的下半部分
    for (int i = 0; i < n; i++)
    {
        // 检查是否需要移除当前点p[m]以保持凸包的凸性
        // 通过计算向量叉积判断三点之间的转折方向
        while (m > 0 && ((p[m] - p[m - 1]) ^ (a[i] - p[m - 1])) <= 0)
        {
            m--; // 如果叉积为0或负数,说明存在共线或右转,移除p[m]
        }
        p[++m] = a[i]; // 将当前点添加到凸包中
    }

    // 如果点集只有一个点,直接返回该点构成的凸包
    if (n == 1)
    {
        return Polygon(a);
    }

    int k = m; // 记录下半部分最后一个点的位置

    // 构建凸包的上半部分
    for (int i = n - 2; i >= 0; i--)
    {
        // 同样使用叉积判断是否需要移除当前点以保持凸性
        for (; m > k && ((p[m] - p[m - 1]) ^ (a[i] - p[m - 1])) <= 0; m--)
            ;
        p[++m] = a[i]; // 将当前点添加到凸包中
    }

    // 调整凸包的点集大小,去掉最后一个点(它是上半部分和下半部分的连接点)
    p.resize(m);
    
    return Polygon(p); // 返回由这些点构成的凸包
}

// 计算多边形A的凸包
Polygon ConvexHull(Polygon A)
{
    return ConvexHull(A.p); // 调用基于点集的凸包函数,传入多边形的顶点列表
}

求半平面交

// 计算半平面交得到的凸多边形
Polygon HalfPlanes(vector<Line> l)
{
    vector<Point> p; // 存储最终得到的多边形顶点
    int n = l.size(); // 半平面的数量

    // 比较函数,用于按照直线的极角进行排序
    auto cmp = [](Line A, Line B)
    {
        ld r = A.angle2() - B.angle2(); // 计算两条直线的极角差
        if (dcmp(r) != 0)
            return dcmp(r) < 0; // 按极角从小到大排序
        // 如果极角相同,按直线位置排序
        return dcmp((A.e - A.s) ^ (B.e - A.s)) < 0;
    };

    // 按照比较函数对所有半平面的直线进行排序
    sort(all(l), cmp);

    // 用于存储处理后的半平面边界和对应的交点
    vector<Line> q(n + 2); 
    vector<Point> b(n + 2);

    // 初始化双端队列的头尾指针
    int head = 0, tail = 0;

    // 将第一条直线加入队列
    q[0] = l[0];

    for (int i = 1; i < n; i++)
    {
        // 如果当前直线与前一条直线的极角不同,则处理
        if (dcmp(l[i].angle2() - l[i - 1].angle2()) != 0)
        {
            // 检查队列中头尾是否存在平行且重合的直线
            if (head < tail && q[head].parallel(q[head + 1]))
                return Polygon(p); // 如果存在重合直线,返回空多边形
            if (head < tail && q[tail].parallel(q[tail - 1]))
                return Polygon(p);

            // 移除队尾不符合条件的直线
            while (head < tail && l[i].relation(b[tail - 1]) == 2)
                tail--;

            // 移除队头不符合条件的直线
            while (head < tail && l[i].relation(b[head]) == 2)
                head++;

            // 将当前直线加入队尾
            q[++tail] = l[i];

            // 计算新加入的直线与前一条直线的交点
            if (head < tail)
                b[tail - 1] = q[tail].Intersection(q[tail - 1]);
        }
    }

    // 检查头尾直线是否构成封闭多边形
    while (head < tail && l[head].relation(b[tail - 1]) == 2)
        tail--;
    while (head < tail && l[tail].relation(b[head]) == 2)
        head++;

    // 如果剩下的直线不足以构成多边形,则返回空多边形
    if (tail - head <= 1)
        return Polygon(p);

    // 计算封闭多边形的最后一个交点
    b[tail] = q[head].Intersection(q[tail]);

    // 调整多边形顶点的大小并存储结果
    p.resize(tail - head + 1);
    for (int i = head; i <= tail; i++)
        p[i - head] = b[i];

    return Polygon(p); // 返回由这些顶点构成的凸多边形
}

静态凸包(单调链)


struct Point
{
    ll x; // x坐标
    ll y; // y坐标

    // 构造函数
    Point(ll x = 0, ll y = 0) : x(x), y(y) {}
};

// 重载相等运算符
bool operator==(const Point &a, const Point &b)
{
    return a.x == b.x && a.y == b.y;
}

// 重载加法运算符
Point operator+(const Point &a, const Point &b)
{
    return Point(a.x + b.x, a.y + b.y);
}

// 重载减法运算符
Point operator-(const Point &a, const Point &b)
{
    return Point(a.x - b.x, a.y - b.y);
}

// 计算点的点积
ll dot(const Point &a, const Point &b)
{
    return a.x * b.x + a.y * b.y;
}

// 计算点的叉积
ll cross(const Point &a, const Point &b)
{
    return a.x * b.y - a.y * b.x;
}

// 规范化点的顺序,使最底部的点在前
void norm(vector<Point> &h)
{
    int i = 0;
    for (int j = 0; j < int(h.size()); j++)
    {
        if (h[j].y < h[i].y || (h[j].y == h[i].y && h[j].x < h[i].x))
        {
            i = j; // 找到最底部的点
        }
    }
    rotate(h.begin(), h.begin() + i, h.end()); // 将最底部的点移到开头
}

// 符号函数,判断方向
int sgn(const Point &a)
{
    return a.y > 0 || (a.y == 0 && a.x > 0) ? 0 : 1;
}

// 获取凸包
vector<Point> getHull(vector<Point> p)
{
    vector<Point> h, l; // h为上半部分,l为下半部分
    // 按照x坐标排序,若x相同则按y坐标排序
    sort(p.begin(), p.end(), [&](auto a, auto b)
         {
        if (a.x != b.x) {
            return a.x < b.x;
        } else {
            return a.y < b.y;
        } });

    // 去重
    p.erase(unique(p.begin(), p.end()), p.end());

    // 如果点数小于等于1,直接返回
    if (p.size() <= 1)
    {
        return p;
    }

    // 构建下半部分
    for (auto a : p)
    {
        while (h.size() > 1 && cross(a - h.back(), a - h[h.size() - 2]) <= 0)
        {
            h.pop_back(); // 去掉不满足条件的点
        }
        h.push_back(a); // 添加当前点
    }

    // 构建上半部分
    for (int i = p.size() - 1; i >= 0; i--)
    {
        auto a = p[i];
        while (l.size() > 1 && cross(a - l.back(), a - l[l.size() - 2]) >= 0)
        {
            l.pop_back(); // 去掉不满足条件的点
        }
        l.push_back(a); // 添加当前点
    }

    l.pop_back();                          // 删除重复的点
    reverse(h.begin(), h.end());           // 反转上半部分
    h.pop_back();                          // 删除重复的点
    l.insert(l.end(), h.begin(), h.end()); // 合并上下半部分
    return l;                              // 返回最终的凸包
}


平面几何合集

// 定义点结构体
template <class T>
struct Point
{
    T x, y;

    Point(T x_ = 0, T y_ = 0) : x(x_), y(y_) {}

    // 类型转换构造函数
    template <class U>
    operator Point<U>()
    {
        return Point<U>(U(x), U(y));
    }

    // 运算符重载
    Point &operator+=(Point p) &
    {
        x += p.x;
        y += p.y;
        return *this;
    }
    Point &operator-=(Point p) &
    {
        x -= p.x;
        y -= p.y;
        return *this;
    }
    Point &operator*=(T v) &
    {
        x *= v;
        y *= v;
        return *this;
    }
    Point operator-() const
    {
        return Point(-x, -y);
    }

    // 友元函数
    friend Point operator+(Point a, Point b)
    {
        return a += b;
    }
    friend Point operator-(Point a, Point b)
    {
        return a -= b;
    }
    friend Point operator*(Point a, T b)
    {
        return a *= b;
    }
    friend Point operator*(T a, Point b)
    {
        return b *= a;
    }
    friend bool operator==(Point a, Point b)
    {
        return a.x == b.x && a.y == b.y;
    }
    friend istream &operator>>(istream &is, Point &p)
    {
        return is >> p.x >> p.y;
    }
    friend ostream &operator<<(ostream &os, Point p)
    {
        return os << "(" << p.x << ", " << p.y << ")";
    }
};

// 点积
template <class T>
T dot(Point<T> a, Point<T> b)
{
    return a.x * b.x + a.y * b.y;
}

// 叉积
template <class T>
T cross(Point<T> a, Point<T> b)
{
    return a.x * b.y - a.y * b.x;
}

// 计算平方长度
template <class T>
T square(Point<T> p)
{
    return dot(p, p);
}

// 计算长度
template <class T>
double length(Point<T> p)
{
    return sqrt(double(square(p)));
}

// 特化长双精度长度
ld length(Point<ld> p)
{
    return sqrt(square(p));
}

// 定义线段结构体
template <class T>
struct Line
{
    Point<T> a, b;

    Line(Point<T> a_ = Point<T>(), Point<T> b_ = Point<T>()) : a(a_), b(b_) {}
};

// 旋转点
template <class T>
Point<T> rotate(Point<T> a)
{
    return Point(-a.y, a.x);
}

// 符号函数
template <class T>
int sgn(Point<T> a)
{
    return a.y > 0 || (a.y == 0 && a.x > 0) ? 1 : -1;
}

// 判断点是否在左侧
template <class T>
bool pointOnLineLeft(Point<T> p, Line<T> l)
{
    return cross(l.b - l.a, p - l.a) > 0;
}

// 计算两条线段的交点
template <class T>
Point<T> lineIntersection(Line<T> l1, Line<T> l2)
{
    return l1.a + (l1.b - l1.a) * (cross(l2.b - l2.a, l1.a - l2.a) / cross(l2.b - l2.a, l1.a - l1.b));
}

// 判断点是否在线段上
template <class T>
bool pointOnSegment(Point<T> p, Line<T> l)
{
    return cross(p - l.a, l.b - l.a) == 0 && min(l.a.x, l.b.x) <= p.x && p.x <= max(l.a.x, l.b.x) && min(l.a.y, l.b.y) <= p.y && p.y <= max(l.a.y, l.b.y);
}

// 判断点是否在多边形内部
template <class T>
bool pointInPolygon(Point<T> a, vector<Point<T>> p)
{
    int n = p.size();
    for (int i = 0; i < n; i++)
    {
        if (pointOnSegment(a, Line(p[i], p[(i + 1) % n])))
        {
            return true; // 点在边上
        }
    }

    int t = 0; // 交点计数
    for (int i = 0; i < n; i++)
    {
        auto u = p[i];
        auto v = p[(i + 1) % n];
        if (u.x < a.x && v.x >= a.x && pointOnLineLeft(a, Line(v, u)))
        {
            t ^= 1; // 交点计数
        }
        if (u.x >= a.x && v.x < a.x && pointOnLineLeft(a, Line(u, v)))
        {
            t ^= 1;
        }
    }

    return t == 1; // 如果交点数为奇数,点在多边形内
}

// 线段交集检测
// 返回值:0 : 不相交,1 : 严格相交,2 : 重叠,3 : 端点相交
template <class T>
tuple<int, Point<T>, Point<T>> segmentIntersection(Line<T> l1, Line<T> l2)
{
    // 检查边界情况
    if (max(l1.a.x, l1.b.x) < min(l2.a.x, l2.b.x) ||
        min(l1.a.x, l1.b.x) > max(l2.a.x, l2.b.x) ||
        max(l1.a.y, l1.b.y) < min(l2.a.y, l2.b.y) ||
        min(l1.a.y, l1.b.y) > max(l2.a.y, l2.b.y))
    {
        return {0, Point<T>(), Point<T>()}; // 不相交
    }

    // 检查是否平行
    if (cross(l1.b - l1.a, l2.b - l2.a) == 0)
    {
        if (cross(l1.b - l1.a, l2.a - l1.a) != 0)
        {
            return {0, Point<T>(), Point<T>()}; // 不相交
        }
        else
        {
            // 计算重叠区间
            auto maxx1 = max(l1.a.x, l1.b.x);
            auto minx1 = min(l1.a.x, l1.b.x);
            auto maxy1 = max(l1.a.y, l1.b.y);
            auto miny1 = min(l1.a.y, l1.b.y);
            auto maxx2 = max(l2.a.x, l2.b.x);
            auto minx2 = min(l2.a.x, l2.b.x);
            auto maxy2 = max(l2.a.y, l2.b.y);
            auto miny2 = min(l2.a.y, l2.b.y);
            Point<T> p1(max(minx1, minx2), max(miny1, miny2));
            Point<T> p2(min(maxx1, maxx2), min(maxy1, maxy2));

            // 检查重叠情况
            if (!pointOnSegment(p1, l1))
            {
                swap(p1.y, p2.y); // 确保 p1 是在 l1 上
            }
            if (p1 == p2)
            {
                return {3, p1, p2}; // 端点相交
            }
            else
            {
                return {2, p1, p2}; // 重叠
            }
        }
    }

    // 计算交点
    auto cp1 = cross(l2.a - l1.a, l2.b - l1.a);
    auto cp2 = cross(l2.a - l1.b, l2.b - l1.b);
    auto cp3 = cross(l1.a - l2.a, l1.b - l2.a);
    auto cp4 = cross(l1.a - l2.b, l1.b - l2.b);

    // 检查相交情况
    if ((cp1 > 0 && cp2 > 0) || (cp1 < 0 && cp2 < 0) || (cp3 > 0 && cp4 > 0) || (cp3 < 0 && cp4 < 0))
    {
        return {0, Point<T>(), Point<T>()}; // 不相交
    }

    Point p = lineIntersection(l1, l2); // 计算交点
    if (cp1 != 0 && cp2 != 0 && cp3 != 0 && cp4 != 0)
    {
        return {1, p, p}; // 严格相交
    }
    else
    {
        return {3, p, p}; // 端点相交
    }
}

// 判断线段是否在多边形内
template <class T>
bool segmentInPolygon(Line<T> l, vector<Point<T>> p)
{
    int n = p.size();
    if (!pointInPolygon(l.a, p) || !pointInPolygon(l.b, p))
    {
        return false; // 线段的端点不在多边形内
    }
    for (int i = 0; i < n; i++)
    {
        auto u = p[i];
        auto v = p[(i + 1) % n];
        auto w = p[(i + 2) % n];
        auto [t, p1, p2] = segmentIntersection(l, Line(u, v));

        if (t == 1)
        {
            return false; // 严格相交,线段在多边形外
        }
        if (t == 0)
        {
            continue; // 不相交
        }
        if (t == 2)
        { // 重叠
            if (pointOnSegment(v, l) && v != l.a && v != l.b)
            {
                if (cross(v - u, w - v) > 0)
                {
                    return false; // 线段在多边形外
                }
            }
        }
        else
        { // 端点相交
            if (p1 != u && p1 != v)
            {
                if (pointOnLineLeft(l.a, Line(v, u)) || pointOnLineLeft(l.b, Line(v, u)))
                {
                    return false; // 线段在多边形外
                }
            }
            else if (p1 == v)
            {
                if (l.a == v)
                {
                    if (pointOnLineLeft(u, l))
                    {
                        if (pointOnLineLeft(w, l) && pointOnLineLeft(w, Line(u, v)))
                        {
                            return false; // 线段在多边形外
                        }
                    }
                    else
                    {
                        if (pointOnLineLeft(w, l) || pointOnLineLeft(w, Line(u, v)))
                        {
                            return false; // 线段在多边形外
                        }
                    }
                }
                else if (l.b == v)
                {
                    if (pointOnLineLeft(u, Line(l.b, l.a)))
                    {
                        if (pointOnLineLeft(w, Line(l.b, l.a)) && pointOnLineLeft(w, Line(u, v)))
                        {
                            return false; // 线段在多边形外
                        }
                    }
                    else
                    {
                        if (pointOnLineLeft(w, Line(l.b, l.a)) || pointOnLineLeft(w, Line(u, v)))
                        {
                            return false; // 线段在多边形外
                        }
                    }
                }
                else
                {
                    if (pointOnLineLeft(u, l))
                    {
                        if (pointOnLineLeft(w, Line(l.b, l.a)) || pointOnLineLeft(w, Line(u, v)))
                        {
                            return false; // 线段在多边形外
                        }
                    }
                    else
                    {
                        if (pointOnLineLeft(w, l) || pointOnLineLeft(w, Line(u, v)))
                        {
                            return false; // 线段在多边形外
                        }
                    }
                }
            }
        }
    }
    return true; // 线段在多边形内
}

// 计算凸包
template <class T>
vector<Point<T>> hp(vector<Line<T>> lines)
{
    sort(lines.begin(), lines.end(), [&](auto l1, auto l2)
         {
             auto d1 = l1.b - l1.a;
             auto d2 = l2.b - l2.a;

             if (sgn(d1) != sgn(d2))
             {
                 return sgn(d1) == 1; // 按方向排序
             }

             return cross(d1, d2) > 0; // 按叉积排序
         });

    deque<Line<T>> ls;  // 存储线段
    deque<Point<T>> ps; // 存储交点
    for (auto l : lines)
    {
        if (ls.empty())
        {
            ls.push_back(l);
            continue;
        }

        // 检查是否左侧
        while (!ps.empty() && !pointOnLineLeft(ps.back(), l))
        {
            ps.pop_back();
            ls.pop_back();
        }

        while (!ps.empty() && !pointOnLineLeft(ps[0], l))
        {
            ps.pop_front();
            ls.pop_front();
        }

        // 检查平行情况
        if (cross(l.b - l.a, ls.back().b - ls.back().a) == 0)
        {
            if (dot(l.b - l.a, ls.back().b - ls.back().a) > 0)
            {
                if (!pointOnLineLeft(ls.back().a, l))
                {
                    assert(ls.size() == 1);
                    ls[0] = l; // 用新线段替换
                }
                continue;
            }
            return {}; // 平行且不相交
        }

        ps.push_back(lineIntersection(ls.back(), l)); // 计算交点
        ls.push_back(l);                              // 添加线段
    }

    // 检查最后的交点
    while (!ps.empty() && !pointOnLineLeft(ps.back(), ls[0]))
    {
        ps.pop_back();
        ls.pop_back();
    }
    if (ls.size() <= 2)
    {
        return {}; // 线段数量不足
    }
    ps.push_back(lineIntersection(ls[0], ls.back())); // 计算最后交点

    return vector(ps.begin(), ps.end()); // 返回交点
}

八、其他

常用工具

快速读写

template <typename T>
inline T read()
{
    T x = 0;
    int y = 1;
    char ch = getchar();
    while (ch > '9' || ch < '0')
    {
        if (ch == '-')
            y = -1;
        ch = getchar();
    }
    while (ch >= '0' && ch <= '9')
    {
        x = (x << 3) + (x << 1) + (ch ^ 48);
        ch = getchar();
    }
    return x * y;
}

template <typename T>
inline void write(T x)
{
    if (x < 0)
    {
        putchar('-');
        x = -x;
    }
    if (x >= 10)
    {
        write(x / 10);
    }
    putchar(x % 10 + '0');
}


生成随机数

// 随机数工具,封装64位随机源与常用接口
// 功能:区间整数/浮点随机、打乱、采样、64位哈希
// 复杂度:所有操作均为O(len)或O(1)期望
struct Random
{
    using ull = unsigned long long;
    using ll = long long;
    mt19937_64 eng; // 64位梅森旋转随机引擎

    // 使用高分辨率时钟与地址熵混合的构造函数
    Random()
    {
        ull seed = chrono::high_resolution_clock::now().time_since_epoch().count();
        seed ^= ull(reinterpret_cast<uintptr_t>(this)) + 0x9e3779b97f4a7c15ULL;
        eng.seed(seed);
    }

    // 在闭区间[l,r]上生成均匀随机整数
    ll nextInt(ll l, ll r)
    {
        return uniform_int_distribution<ll>(l, r)(eng);
    }

    // 在闭区间[l,r]上生成均匀随机浮点
    ld nextReal(ld l, ld r)
    {
        return uniform_real_distribution<ld>(l, r)(eng);
    }

    // 原地打乱一个容器
    template <typename T>
    void shuffleVec(vector<T> &a)
    {
        shuffle(a.begin(), a.end(), eng);
    }

    // 从[0..n-1]中不放回采样k个下标,返回升序结果
    vector<int> sample(int n, int k)
    {
        vector<int> id(n);
        iota(id.begin(), id.end(), 0);
        for (int i = 0; i < k; i++)
            swap(id[i], id[nextInt(i, n - 1)]);
        sort(id.begin(), id.begin() + k);
        id.resize(k);
        return id;
    }

    // splitmix64 作为64位哈希(可用于随机哈希/置乱键值)
    ull splitmix64(ull x)
    {
        x += 0x9e3779b97f4a7c15ULL;
        x = (x ^ (x >> 30)) * 0xbf58476d1ce4e5b9ULL;
        x = (x ^ (x >> 27)) * 0x94d049bb133111ebULL;
        return x ^ (x >> 31);
    }
};

取模类

template <class T>
constexpr T power(T a, ll b)
{
    T res = 1;
    for (; b; b /= 2, a *= a)
    {
        if (b % 2)
        {
            res *= a;
        }
    }
    return res;
}

constexpr ll mul(ll a, ll b, ll p)
{
    ll res = a * b - ll(1.L * a * b / p) * p;
    res %= p;
    if (res < 0)
    {
        res += p;
    }
    return res;
}
template <ll P>
struct MLong
{
    ll x;
    constexpr MLong() : x{} {}
    constexpr MLong(ll x) : x{norm(x % getMod())} {}

    static ll Mod;
    constexpr static ll getMod()
    {
        if (P > 0)
        {
            return P;
        }
        else
        {
            return Mod;
        }
    }
    constexpr static void setMod(ll Mod_)
    {
        Mod = Mod_;
    }
    constexpr ll norm(ll x) const
    {
        if (x < 0)
        {
            x += getMod();
        }
        if (x >= getMod())
        {
            x -= getMod();
        }
        return x;
    }
    constexpr ll val() const
    {
        return x;
    }
    explicit constexpr operator ll() const
    {
        return x;
    }
    constexpr MLong operator-() const
    {
        MLong res;
        res.x = norm(getMod() - x);
        return res;
    }
    constexpr MLong inv() const
    {
        assert(x != 0);
        return power(*this, getMod() - 2);
    }
    constexpr MLong &operator*=(MLong rhs) &
    {
        x = mul(x, rhs.x, getMod());
        return *this;
    }
    constexpr MLong &operator+=(MLong rhs) &
    {
        x = norm(x + rhs.x);
        return *this;
    }
    constexpr MLong &operator-=(MLong rhs) &
    {
        x = norm(x - rhs.x);
        return *this;
    }
    constexpr MLong &operator/=(MLong rhs) &
    {
        return *this *= rhs.inv();
    }
    friend constexpr MLong operator*(MLong lhs, MLong rhs)
    {
        MLong res = lhs;
        res *= rhs;
        return res;
    }
    friend constexpr MLong operator+(MLong lhs, MLong rhs)
    {
        MLong res = lhs;
        res += rhs;
        return res;
    }
    friend constexpr MLong operator-(MLong lhs, MLong rhs)
    {
        MLong res = lhs;
        res -= rhs;
        return res;
    }
    friend constexpr MLong operator/(MLong lhs, MLong rhs)
    {
        MLong res = lhs;
        res /= rhs;
        return res;
    }
    friend constexpr istream &operator>>(istream &is, MLong &a)
    {
        ll v;
        is >> v;
        a = MLong(v);
        return is;
    }
    friend constexpr ostream &operator<<(ostream &os, const MLong &a)
    {
        return os << a.val();
    }
    friend constexpr bool operator==(MLong lhs, MLong rhs)
    {
        return lhs.val() == rhs.val();
    }
    friend constexpr bool operator!=(MLong lhs, MLong rhs)
    {
        return lhs.val() != rhs.val();
    }
};

template <>
ll MLong<0LL>::Mod = ll(1E18) + 9;

template <int P>
struct MInt
{
    int x;
    constexpr MInt() : x{} {}
    constexpr MInt(ll x) : x{norm(x % getMod())} {}

    static int Mod;
    constexpr static int getMod()
    {
        if (P > 0)
        {
            return P;
        }
        else
        {
            return Mod;
        }
    }
    constexpr static void setMod(int Mod_)
    {
        Mod = Mod_;
    }
    constexpr int norm(int x) const
    {
        if (x < 0)
        {
            x += getMod();
        }
        if (x >= getMod())
        {
            x -= getMod();
        }
        return x;
    }
    constexpr int val() const
    {
        return x;
    }
    explicit constexpr operator int() const
    {
        return x;
    }
    constexpr MInt operator-() const
    {
        MInt res;
        res.x = norm(getMod() - x);
        return res;
    }
    constexpr MInt inv() const
    {
        assert(x != 0);
        return power(*this, getMod() - 2);
    }
    constexpr MInt &operator*=(MInt rhs) &
    {
        x = 1LL * x * rhs.x % getMod();
        return *this;
    }
    constexpr MInt &operator+=(MInt rhs) &
    {
        x = norm(x + rhs.x);
        return *this;
    }
    constexpr MInt &operator-=(MInt rhs) &
    {
        x = norm(x - rhs.x);
        return *this;
    }
    constexpr MInt &operator/=(MInt rhs) &
    {
        return *this *= rhs.inv();
    }
    friend constexpr MInt operator*(MInt lhs, MInt rhs)
    {
        MInt res = lhs;
        res *= rhs;
        return res;
    }
    friend constexpr MInt operator+(MInt lhs, MInt rhs)
    {
        MInt res = lhs;
        res += rhs;
        return res;
    }
    friend constexpr MInt operator-(MInt lhs, MInt rhs)
    {
        MInt res = lhs;
        res -= rhs;
        return res;
    }
    friend constexpr MInt operator/(MInt lhs, MInt rhs)
    {
        MInt res = lhs;
        res /= rhs;
        return res;
    }
    friend constexpr istream &operator>>(istream &is, MInt &a)
    {
        ll v;
        is >> v;
        a = MInt(v);
        return is;
    }
    friend constexpr ostream &operator<<(ostream &os, const MInt &a)
    {
        return os << a.val();
    }
    friend constexpr bool operator==(MInt lhs, MInt rhs)
    {
        return lhs.val() == rhs.val();
    }
    friend constexpr bool operator!=(MInt lhs, MInt rhs)
    {
        return lhs.val() != rhs.val();
    }
};

template <>
int MInt<0>::Mod = 998244353;

template <int V, int P>
constexpr MInt<P> CInv = MInt<P>(V).inv();

constexpr int P = 1e9 + 7;
using Z = MInt<P>;

分数类

template <class T>
struct Frac
{
    T num; // 分子
    T den; // 分母

    // 构造函数,初始化分子和分母
    Frac(T num_, T den_) : num(num_), den(den_)
    {
        // 如果分母为负,调整分子和分母的符号
        if (den < 0)
        {
            den = -den;
            num = -num;
        }
    }

    // 默认构造函数,初始化为 0/1
    Frac() : Frac(0, 1) {}

    // 只传入分子,分母默认为 1
    Frac(T num_) : Frac(num_, 1) {}

    // 将分数转换为 double 类型
    explicit operator double() const
    {
        return 1. * num / den; // 注意使用浮点数运算
    }

    // 加法赋值
    Frac &operator+=(const Frac &rhs)
    {
        num = num * rhs.den + rhs.num * den; // 计算新分子
        den *= rhs.den;                      // 计算新分母
        return *this;                        // 返回当前对象
    }

    // 减法赋值
    Frac &operator-=(const Frac &rhs)
    {
        num = num * rhs.den - rhs.num * den; // 计算新分子
        den *= rhs.den;                      // 计算新分母
        return *this;                        // 返回当前对象
    }

    // 乘法赋值
    Frac &operator*=(const Frac &rhs)
    {
        num *= rhs.num; // 计算新分子
        den *= rhs.den; // 计算新分母
        return *this;   // 返回当前对象
    }

    // 除法赋值
    Frac &operator/=(const Frac &rhs)
    {
        num *= rhs.den; // 计算新分子
        den *= rhs.num; // 计算新分母
        if (den < 0)
        { // 如果分母为负,调整符号
            num = -num;
            den = -den;
        }
        return *this; // 返回当前对象
    }

    // 加法运算符
    friend Frac operator+(Frac lhs, const Frac &rhs)
    {
        return lhs += rhs; // 使用 += 实现
    }

    // 减法运算符
    friend Frac operator-(Frac lhs, const Frac &rhs)
    {
        return lhs -= rhs; // 使用 -= 实现
    }

    // 乘法运算符
    friend Frac operator*(Frac lhs, const Frac &rhs)
    {
        return lhs *= rhs; // 使用 *= 实现
    }

    // 除法运算符
    friend Frac operator/(Frac lhs, const Frac &rhs)
    {
        return lhs /= rhs; // 使用 /= 实现
    }

    // 一元负运算符
    friend Frac operator-(const Frac &a)
    {
        return Frac(-a.num, a.den); // 返回相反数
    }

    // 相等比较
    friend bool operator==(const Frac &lhs, const Frac &rhs)
    {
        return lhs.num * rhs.den == rhs.num * lhs.den; // 交叉相乘比较
    }

    // 不相等比较
    friend bool operator!=(const Frac &lhs, const Frac &rhs)
    {
        return !(lhs == rhs); // 使用 == 实现
    }

    // 小于比较
    friend bool operator<(const Frac &lhs, const Frac &rhs)
    {
        return lhs.num * rhs.den < rhs.num * lhs.den; // 交叉相乘比较
    }

    // 大于比较
    friend bool operator>(const Frac &lhs, const Frac &rhs)
    {
        return rhs < lhs; // 使用 < 实现
    }

    // 小于等于比较
    friend bool operator<=(const Frac &lhs, const Frac &rhs)
    {
        return !(rhs < lhs); // 使用 < 实现
    }

    // 大于等于比较
    friend bool operator>=(const Frac &lhs, const Frac &rhs)
    {
        return !(lhs < rhs); // 使用 < 实现
    }

    // 输出流重载
    friend ostream &operator<<(ostream &os, Frac x)
    {
        T g = gcd(x.num, x.den); // 计算分子和分母的最大公约数
        if (x.den == g)
        {                           // 如果分母为最大公约数
            return os << x.num / g; // 只输出分子
        }
        else
        {
            return os << x.num / g << "/" << x.den / g; // 输出简化后的分数
        }
    }
};

大数运算

// 64位模运算工具,安全乘法、快速幂与扩展欧几里得求逆
// 功能:在任意64位模数mod下进行(a*b)%mod与pow、inv
// 复杂度:mul O(1),pow O(log b)
struct Mod64
{
    using u128 = __uint128_t;
    using i128 = __int128_t;
    ull mod;

    Mod64(ull m = (ull)1e9 + 7) { mod = m; }

    ull add(ull a, ull b)
    {
        ull c = a + b;
        return c >= mod ? c - mod : c;
    }

    ull sub(ull a, ull b)
    {
        return a >= b ? a - b : a + mod - b;
    }

    // 使用内建128位进行安全乘法
    ull mul(ull a, ull b)
    {
        return (u128)a * b % mod;
    }

    ull powmod(ull a, ull e)
    {
        ull r = 1ULL;
        while (e)
        {
            if (e & 1ULL) r = mul(r, a);
            a = mul(a, a);
            e >>= 1ULL;
        }
        return r;
    }

    // 扩展欧几里得求逆,要求a与mod互质
    ll inv(ll a)
    {
        auto exgcd = [&](auto &&self, ll x, ll y) -> pair<ll,ll>
        {
            if (!y) return {1, 0};
            auto [u, v] = self(self, y, x % y);
            return make_pair(v, u - (x / y) * v);
        };
        ll x = exgcd(exgcd, a % (ll)mod + (ll)mod, (ll)mod).first;
        x %= (ll)mod;
        if (x < 0) x += (ll)mod;
        return x;
    }
};

// 十进制大整数,仅存正数,支持加法与乘以int
// 功能:处理长度极大的十进制数;接口简洁,适合快速拼装输出
// 复杂度:加法O(n),乘以int O(n)
struct BigDec
{
    string s;                       // 最高位在s[0]
    BigDec(const string &t = "0") { s = strip(t); }

    // 去除前导零
    string strip(const string &x)
    {
        int i = 0;
        while (i + 1 < (int)x.size() && x[i] == '0') i++;
        return x.substr(i);
    }

    // 与另一个十进制正数相加
    BigDec add(const BigDec &o) const
    {
        string a = s, b = o.s;
        int i = (int)a.size() - 1, j = (int)b.size() - 1, c = 0;
        string r;
        while (i >= 0 || j >= 0 || c)
        {
            int x = i >= 0 ? a[i--] - '0' : 0;
            int y = j >= 0 ? b[j--] - '0' : 0;
            int z = x + y + c;
            r.push_back(char('0' + (z % 10)));
            c = z / 10;
        }
        reverse(r.begin(), r.end());
        return BigDec(r);
    }

    // 乘以一个非负32位整数
    BigDec mulInt(int m) const
    {
        if (m == 0) return BigDec("0");
        int c = 0;
        string r; r.resize(s.size());
        for (int i = (int)s.size() - 1; i >= 0; i--)
        {
            ll z = 1LL * (s[i] - '0') * m + c;
            r[i] = char('0' + (z % 10));
            c = (int)(z / 10);
        }
        if (c) r.insert(r.begin(), char('0' + (c % 10)));
        return BigDec(r);
    }
};

排序

单轴快排

// 快速排序函数,对数组 arr 的子数组 arr[l...r] 进行排序
void quickSort(int arr[], int l, int r)
{
    // 如果子数组长度为 0 或 1,无需排序,直接返回
    if (l >= r) 
        return;
    // 初始化左右指针和基准值
    int i = l - 1, j = r + 1;
     // 取中间元素作为基准值
    int pivot = arr[l + r >> 1];
    // 开始进行划分
    while (i < j)
    {
        // 在左侧找到第一个大于等于基准值的元素
        do {i++;} while (arr[i] < pivot);
        // 在右侧找到第一个小于等于基准值的元素
        do {j--;} while (arr[j] > pivot);
        // 如果 i < j,则交换 arr[i] 和 arr[j],将较小的元素放到左侧,较大的元素放到右侧
        if (i < j)
            swap(arr[i], arr[j]);
    }
    // 递归对左右两侧子数组进行快速排序
    quickSort(arr, l, j);   // 对左侧子数组 arr[l...j] 排序
    quickSort(arr, j + 1, r); // 对右侧子数组 arr[j+1...r] 排序
}

双轴快排

// 双轴快速排序函数,对数组 arr 的子数组 arr[l...r] 进行排序
void dualPivotQuickSort(int arr[], int l, int r)
{
    // 如果子数组长度为 0 或 1,无需排序,直接返回
    if (l >= r)
        return;

    // 初始化左右指针和两个基准值
    int i = l - 1, j = r + 1;
    // 选择首尾两个元素作为基准值
    int pivot1 = arr[l], pivot2 = arr[r];

    // 开始进行划分
    for (int k = l; k <= r; ++k)
    {
        // 将数组划分为三部分:小于 p 的部分、介于 p 和 q 之间的部分、大于 q 的部分
        if (arr[k] < pivot1)
        {
            ++i;
            swap(arr[i], arr[k]); // 将当前元素交换到小于 p 的部分
        }
        else if (arr[k] > pivot2)
        {
            --j;
            swap(arr[j], arr[k]); // 将当前元素交换到大于 q 的部分
            // 由于交换过来的元素大小未知,需要重新检查当前位置的元素
            --k;
        }
    }

    // 递归对小于 p 的部分和大于 q 的部分进行排序
    dualPivotQuickSort(arr, l, i);
    dualPivotQuickSort(arr, j, r);

    // 如果介于 p 和 q 之间的部分存在元素,则递归对其进行排序
    if (pivot1 < pivot2)
        dualPivotQuickSort(arr, i + 1, j - 1);
}

归并排序

// 合并两个有序数组 arr[l...m] 和 arr[m+1...r]
void merge(int arr[], int l, int m, int r)
{
    int n1 = m - l + 1; // 左子数组的长度
    int n2 = r - m;     // 右子数组的长度

    // 创建临时数组来存储左右子数组合并后的结果
    int L[n1], R[n2];

    // 将数据复制到临时数组中
    for (int i = 0; i < n1; ++i)
        L[i] = arr[l + i];
    for (int j = 0; j < n2; ++j)
        R[j] = arr[m + 1 + j];

    // 合并临时数组到 arr[l...r]
    int i = 0; // 初始化左子数组的索引
    int j = 0; // 初始化右子数组的索引
    int k = l; // 初始化合并后数组的索引
    while (i < n1 && j < n2)
    {
        if (L[i] <= R[j])
        {
            arr[k] = L[i];
            ++i;
        }
        else
        {
            arr[k] = R[j];
            ++j;
        }
        ++k;
    }

    // 将剩余的元素复制到 arr[l...r]
    while (i < n1)
    {
        arr[k] = L[i];
        ++i;
        ++k;
    }
    while (j < n2)
    {
        arr[k] = R[j];
        ++j;
        ++k;
    }
}

// 归并排序函数,对数组 arr 的子数组 arr[l...r] 进行排序
void mergeSort(int arr[], int l, int r)
{
    if (l < r)
    {
        // 计算中间位置
        int m = l + (r - l) / 2;

        // 递归地对左右两部分进行排序
        mergeSort(arr, l, m);
        mergeSort(arr, m + 1, r);

        // 合并已排序的两部分
        merge(arr, l, m, r);
    }
}

希尔排序

// 希尔排序函数
void shellSort(int arr[], int n)
{
    // 初始步长设定为数组长度的一半,并逐步缩小步长直到为 1
    for (int gap = n / 2; gap > 0; gap /= 2)
    {
        // 对每个步长进行插入排序
        for (int i = gap; i < n; i += 1)
        {
            // 将 arr[i] 插入到所在的子数组中的正确位置
            int temp = arr[i];
            int j;
            for (j = i; j >= gap && arr[j - gap] > temp; j -= gap)
            {
                arr[j] = arr[j - gap];
            }
            arr[j] = temp;
        }
    }
}

堆排序

// 调整以节点 i 为根的子树为最大堆
void heapify(int arr[], int n, int i)
{
    int largest = i;   // 初始化最大值为根节点
    int l = 2 * i + 1; // 左子节点的索引
    int r = 2 * i + 2; // 右子节点的索引

    // 如果左子节点比根节点大,更新最大值索引
    if (l < n && arr[l] > arr[largest])
        largest = l;

    // 如果右子节点比根节点大,更新最大值索引
    if (r < n && arr[r] > arr[largest])
        largest = r;

    // 如果最大值不是根节点,交换最大值和根节点,并递归调整子树
    if (largest != i)
    {
        swap(arr[i], arr[largest]);
        heapify(arr, n, largest);
    }
}

// 堆排序函数
void heapSort(int arr[], int n)
{
    // 构建最大堆
    for (int i = n / 2 - 1; i >= 0; i--)
        heapify(arr, n, i);

    // 依次取出堆顶元素并重新调整堆
    for (int i = n - 1; i > 0; i--)
    {
        swap(arr[0], arr[i]); // 将堆顶元素(最大值)与最后一个元素交换
        heapify(arr, i, 0);   // 重新调整剩余元素为最大堆
    }
}


杂项

CDQ分治

算法介绍

CDQ 是在“分治顺序”上做数据结构维护的技巧,用于解决具有“前缀依赖”的离线问题。经典应用是三维偏序计数与某些 DP 的转移优化。将数据按第一关键字排序后在分治过程中按第二关键字排序并用树状数组维护第三关键字,从而把三维问题降为若干次二维问题。

常见例题

给出 n 个三元组 (a, b, c),统计对数 (i, j) 使得 ai ≤ aj 且 bi ≤ bj 且 ci ≤ cj。做法是按 a 排序,做一次 CDQ,将左半区按 b 排序在线加入 BIT,再用右半区按 b 排序逐个用 c 查询并累加答案,递归左右两侧即可。

代码
// 三维偏序CDQ模板:统计(i<j)且a[i]<=a[j], b[i]<=b[j], c[i]<=c[j]的对数
// 功能:离线分治 + Fenwick 统计三维偏序
// 复杂度:O(n log^2 n)
template <typename T>
struct Fenwick
{
    int n; vector<T> bit;
    Fenwick(int n_ = 0) { init(n_); }
    void init(int n_) { n = n_; bit.assign(n + 1, T{}); }
    void add(int x, T v) { for (int i = x; i <= n; i += i & -i) bit[i] = bit[i] + v; }
    T sum(int x) { T r{}; for (int i = x; i > 0; i -= i & -i) r = r + bit[i]; return r; }
};

struct CDQ3D
{
    struct Node { int a, b, c, id; ll add, ans; };
    vector<Node> p, buf;
    Fenwick<ll> fw;

    CDQ3D() {}
    CDQ3D(const vector<tuple<int,int,int>> &v)
    {
        int n = (int)v.size();
        p.resize(n);
        for (int i = 0; i < n; i++)
        {
            auto [A,B,C] = v[i];
            p[i] = {A, B, C, i, 1, 0};
        }
        buf.resize(n);
        fw.init(200000); // 依据c的离散化上界设置;实际应用中先离散化c
    }

    // 若c值范围大,需要外部先对c离散化;此处提供辅助
    void discretizeC()
    {
        vector<int> vals; vals.reserve(p.size());
        for (auto &x : p) vals.push_back(x.c);
        sort(vals.begin(), vals.end());
        vals.erase(unique(vals.begin(), vals.end()), vals.end());
        for (auto &x : p) x.c = int(lower_bound(vals.begin(), vals.end(), x.c) - vals.begin()) + 1;
        fw.init((int)vals.size());
    }

    void solve()
    {
        sort(p.begin(), p.end(), [&](const Node &x, const Node &y)
        {
            if (x.a != y.a) return x.a < y.a;
            if (x.b != y.b) return x.b < y.b;
            return x.c < y.c;
        });
        for (int l = 0, r; l < (int)p.size(); l = r)
        {
            r = l;
            while (r < (int)p.size() && p[r].a == p[l].a && p[r].b == p[l].b && p[r].c == p[l].c) r++;
            p[l].add = r - l;
        }
        cdq(0, (int)p.size() - 1);
        sort(p.begin(), p.end(), [&](const Node &x, const Node &y){ return x.id < y.id; });
    }

    // 获取每个点被“支配”的对数(含重复合并)
    vector<ll> answers() const
    {
        vector<ll> res; res.reserve(p.size());
        for (auto &x : p) res.push_back(x.ans);
        return res;
    }

    void cdq(int L, int R)
    {
        if (L >= R) return;
        int M = (L + R) >> 1;
        cdq(L, M);
        cdq(M + 1, R);
        int i = L, j = M + 1, k = L;
        while (i <= M && j <= R)
        {
            if (p[i].b <= p[j].b)
            {
                fw.add(p[i].c, p[i].add);
                buf[k++] = p[i++];
            }
            else
            {
                p[j].ans += fw.sum(p[j].c);
                buf[k++] = p[j++];
            }
        }
        while (i <= M) fw.add(p[i].c, p[i].add), buf[k++] = p[i++];
        while (j <= R) p[j].ans += fw.sum(p[j].c), buf[k++] = p[j++];
        for (int t = L; t <= R; t++) p[t] = buf[t];
        for (int t = L; t <= M; t++) fw.add(p[t].c, -p[t].add);
    }
};

分数规划

算法介绍

分数规划常用于最大化形如 Σai xi / Σbi xi 的目标。典型做法是二分答案 mid,将目标转化为判断是否存在选择使 Σ(ai − mid·bi) ≥ 0 成立。若是选择恰好 k 个,则把每个对象的“贡献”替换为 ai − mid·bi,取前 k 大判断和是否非负。答案误差由二分精度控制。

常见例题

给定 n 个物品,每个物品有收益 ai 与权重 bi,要求选择恰好 k 个使平均收益 Σai / Σbi 最大。做法是对 mid 进行二分,每次把每个物品的值变为 ai − mid·bi,取前 k 大,若其和非负则说明可行。

代码
// 分数规划:恰取k个使Σai/Σbi最大
// 功能:二分答案 + 判定取前k大( ai - mid*bi )之和 >= 0
// 复杂度:O(n log n log 精度)
struct FractionPlan
{
    int n, k;
    vector<ld> a, b;

    FractionPlan(int n_, int k_, const vector<ld> &A, const vector<ld> &B)
    {
        n = n_; k = k_; a = A; b = B;
    }

    bool check(ld mid)
    {
        static vector<ld> v;
        v.resize(n);
        for (int i = 0; i < n; i++) v[i] = a[i] - mid * b[i];
        nth_element(v.begin(), v.begin() + k, v.end(), greater<ld>());
        ld s = 0;
        for (int i = 0; i < k; i++) s += v[i];
        return s >= 0;
    }

    ld solve(ld lo = 0, ld hi = 1e9, int iters = 80)
    {
        while (iters--) { ld mid = (lo + hi) / 2; if (check(mid)) lo = mid; else hi = mid; }
        return lo;
    }
};

珂朵莉树

算法介绍

珂朵莉树适合“数据随机且区间赋值为主”的场景,把相同值的连续段按 set 维护为不相交的区间节点。核心操作是 split,把位置 x 处切成左右两段,然后在若干相邻段上进行 assign、add 或查询。性能高度依赖数据分块的稳定性与随机性。

常见例题

给定长度为 n 的序列,支持多次区间赋值与区间求和,初始为全零。做法是使用 ODT 将 [l, r] 覆盖的若干节点替换为一个值为 v 的新节点,求和时遍历覆盖的节点累加长度乘以节点值。

代码
// 珂朵莉树(ODT),按值相同的连续段分块
// 功能:维护区间常值分段,支持区间赋值、区间加、区间求和、区间复制、区间交换、区间逆置与打印
// 适用:区间操作以“整段同值”为主、修改随机性较强的场景;理论最坏复杂,但实战常数优秀
// 复杂度:随机数据期望近似O(log 段数)每次
template <typename T>
struct ChthollyTree
{
    struct Node
    {
        int left, right;                 // 区间左右端点
        mutable T val;                   // 区间常值,mutable以便通过set迭代器原地修改
        Node(int L = 0, int R = -1, T V = T{}) : left(L), right(R), val(V) {}
        bool operator<(const Node &o) const { return left < o.left; } // 以left为键的有序分段
    };

    int n;                               // 序列长度
    T mod;                               // 模数,0表示不取模
    std::set<Node> seg;                  // 珂朵莉树主体
    std::vector<Node> bufA, bufB;        // 临时缓冲,供复制/交换/逆置时收集分段

    // 构造函数,若给定规模则仅完成基本初始化,不建树
    ChthollyTree(int n_ = 0, T mod_ = T{}) { init(n_, mod_); }

    // 初始化规模与模数并清空结构
    void init(int n_, T mod_ = T{})
    {
        n = n_;
        mod = mod_;
        seg.clear();
        bufA.clear();
        bufB.clear();
    }

    // 将值归一化到[0,mod);若mod为0则原样返回
    T norm(T x) const
    {
        if (mod == T{}) return x;
        x %= mod;
        if (x < T{}) x += mod;
        return x;
    }

    // 用长度为n的数组v建树,自动压缩相邻相同值为一段
    void build(const std::vector<T> &v)
    {
        init((int)v.size() - 1, mod);
        if ((int)v.size() <= 1) return;
        int L = 1; T cur = norm(v[1]);
        for (int i = 2; i <= n; i++)
        {
            T x = norm(v[i]);
            if (x == cur) continue;
            seg.insert(Node(L, i - 1, cur));
            L = i; cur = x;
        }
        seg.insert(Node(L, n, cur));
    }

    // 从输入流读取n个值建树,等价于build,便于赛时快速读入
    void buildFromStream(std::istream &in)
    {
        std::vector<T> v(n + 1);
        for (int i = 1; i <= n; i++) in >> v[i];
        build(v);
    }

    // 在位置p处分割并返回以p为左端点的迭代器
    // 说明:若已存在以p为左端点的分段则直接返回;若p越界于当前所有段左端之前则返回begin()
    typename std::set<Node>::iterator split(int p)
    {
        if (p <= 1) return seg.begin();
        if (p > n) return seg.end();
        auto it = seg.lower_bound(Node(p, -1, T{}));
        if (it != seg.end() && it->left == p) return it;
        if (it == seg.begin()) return seg.begin();
        --it;
        if (p > it->right) return std::next(it);
        int L = it->left, R = it->right; T v = it->val;
        seg.erase(it);
        seg.insert(Node(L, p - 1, v));
        return seg.insert(Node(p, R, v)).first;
    }

    // 区间求和,返回[l,r]上值之和(按需取模)
    T askInterval(int l, int r)
    {
        if (l > r) std::swap(l, r);
        l = std::max(l, 1); r = std::min(r, n);
        if (l > r) return T{};
        auto itR = split(r + 1);
        auto itL = split(l);
        __int128 acc = 0;
        for (auto it = itL; it != itR; ++it)
        {
            acc += (__int128)(it->right - it->left + 1) * it->val;
            if (mod != T{} && acc > (__int128)4e36) acc %= mod;
        }
        if (mod == T{}) return (T)acc;
        return (T)(acc % mod);
    }

    // 区间加,将[l,r]上所有值加上v(按需取模)
    void addInterval(int l, int r, T v)
    {
        if (l > r) std::swap(l, r);
        l = std::max(l, 1); r = std::min(r, n);
        if (l > r) return;
        auto itR = split(r + 1);
        auto itL = split(l);
        if (mod == T{})
        {
            for (auto it = itL; it != itR; ++it) it->val += v;
            return;
        }
        v %= mod;
        for (auto it = itL; it != itR; ++it) it->val = norm(it->val + v);
    }

    // 区间赋值,将[l,r]整体赋成v
    void assignInterval(int l, int r, T v)
    {
        if (l > r) std::swap(l, r);
        l = std::max(l, 1); r = std::min(r, n);
        if (l > r) return;
        auto itR = split(r + 1);
        auto itL = split(l);
        seg.erase(itL, itR);
        seg.insert(Node(l, r, norm(v)));
    }

    // 区间复制,把[l1,r1]的值复制到[l2,r2](两段长度必须相等)
    void copyInterval(int l1, int r1, int l2, int r2)
    {
        if (l1 > r1) std::swap(l1, r1);
        if (l2 > r2) std::swap(l2, r2);
        l1 = std::max(l1, 1); r1 = std::min(r1, n);
        l2 = std::max(l2, 1); r2 = std::min(r2, n);
        if (l1 > r1 || l2 > r2) return;
        int len1 = r1 - l1 + 1, len2 = r2 - l2 + 1;
        if (len1 != len2) return;

        bufA.clear();
        auto itR1 = split(r1 + 1);
        auto itL1 = split(l1);
        for (auto it = itL1; it != itR1; ++it) bufA.emplace_back(*it);

        auto itR2 = split(r2 + 1);
        auto itL2 = split(l2);
        seg.erase(itL2, itR2);

        for (auto &nd : bufA) seg.insert(Node(nd.left - l1 + l2, nd.right - l1 + l2, nd.val));
    }

    // 区间交换,交换[l1,r1]与[l2,r2](允许相交,长度必须相等)
    void swapInterval(int l1, int r1, int l2, int r2)
    {
        if (l1 > r1) std::swap(l1, r1);
        if (l2 > r2) std::swap(l2, r2);
        if (l1 > l2) { std::swap(l1, l2); std::swap(r1, r2); }
        l1 = std::max(l1, 1); r1 = std::min(r1, n);
        l2 = std::max(l2, 1); r2 = std::min(r2, n);
        if (l1 > r1 || l2 > r2) return;
        int len1 = r1 - l1 + 1, len2 = r2 - l2 + 1;
        if (len1 != len2) return;

        bufA.clear(); bufB.clear();
        auto itR1 = split(r1 + 1);
        auto itL1 = split(l1);
        for (auto it = itL1; it != itR1; ++it) bufA.emplace_back(*it);

        auto itR2 = split(r2 + 1);
        auto itL2 = split(l2);
        for (auto it = itL2; it != itR2; ++it) bufB.emplace_back(*it);

        itR1 = split(r1 + 1); itL1 = split(l1); seg.erase(itL1, itR1);
        itR2 = split(r2 + 1); itL2 = split(l2); seg.erase(itL2, itR2);

        for (auto &nd : bufB) seg.insert(Node(nd.left - l2 + l1, nd.right - l2 + l1, nd.val));
        for (auto &nd : bufA) seg.insert(Node(nd.left - l1 + l2, nd.right - l1 + l2, nd.val));
    }

    // 区间逆置,将[l,r]的段顺序反转并相应平移端点
    void reverseInterval(int l, int r)
    {
        if (l > r) std::swap(l, r);
        l = std::max(l, 1); r = std::min(r, n);
        if (l > r) return;

        bufA.clear();
        auto itR = split(r + 1);
        auto itL = split(l);
        for (auto it = itL; it != itR; ++it) bufA.emplace_back(*it);
        seg.erase(itL, itR);

        // 逆序插回,端点做对称映射
        for (int i = 0; i < (int)bufA.size(); i++)
        {
            int L = bufA[i].left, R = bufA[i].right; T v = bufA[i].val;
            seg.insert(Node(r - R + l, r - L + l, v));
        }
    }

    // 打印整棵树的展开序列到输出流;仅用于调试,复杂度O(段数 + n)
    void print(std::ostream &out) const
    {
        for (auto it = seg.begin(); it != seg.end(); ++it)
        {
            int L = std::max(it->left, 1), R = std::min(it->right, n);
            for (int i = L; i <= R; i++) out << it->val << (i == n ? '\n' : ' ');
        }
    }

    // 若当前为空,按常量v建成单段;便于快速清空成常量序列
    void assignAll(T v)
    {
        seg.clear();
        if (n > 0) seg.insert(Node(1, n, norm(v)));
    }
};


在线算法技巧

算法介绍

在线算法指边读入边产出,中途不回看历史。常见技巧包括双堆维护在线中位数、单调队列维护滑窗最值、Fenwick/线段树维护在线前缀与区间统计。这里给出双堆中位数作为可直接使用的在线模板。

常见例题

读入一串流式数据,贯穿过程需要随时输出当前所有读入数的中位数。做法是用一个最大堆维护较小的一半元素、一个最小堆维护较大的一半元素,使两堆大小之差不超过一并保持最大堆堆顶不大于最小堆堆顶即可。

代码
// 在线中位数:双堆维护
// 功能:支持插入与查询当前中位数
// 复杂度:插入O(log n),查询O(1)
struct OnlineMedian
{
    priority_queue<ll> L;                                       // 大根堆,存较小的一半
    priority_queue<ll, vector<ll>, greater<ll>> R; // 小根堆,存较大的一半

    OnlineMedian() {}

    void insert(ll x)
    {
        if (L.empty() || x <= L.top()) L.push(x);
        else R.push(x);
        if ((int)L.size() > (int)R.size() + 1) R.push(L.top()), L.pop();
        if ((int)R.size() > (int)L.size()) L.push(R.top()), R.pop();
    }

    ll median()
    {
        return L.top();
    }
};

模拟退火 / 随机化算法

算法介绍

模拟退火用于求近似全局最优,通过温度从高到低逐步收敛。每步在当前解附近随机扰动,若解更优则必然接受,否则以概率 exp(−Δ/ T) 接受,从而跳出局部最优。温度下降可用乘性衰减,也可结合 reheating 与多重起点提升稳定性。

常见例题

在一维连续区间 [L, R] 上最大化某个黑箱函数 f(x),例如带随机噪声的打分函数。做法是以区间中点为初始解,设置温度 T 并不断尝试在当前 x 附近的随机位移,若打分更高则接受,否则以一定概率接受,温度逐步衰减到阈值为止。

代码
// 模拟退火模板(实数域单峰/多峰黑箱优化)
// 功能:最大化用户给定f(x),支持温度衰减与随机扰动
// 复杂度:与迭代次数线性,常用几千到几万步
struct SimulatedAnnealing
{
    ld L, R;
    ld T0, Tend, alpha;     // 初温、终温与衰减因子
    function<ld(ld)> f; // 目标函数(仅本处使用,可换为捕获lambda)
    Random rnd;

    SimulatedAnnealing(ld L_, ld R_, ld T0_=1e2L, ld Tend_=1e-6L, ld alpha_=0.985L)
    {
        L = L_; R = R_; T0 = T0_; Tend = Tend_; alpha = alpha_;
    }

    // 在区间内跑一次模拟退火,返回近似最优x
    ld solve(ld x0)
    {
        ld x = x0, fx = f(x0);
        for (ld T = T0; T > Tend; T *= alpha)
        {
            ld nx = x + rnd.nextReal(-1.0L, 1.0L) * T;
            if (nx < L) nx = L; if (nx > R) nx = R;
            ld nfx = f(nx);
            ld delta = nfx - fx;
            if (delta >= 0) x = nx, fx = nfx;
            else
            {
                ld prob = expl(delta / T);
                ld coin = rnd.nextReal(0.0L, 1.0L);
                if (coin < prob) x = nx, fx = nfx;
            }
        }
        return x;
    }
};

声明

作者:ExtractStars - Codeforces

部分内容参考:

jiangly - Codeforces 在codeforces上的提交

OI Wiki - OI Wiki

有部分模板未经验证,可能存在错误,如果在赛场上使用导致比赛失利,概不负责,请谨慎使用该模板

docx版本(已排版)网盘链接:https://pan.baidu.com/s/15C0G5CY2aqQykuc5sHz_RQ?pwd=jfur

退役了,把使用的算法模板整理分享一下

posted @ 2025-09-20 16:40  ExtractStars  阅读(30)  评论(0)    收藏  举报