学习笔记(9)KD-Tree
p.s. 本文大量参考网络博客,如有侵犯望告知删改
前置:主定理(证复杂度用)
关于递归函数的复杂度分析,有这么一个定理可以用,但必须满足以下形式:
其中 \(a\) 表示每一层递归函数的调用次数,\(b\) 表示子问题规模是原问题的多少分之一
那么讨论函数 \(n^{\log_{b}{a}}\) 与 \(f(n)\) 的增长率关系有:
其中的所有讨论只要求近似而并不要求相等,\(c\) 是常数
另外注意以下情况主定理并不适用:
- \(n^{\log_{b}{a}}\) 与 \(f(n)\) 的增长率不可比
- \(n^{\log_{b}{a}}\) 比 \(f(n)\) 增长更快,但没有快到 \(O(n^{\epsilon})\) 倍
- \(n^{\log_{b}{a}}\) 比 \(f(n)\) 增长更慢,但没有慢到 \(O(n^{\epsilon})\) 倍
引入
\(KDT\) 可以说是线段树向高维拓展后的结果,可以作为树套树或 \(cdq\) 的平替使用;而对于满足如下形式的二维修改查询问题,可以利用主定理证明 \(KDT\) 是理论最优的范围修改查询数据结构(不过后文并没有进行比较):
- 对矩形中元素进行一次修改/查询
- 修改对修改有结合律,修改对范围信息有分配律与结合律,范围信息对范围信息有交换律与结合律
较多的人认为 \(KDT\) 相当于一颗高维的替罪羊树,同时也有人认为 \(KDT\) 的查询复杂度属于玄学。其实在具体的实现上并不一定必须采用替罪羊树,理论上正确复杂度的 \(KDT\) 似乎也并不是替罪羊树式的实现;而查询的复杂度是可以被证明的
考虑一颗线段树是如何维护一维序列的。我们每次选取中点 \(mid\) 将序列划分为两个子区间进行维护,查询与修改时只需要找到包含于目标区间内的区间修改根节点的信息,打标记并下传子树
类似的,当维护高维数据时,\(KDT\) 也可以参考这种划分方法——选取一个超平面将超长方体分成两部分,更具体一点地说,即对于每一维坐标“二分”。为保证复杂度,我们选取该维上所有坐标的中位数为 “\(mid\)“。插入与删除操作同样类比线段树,也可以根据实际的题目限制选择不同的最佳实现方式。当然对于一直坚持替罪羊树式维护的选手来说,虽然有可能被精心构造的数据卡掉(目前已知查询半平面,圆的信息时可以被卡 \(T\) 到飞起),但也不失为一种优秀的骗分选择
基本实现
大多数选手选择替罪羊树式的实现,本人没打过(同时了解到这种方式理论复杂度是错的)因此不过多记叙。事实上在建树与 \(Pushup\) 的实现上方法都大差不差,主要区别在于涉及修改时的维护
主体结构
可以分为 \(2\) 种(个人认为):
1). 平衡树式:根节点会表示一个实际的划分点。(如果有插入操作等需要改变树的结构的话——)由于 \(KDT\) 不满足 \(Splay\) 的伸展性与 \(FHQ\) 的随机优先级,于是只能考虑替罪羊树式的“偏暴力”维护(其实就是暴力)。设定平衡因子 \(\alpha\),若 \(siz\) 的比不满足平衡因子则将原来的树拍扁局部暴力重构,不同的 \(\alpha\) 会有复杂度及常数(一般 \(\alpha\) 取 \([0.6, 0.8]\) 效果最佳,不过建议根据题目调整,具体可以参考这篇博客)
2). 线段树式(\(Leafy\) 式):仅叶节点对应实际的点。建议参见 \(cmd\) 的博客,这种写法下,\(KDT\) 与高维线段树的相似(相同)体现地十分明显,不同于前者的维护方式下根节点会表示选取的划分点,该实现方式下根节点只维护左右区间合并的大区间信息,相对更费空间一点,不过如果需要打标记会更好处理
查询
记当前待查询的超矩形为 \(R\),查询时对树上的区间分三类讨论(类比线段树):
-
与 \(R\) 交集为空;
-
包含在 \(R\) 内;
-
与 \(R\) 有交但不包含于 \(R\);
查询时可以根据维护的区间极值进行剪枝,或者是套 \(A*\)。另外记得对于两种不同的主体结构写法,一种还需要考虑根节点自身的点而另一种只需要考虑合并叶子节点
由于查询时的剪枝与超平面内点的分布情况有关(越密集时能用于剪枝的条件越容易达成),因此应用 \(KDT\) 时一般要求**点集大小 \(\geq 2^{k}\) **(其中 \(k\) 为维数,跟复杂度有关),此时的 \(KDT\) 才比较能够体现其优秀,而不是看起来那样的暴力
修改后的重构
实现上也有除替罪羊树局部重构以外的两种:
1). 根号分治:将要插入的点累计下来,每达到 \(B\) 个将整棵树拍扁重构。由于总的修改复杂度是均摊的 \(O(\frac{n \log{n}}{B})\), 查询复杂度为 \(O(B + n^{1 - \frac{1}{k}})\),因此二者同阶时取 \(B = O(n \sqrt{n \ log{n}})\) 最优,总复杂度 \(O(\sqrt{n \log{n}} + n ^ {1 - \frac{1}{k}})\)(其中修改 \(O(\sqrt{n\log{n}})\),查询 \(O(\sqrt{n \log{n}} + n ^ {1 - \frac{1}{k}})\)),如果要去掉 \(\log\) 可以维护一大一小两棵树,见 \(cmd\) 的博客(常数较优并且好写,考场上建议用这种写法)
2). 二进制分组:应该说是重构的唯一正确实现。维护若干棵大小为 \(2^k\) 的 \(KDT\),满足这些树的大小之和为 \(n\)。插入时新建大小为 \(1\) 的树然后与更大规模的树向上合并,类似于二进制加法,查询时需要统计每棵树上的查询,修改是 \(O(n \log^{2}{n})\) 的,查询复杂度则是完美的 \(O(n^{1 - \frac{1}{k}})\)(但是实测却没有跑过丢掉 \(\log\) 的根号分治就很神奇,大概是我的实现不太优美……)。另外听说这种写法可以维护时间戳
一般来说只有插入需要重构,删除可以通过在树上打标记来实现。
关于复杂度
建树
轮替划分每个维度的话每次递归两边建树,中间只需要调用 \(nth\)_\(element\) 求中位数,有:
\(T(n) = 2T(\frac n2) + O(n)\),\(\log_{b}{a} = 1\),\(\epsilon = k = 0\),\(f(n) = O(n^{\log_{b}{a}}\log^{k}{n}) = O(n)\)
\(\Rightarrow T(n) = O(n \log{n})\)
另外有根据方差大小选择划分维度的做法,但听说 \(lxl\) 说这是错的,不过有些题目里会有出其不意的卡常效果
修改
单纯修改是 \(O(\log{n})\) 的,瓶颈在重构,替罪羊单次均摊好像也是 \(\log\) 的,不会证,另外两个做法前面已经提过了
查询
查邻域最坏还是 \(O(n^{k})\) 的,在保证每轮操作能够二等分当前区间时,即树高能够保证严格的 \(\log{n} + O(1)\) 时可证为 \(O(n^{1 - \frac 1k})\),而替罪羊树的树高为 \(O(\log{n})\),不能保证严格的 \(\log\),故称替罪羊树的维护复杂度是错误的(同时实现上一般带 \(2\) ~ \(4\) 倍的常数)
按轮换划分分析,每一轮会对 \(k\) 维依次划分,会产生 \(2^{k}\) 个部分,而一个用来划分的超平面最多跨越 \(2^{k - 1}\) 个部分(归纳可证),有:
\(T(n) = 2^{k - 1}T(\frac{n}{2^{k}}) + O(k)\),\(\log_{b}{a} = \frac{k - 1}{k}\), \(\epsilon = \frac{k - 1}{k}\),\(f(n) = O(n^{\log_{b}{a} - \epsilon}) = O(1)\)
\(\Rightarrow T(n) = O(n^{1 - \frac 1k})\)
例题
\(20MB\) 的空间限制卡掉了树套树,强制在线卡掉了 \(cdq\) 分治,不过任何比赛里的出题人应该都不会丧心病狂到这种程度、正解是 \(KDT\) 的话一般都会被骂、
根号分治写法(\(O(n\sqrt{n\log{n}})\)) \(\Downarrow\):
点击查看代码
#include <bits/stdc++.h>
#define N 200005
#define K 2
using namespace std;
inline int read(){
char ch = getchar(); int x = 0, f = 1;
while(!isdigit(ch)){if(ch == '-') f = -1; ch = getchar();}
while(isdigit(ch)){x = (x << 3) + (x << 1) + (ch ^ 48); ch = getchar();}
return x * f;
}
bool mbegin;
int n, cur, last, ql, qr;
struct Node{
int val;
int L[K], R[K], x[K];
}p[N], tmp;
bool cmp(Node a, Node b){return a.x[cur] < b.x[cur];}
struct KDT{
#define lc (now << 1)
#define rc (now << 1 | 1)
Node t[N << 2];
void pushup(int now){
t[now].val = t[lc].val + t[rc].val;
for(int i = 0; i < K; i++){
t[now].L[i] = min(t[lc].L[i], t[rc].L[i]);
t[now].R[i] = max(t[lc].R[i], t[rc].R[i]);
}
}
void build(int now, int l, int r, int d = 0){
if(l == r){
for(int i = 0; i < K; i++) t[now].L[i] = t[now].R[i] = p[l].x[i];
t[now].val = p[l].val;
return;
}
int mid = (l + r) >> 1;
cur = d, nth_element(p + l, p + mid, p + r + 1, cmp);
build(lc, l, mid, (d + 1) % K);
build(rc, mid + 1, r, (d + 1) % K);
pushup(now);
}
bool in(Node a, Node b){
for(int i = 0; i < K; i++) if(a.L[i] > b.L[i] || a.R[i] < b.R[i]) return 0;
return 1;
}
bool out(Node a, Node b){
for(int i = 0; i < K; i++) if(a.L[i] > b.R[i] || a.R[i] < b.L[i]) return 1;
return 0;
}
void query(int now){
if(out(tmp, t[now])) return;
if(in(tmp, t[now])){tmp.val += t[now].val; return;}
query(lc), query(rc);
}
}T;
bool mend;
int main(){
// freopen(".in", "r", stdin);
// freopen(".out", "w", stdout);
cerr << abs(&mbegin - &mend) / 1024 / 1024 << "MB" << endl;
n = read();
while(1){
int opt = read();
if(opt == 1){
++qr;
p[qr].x[0] = read() ^ last, p[qr].x[1] = read() ^ last, p[qr].val = read() ^ last;
if((qr - ql) * (qr - ql) > qr * 40) T.build(1, 1, ql = qr);
}
else if(opt == 2){
for(int i = 0; i < K; i++) tmp.L[i] = read() ^ last;
for(int i = 0; i < K; i++) tmp.R[i] = read() ^ last;
tmp.val = 0;
T.query(1);
for(int i = ql + 1; i <= qr; i++){
nex:
if(i > qr) break;
for(int j = 0; j < K; j++) if(tmp.L[j] > p[i].x[j] || tmp.R[j] < p[i].x[j]){++i; goto nex;}
tmp.val += p[i].val;
}
printf("%d\n", tmp.val);
last = tmp.val;
}
else break;
}
return 0;
}
二进制分组写法(\(O(n^{1 - \frac{1}{k}})\)) \(\Downarrow\):
点击查看代码
#include <bits/stdc++.h>
#define N 200005
#define LOG 18
#define K 2
using namespace std;
inline int read(){
char ch = getchar(); int x = 0, f = 1;
while(!isdigit(ch)){if(ch == '-') f = -1; ch = getchar();}
while(isdigit(ch)){x = (x << 3) + (x << 1) + (ch ^ 48); ch = getchar();}
return x * f;
}
bool mbegin;
int n, cur, last;
struct Node{
int x[K];
int val, sum, lc, rc;
int mn[K], mx[K];
}l, h, t[N];
bool cmp(int x, int y){return t[x].x[cur] < t[y].x[cur];}
struct KD_Tree{
int tot;
int root[LOG], p[N];
void pushup(int now){
t[now].sum = t[t[now].lc].sum + t[t[now].rc].sum + t[now].val;
for(int i = 0; i < K; i++){
t[now].mn[i] = t[now].mx[i] = t[now].x[i];
if(t[now].lc){
t[now].mn[i] = min(t[now].mn[i], t[t[now].lc].mn[i]);
t[now].mx[i] = max(t[now].mx[i], t[t[now].lc].mx[i]);
}
if(t[now].rc){
t[now].mn[i] = min(t[now].mn[i], t[t[now].rc].mn[i]);
t[now].mx[i] = max(t[now].mx[i], t[t[now].rc].mx[i]);
}
}
}
int build(int l, int r, int d = 0){
int mid = (l + r) >> 1;
cur = d, nth_element(p + l, p + mid, p + r + 1, cmp);
int now = p[mid];
if(l < mid) t[now].lc = build(l, mid - 1, (d + 1) % K);
if(r > mid) t[now].rc = build(mid + 1, r, (d + 1) % K);
pushup(now);
return now;
}
void append(int &now){
if(!now) return;
p[++tot] = now;
append(t[now].lc);
append(t[now].rc);
now = 0;
}
int query(int now){
if(!now) return 0;
bool flag = 0;
for(int i = 0; i < K; i++) flag |= (!(l.x[i] <= t[now].mn[i] && t[now].mx[i] <= h.x[i]));
if(!flag) return t[now].sum;
for(int i = 0; i < K; i++) if(t[now].mx[i] < l.x[i] || h.x[i] < t[now].mn[i]) return 0;
int res = 0;
flag = 0;
for(int i = 0; i < K; i++) flag |= (!(l.x[i] <= t[now].x[i] && t[now].x[i] <= h.x[i]));
if(!flag) res = t[now].val;
res += query(t[now].lc) + query(t[now].rc);
return res;
}
}T;
bool mend;
int main(){
// freopen(".in", "r", stdin);
// freopen(".out", "w", stdout);
cerr << abs(&mbegin - &mend) / 1024 / 1024 << "MB" << endl;
n = read();
n = 0;
while(1){
int opt = read();
if(opt == 1){
int x = read() ^ last, y = read() ^ last, A = read() ^ last;
t[++n] = {{x, y}, A};
T.p[T.tot = 1] = n;
for(int i = 0; ; i++){
if(!T.root[i]){T.root[i] = T.build(1, T.tot); break;}
else T.append(T.root[i]);
}
}
else if(opt == 2){
for(int i = 0; i < K; i++) l.x[i] = read() ^ last;
for(int i = 0; i < K; i++) h.x[i] = read() ^ last;
last = 0;
for(int i = 0; i < LOG; i++) last += T.query(T.root[i]);
printf("%d\n", last);
}
else break;
}
return 0;
}
根号分治优化(\(O(n \sqrt{n} + n^{\frac{5}{4}}\log{n})\)) \(\Downarrow\):
点击查看代码
#include <bits/stdc++.h>
#define N 200005
#define K 2
using namespace std;
inline int read(){
char ch = getchar(); int x = 0, f = 1;
while(!isdigit(ch)){if(ch == '-') f = -1; ch = getchar();}
while(isdigit(ch)){x = (x << 3) + (x << 1) + (ch ^ 48); ch = getchar();}
return x * f;
}
bool mbegin;
int n, cur, last, p0, p1, p2;
struct Node{
int val;
int L[K], R[K], x[K];
}p[N], tmp;
bool cmp(Node a, Node b){return a.x[cur] < b.x[cur];}
struct KDT{
#define lc (now << 1)
#define rc (now << 1 | 1)
Node t[N << 2];
void pushup(int now){
t[now].val = t[lc].val + t[rc].val;
for(int i = 0; i < K; i++){
t[now].L[i] = min(t[lc].L[i], t[rc].L[i]);
t[now].R[i] = max(t[lc].R[i], t[rc].R[i]);
}
}
void build(int now, int l, int r, int d = 0){
if(l == r){
for(int i = 0; i < K; i++) t[now].L[i] = t[now].R[i] = p[l].x[i];
t[now].val = p[l].val;
return;
}
int mid = (l + r) >> 1;
cur = d, nth_element(p + l, p + mid, p + r + 1, cmp);
build(lc, l, mid, (d + 1) % K);
build(rc, mid + 1, r, (d + 1) % K);
pushup(now);
}
bool in(Node a, Node b){
for(int i = 0; i < K; i++) if(a.L[i] > b.L[i] || a.R[i] < b.R[i]) return 0;
return 1;
}
bool out(Node a, Node b){
for(int i = 0; i < K; i++) if(a.L[i] > b.R[i] || a.R[i] < b.L[i]) return 1;
return 0;
}
void query(int now = 1){
if(out(tmp, t[now])) return;
if(in(tmp, t[now])){tmp.val += t[now].val; return;}
query(lc), query(rc);
}
}T1, T2;
bool mend;
int main(){
// freopen(".in", "r", stdin);
// freopen(".out", "w", stdout);
cerr << abs(&mbegin - &mend) / 1024 / 1024 << "MB" << endl;
n = read();
while(1){
int opt = read();
if(opt == 1){
++p2;
p[p2].x[0] = read() ^ last, p[p2].x[1] = read() ^ last, p[p2].val = read() ^ last;
if((p2 - p1) * (p2 - p1) > 1.5 * p2) T2.build(1, p0 + 1, p1 = p2);
else if((p1 - p0) > 1.2 * pow(p2, 0.75))
{T1.build(1, 1, p0 = p1); T2.build(1, p0 + 1, p1 = p2);}
}
else if(opt == 2){
for(int i = 0; i < K; i++) tmp.L[i] = read() ^ last;
for(int i = 0; i < K; i++) tmp.R[i] = read() ^ last;
tmp.val = 0;
T1.query();
T2.query();
for(int i = p1 + 1; i <= p2; i++){
nex:
if(i > p2) break;
for(int j = 0; j < K; j++) if(tmp.L[j] > p[i].x[j] || tmp.R[j] < p[i].x[j]){++i; goto nex;}
tmp.val += p[i].val;
}
printf("%d\n", tmp.val);
last = tmp.val;
}
else break;
}
return 0;
}
应用
一般用于二维~三维的区间处理/邻域查询(以及复杂度不太靠谱的第 \(K\) 近/远点查询),原因在于维数越高时 \(KDT\) 的复杂度也会相应地“指数爆炸”,该数据结构发明者还特此提出了 \(BBF\) 算法以进行优化,不过太高深了还学不来、
关于邻域查询:估价函数一定三思! 欧几里得距离下可以取估价函数(以最小距离为例) \(mind = \min(a.L[0] - x[0], a.R[0] - x[0]) + \min(a.L[1] - x[1], a.R[1] - x[1]) + ...\)
效果类似于从当前点出发作圆,而在曼哈顿距离下则不能照搬欧式距离的 \(\max/\min\) 形式,查询时假定查询的点在左右子树的矩形之外(因为在矩形内则最小曼哈顿距离固定,一定可以直接判断是否可以更新答案),因此估价函数需要取(以最小距离为例) \(\max(0, x[0] - a.R[0]) + \max(0, a.L[0] - x[0]) + \max(0, x[1] - a.R[1]) + \max(0, a.L[1] - x[1])...\) 否则极有可能无法得到正确答案
对于距离第 \(k\) 远/近问题可以直接开堆暴力维护,复杂度极其容易假(指可能被卡到 \(O(n^{2} \log)\)),一般建议骗分使用
欧式距离:
点击查看代码:
#include <bits/stdc++.h>
#define N 100005
#define K 2
#define ll long long
using namespace std;
inline int read(){
char ch = getchar(); int x = 0, f = 1;
while(!isdigit(ch)){if(ch == '-') f = -1; ch = getchar();}
while(isdigit(ch)){x = (x << 3) + (x << 1) + (ch ^ 48); ch = getchar();}
return x * f;
}
int n, m, cur, root, xi, yi, ki;
struct Point{
ll dis;
int id;
bool operator < (const Point &b)const{
return dis > b.dis || (dis == b.dis && id < b.id);
}
};
priority_queue<Point> q;
struct Node{
int lc, rc, id;
int L[K], R[K], x[K];
}p[N];
bool cmp(Node a, Node b){return a.x[cur] < b.x[cur];}
ll qp(ll x){return x * x;}
ll dis(Node a){return qp(a.x[0] - xi) + qp(a.x[1] - yi);}
ll maxd(Node a){return max(qp(a.L[0] - xi), qp(a.R[0] - xi)) + max(qp(a.L[1] - yi), qp(a.R[1] - yi));}
struct KDT{
int root, tot;
Node t[N << 2];
void pushup(int now){
if(!now) return;
for(int i = 0; i < K; i++){
t[now].L[i] = t[now].R[i] = t[now].x[i];
if(t[now].lc){
t[now].L[i] = min(t[now].L[i], t[t[now].lc].L[i]);
t[now].R[i] = max(t[now].R[i], t[t[now].lc].R[i]);
}
if(t[now].rc){
t[now].L[i] = min(t[now].L[i], t[t[now].rc].L[i]);
t[now].R[i] = max(t[now].R[i], t[t[now].rc].R[i]);
}
}
}
void build(int &now, int l = 1, int r = n, int d = 0){
now = ++tot;
int mid = (l + r) >> 1;
cur = d, nth_element(p + l, p + mid, p + r + 1, cmp);
t[now].id = p[mid].id;
for(int i = 0; i < K; i++) t[now].x[i] = p[mid].x[i];
if(l < mid) build(t[now].lc, l, mid - 1, (d + 1) % K);
if(r > mid) build(t[now].rc, mid + 1, r, (d + 1) % K);
pushup(now);
}
void query(int now = 1){
if(!now) return;
ll tmp = dis(t[now]);
if(tmp > q.top().dis || (tmp == q.top().dis && t[now].id < q.top().id))
q.pop(), q.push((Point){tmp, t[now].id});
ll ld = maxd(t[t[now].lc]), rd = maxd(t[t[now].rc]);
if(ld > rd){
if(ld >= q.top().dis) query(t[now].lc);
if(rd >= q.top().dis) query(t[now].rc);
}
else{
if(rd >= q.top().dis) query(t[now].rc);
if(ld >= q.top().dis) query(t[now].lc);
}
}
}T;
int main(){
// freopen(".in", "r", stdin);
// freopen(".out", "w", stdout);
n = read();
for(int i = 1; i <= n; i++) p[i].x[0] = read(), p[i].x[1] = read(), p[i].id = i;
T.build(T.root);
m = read();
while(m--){
xi = read(), yi = read(), ki = read();
while(!q.empty()) q.pop();
while(ki--) q.push((Point){-1ll, 0});
T.query();
printf("%d\n", q.top().id);
}
return 0;
}
曼哈顿距离(例题:天使玩偶/SJY摆棋子):
点击查看代码
#include <bits/stdc++.h>
#define N 300005
#define K 2
#define INF 0x7f7f7f7f
using namespace std;
inline int read(){
char ch = getchar(); int x = 0, f = 1;
while(!isdigit(ch)){if(ch == '-') f = -1; ch = getchar();}
while(isdigit(ch)){x = (x << 3) + (x << 1) + (ch ^ 48); ch = getchar();}
return x * f;
}
int n, m, cur, p1, p2, xi, yi, ans;
struct Node{
int lc, rc;
int L[K], R[K], x[K];
}p[N << 1];
bool cmp(Node a, Node b){return a.x[cur] < b.x[cur];}
int dis(Node a){return abs(a.x[0] - xi) + abs(a.x[1] - yi);}
int mind(Node a){
int res = 0;
res += max(0, xi - a.R[0]) + max(0, a.L[0] - xi);
res += max(0, yi - a.R[1]) + max(0, a.L[1] - yi);
return res;
}
struct KDT{
int root, tot;
Node t[N << 1];
void pushup(int now){
if(!now) return;
for(int i = 0; i < K; i++){
t[now].L[i] = t[now].R[i] = t[now].x[i];
if(t[now].lc){
t[now].L[i] = min(t[now].L[i], t[t[now].lc].L[i]);
t[now].R[i] = max(t[now].R[i], t[t[now].lc].R[i]);
}
if(t[now].rc){
t[now].L[i] = min(t[now].L[i], t[t[now].rc].L[i]);
t[now].R[i] = max(t[now].R[i], t[t[now].rc].R[i]);
}
}
}
void build(int &now, int l, int r, int d = 0){
now = ++tot;
int mid = (l + r) >> 1;
cur = d, nth_element(p + l, p + mid, p + r + 1, cmp);
for(int i = 0; i < K; i++) t[now].x[i] = p[mid].x[i];
t[now].lc = t[now].rc = 0;
if(l < mid) build(t[now].lc, l, mid - 1, (d + 1) % K);
if(r > mid) build(t[now].rc, mid + 1, r, (d + 1) % K);
pushup(now);
}
void query(int now = 1){
if(!now) return;
ans = min(ans, dis(t[now]));
int ld = mind(t[t[now].lc]), rd = mind(t[t[now].rc]);
if(ld < rd){
if(ld < ans) query(t[now].lc);
if(rd < ans) query(t[now].rc);
}
else{
if(rd < ans) query(t[now].rc);
if(ld < ans) query(t[now].lc);
}
}
}T;
int main(){
// freopen(".in", "r", stdin);
// freopen(".out", "w", stdout);
n = read(), m = read();
for(int i = 1; i <= n; i++) p[i].x[0] = read(), p[i].x[1] = read();
T.tot = 0, T.build(T.root, 1, p1 = p2 = n);
while(m--){
int opt = read();
xi = read(), yi = read();
if(opt == 1){
++p2;
p[p2].x[0] = xi, p[p2].x[1] = yi;
if((p2 - p1) * (p2 - p1) > 50 * p2) T.tot = 0, T.build(T.root, 1, p1 = p2);
}
else{
ans = INF;
T.query(T.root);
for(int i = p1 + 1; i <= p2; i++) ans = min(ans, dis(p[i]));
printf("%d\n", ans);
}
}
return 0;
}
加标记时直接按线段树写即可,应该没有什么理解上的门槛:
例如某道神秘联考题,要求:
- 单点查 \(a_{i}\)
- 对 \(x\) 到 \(y\) 的路径上的所有点 \(a_{i} \leftarrow ka_{i} + b\)
- 对子树 \(x\) 内的所有点 \(a_{i} \leftarrow ka_{i} + b\)
- 对距离 \(x\) 不超过 \(r\) 的所有点 \(a_{i} \leftarrow ka_{i} + b\)
点击查看代码
#include <bits/stdc++.h>
#define N 200005
#define K 2
#define p 998244353
#define pb push_back
using namespace std;
static char buf[0xfffff], *p1 = buf, *p2 = buf;
#define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 0xfffff, stdin), p1 == p2)? EOF : *p1++)
inline int read(){
char ch = getchar(); int x = 0, f = 1;
while(!isdigit(ch)){if(ch == '-') f = -1; ch = getchar();}
while(isdigit(ch)){x = (x << 3) + (x << 1) + (ch ^ 48); ch = getchar();}
return x * f;
}
#undef getchar
int Case, n, m, cur, pos;
int dep[N], siz[N], fa[N], top[N], son[N], dfn[N];
vector<int> e[N];
inline int add(int x, int y){return (x + y >= p)? x - p + y : x + y;}
inline int sub(int x, int y){return (x < y)? x - y + p : x - y;}
inline int mul(int x, int y){return 1ll * x * y % p;}
struct Node{
int sum, cnt, tk = 1, tb = 0;
int x[K], L[K], R[K];
}tmp, P[N];
bool cmp(Node a, Node b){return a.x[cur] < b.x[cur];}
struct KDT{
#define Lc (now << 1)
#define Rc (now << 1 | 1)
#define mid ((l + r) >> 1)
Node t[N << 2];
void pushup(int now){
t[now].sum = add(t[Lc].sum, t[Rc].sum), t[now].cnt = add(t[Lc].cnt, t[Rc].cnt);
for(int i = 0; i < K; i++)
t[now].L[i] = min(t[Lc].L[i], t[Rc].L[i]), t[now].R[i] = max(t[Lc].R[i], t[Rc].R[i]);
}
void build(int now = 1, int l = 1, int r = n, int d = 0){
if(l == r){
t[now].sum = P[l].sum; t[now].cnt = 1;
for(int i = 0; i < K; i++) t[now].L[i] = t[now].R[i] = P[l].x[i];
return;
}
cur = d, nth_element(P + l, P + mid, P + r + 1, cmp);
build(Lc, l, mid, (d + 1) % K), build(Rc, mid + 1, r, (d + 1) % K); pushup(now);
}
bool in(Node a, Node b){
for(int i = 0; i < K; i++) if(a.L[i] > b.L[i] || a.R[i] < b.R[i]) return 0;
return 1;
}
bool out(Node a, Node b){
for(int i = 0; i < K; i++) if(a.L[i] > b.R[i] || a.R[i] < b.L[i]) return 1;
return 0;
}
void update(int now, int k, int b){
t[now].sum = add(mul(t[now].sum, k), mul(t[now].cnt, b));
t[now].tk = mul(t[now].tk, k), t[now].tb = add(mul(t[now].tb, k), b);
}
void pushdown(int now){
if(t[now].tk ^ 1 || t[now].tb) update(Lc, t[now].tk, t[now].tb), update(Rc, t[now].tk, t[now].tb);
t[now].tk = 1, t[now].tb = 0;
}
void modify(int k, int b, int now = 1){
if(out(tmp, t[now])) return;
if(in(tmp, t[now])){update(now, k, b); return;}
pushdown(now); modify(k, b, Lc), modify(k, b, Rc); pushup(now);
}
void query(int now = 1){
if(out(tmp, t[now])) return;
if(in(tmp, t[now])){tmp.sum = add(tmp.sum, t[now].sum); return;}
pushdown(now); query(Lc), query(Rc);
}
}T;
void dfs1(int u, int f){
siz[u] = 1, dep[u] = dep[fa[u] = f] + 1;
for(int v : e[u]){
if(v == f) continue;
dfs1(v, u);
siz[u] += siz[v];
if(siz[v] > siz[son[u]]) son[u] = v;
}
}
void dfs2(int u, int f){
dfn[u] = ++pos, top[u] = f;
if(son[u]) dfs2(son[u], f);
for(int v : e[u]) if(v ^ fa[u] && v ^ son[u]) dfs2(v, v);
}
void update(int u, int v, int k, int b){
while(top[u] ^ top[v]){
if(dep[top[u]] < dep[top[v]]) swap(u, v);
tmp.L[0] = dfn[top[u]], tmp.R[0] = dfn[u], tmp.L[1] = 0, tmp.R[1] = n;
T.modify(k, b);
u = fa[top[u]];
}
if(dep[u] > dep[v]) swap(u, v);
tmp.L[0] = dfn[u], tmp.R[0] = dfn[v], tmp.L[1] = 0, tmp.R[1] = n;
T.modify(k, b);
}
int main(){
// freopen("tour.in", "r", stdin);
// freopen("tour.out", "w", stdout);
Case = read(), n = read(), m = read();
for(int i = 1; i < n; i++){
int u = read(), v = read();
e[u].pb(v), e[v].pb(u);
}
dfs1(1, 0), dfs2(1, 1);
for(int i = 1; i <= n; i++) P[i].sum = read(), P[i].x[0] = dfn[i], P[i].x[1] = dep[i];
T.build();
while(m--){
int opt = read(), x = read(), y = 0, k = 0, b = 0;
if(opt == 2 || opt == 4) y = read();
if(opt ^ 1) k = read(), b = read();
if(opt == 1){
tmp.sum = 0, tmp.L[0] = tmp.R[0] = dfn[x], tmp.L[1] = tmp.R[1] = dep[x];
T.query();
printf("%d\n", tmp.sum);
}
if(opt == 2) update(x, y, k, b);
if(opt == 3){
tmp.L[0] = dfn[x], tmp.R[0] = dfn[x] + siz[x] - 1, tmp.L[1] = 0, tmp.R[1] = n;
T.modify(k, b);
}
if(opt == 4){
tmp.L[0] = dfn[x], tmp.R[0] = dfn[x] + siz[x] - 1, tmp.L[1] = dep[x], tmp.R[1] = dep[x] + y;
T.modify(k, b);
int u = x; --y;
for(x = fa[x]; ~y && x; --y, u = x, x = fa[x]){
tmp.L[0] = dfn[x], tmp.R[0] = dfn[u] - 1, tmp.L[1] = dep[x], tmp.R[1] = dep[x] + y;
T.modify(k, b);
if(dfn[u] + siz[u] < dfn[x] + siz[x]){
tmp.L[0] = dfn[u] + siz[u], tmp.R[0] = dfn[x] + siz[x] - 1, tmp.L[1] = dep[x], tmp.R[1] = dep[x] + y;
T.modify(k, b);
}
}
}
}
return 0;
}