线段树
基本知识
基本用途:对序列进行维护,支持查询和修改指令
1.线段树的每一个节点都代表一个区间
2.线段树具有唯一的根节点,代表的区间是整个统计范围
3.线段树的每个叶节点都代表长度为1的元区间
4.对于每个内部节点\([l,r]\),它的左节点是\([l,mid]\),右节点是\([mid + 1,r]\),其中\(mid = (l + r) >> 1\)(向下取整)
节点编号方法:对于编号为\(x\)的节点:左子节点编号为\(2x\),右子节点编号为\(2x + 1\)
注意
保存线段树的数组长度要不小于\(4N\)
线段树的建树
在区间\([1,N]\)上建立一棵线段树。每个叶节点保存\(a[i]\)的值
由于线段树的二叉树结构 --- 我们可以很方便地从下往上传递信息
代码:
struct Tree {
int l,r;
long long sum,add;
}t[maxn * 4];//用结构体数组保存线段树(四倍空间)
void build(int k,int l,int r) {
t[k].l = l;
t[k].r = r;//节点k代表区间[l,r]
if (l == r) {
t[k].sum = a[l];//叶节点
return;
}
int mid = (l + r) >> 1;
int ls = k << 1;//左儿子
int rs = ls + 1;//右儿子
build(ls,l,mid);//左子树
build(rs,mid + 1,r);//右子树
t[k].sum = t[ls].sum + t[rs].sum;
}
build(1,1,n);//调用入口
线段树的单点修改
eg.我们需要把\(a[x]\)的值修改为\(v\)
从根节点出发,递归找到代表区间\([x,x]\)的叶节点,然后从下往上更新\([x,x]\)以及它的所有祖先节点的信息
时间复杂度:\(O(logN)\)
代码:
void change(int k,int x,int v) {
if (t[k].l == t[k].r) { //找到叶节点就更新它的信息
t[k].sum = v;
return;
}
int mid = (t[k].l + t[k].r) >> 1;
int ls = k << 1;
int rs = ls + 1;
if (x <= mid) //x属于左半区间
change(ls,x,v);
else //x属于右半区间
change(rs,x,v);
t[k].sum = t[ls].sum + t[rs].sum;//从下往上更新信息
}
线段树的区间查询
eg.查询\(a\)序列在区间\([l,r]\)上的最大值
从根节点开始,递归执行以下过程:
1.若\([l,r]\)完全覆盖当前节点代表的区间,就直接回溯,记录该节点的dat值为候选答案
2.若左子节点与\([l,r]\)有重叠部分,则递归访问左子节点
3.若右子节点与\([l,r]\)有重叠部分,则递归访问右子节点
该查询过程会把询问区间\([l,r]\)在线段树上分成\(O(logN)\)个节点,并取它们的最大值作为最终答案
代码:
long long ask(int k,int l,int r) {
if (l <= t[k].l && r >= t[k].r) return t[k].sum; //完全覆盖的情况
int mid = (t[k].l + t[k].r) >> 1;
long long val = -0x3f3f3f3f; //因为要取最大值,所以赋值为负无穷
int ls = k << 1;
int rs = ls + 1;
if (l <= mid) val = max(val,ask(ls,l,r)); //和左子节点有重叠
if (r > mid) val = max(val,ask(rs,l,r)); //和右子节点有重叠
return val;
}
upd 2021.7.14
延迟标记
为什么需要?
如果一个节点\(k\)代表的区间\([k_l,k_r]\)被当前需要修改的区间\([l,r]\)完全覆盖,即\(l \le p_l \le p_r \le r\),若逐一对以\(k\)点为根的子树中的每一个点进行更新,单次区间修改的时间复杂度则会达到\(O(n)\)。
但是如果我们这样做了更新,而在之后的查询指令中并没有用到\([l,r]\)的子区间,那么更新\(k\)的整棵子树是没有用的。
面对这样的可能情况,我们就需要用到延迟标记。在\(k\)点代表的区间被修改区间完全覆盖时,我们更新\(k\)点的信息,并在回溯之前给\(k\)点加上一个标记,表示该节点曾经被修改,但其子节点尚未被更新。
如果在后续的指令中,需要从节点\(k\)向下递归,这时我们需要检查\(k\)是否有标记,若有标记则更新\(k\)的两个子节点,并给这两个子节点打上延迟标记,再清除\(k\)点的标记。
代码:
void pushdown(int k) { //延迟标记的向下传递
if (t[k].add) { //节点k有标记
int ls = k << 1;
int rs = ls + 1;
t[ls].sum += t[k].add * (t[ls].r - t[ls].l + 1); //更新左子节点信息
t[rs].sum += t[k].add * (t[rs].r - t[rs].l + 1); // 更新右子节点信息
t[ls].add += t[k].add; //给左子节点打延迟标记
t[rs].add += t[k].add; //给右子节点打延迟标记
t[k].add = 0; //清除k点的标记
}
}
void update(int k,int l,int r,int v) { //给区间l~r每个数都加上v
if (l <= t[k].l && r >= t[k].r) { //若k点代表的区间被当前需要修改的区间完全覆盖就执行操作
t[k].sum += v * (t[k].r - t[k].l + 1); //更新k节点信息
t[k].add += v; return; //给k点打延迟标记
}
pushdown(k); //延迟标记下传
int mid = (t[k].l + t[k].r) >> 1;
int ls = k << 1;
int rs = ls + 1;
if (l <= mid) update(ls,l,r,v);
if (r > mid) update(rs,l,r,v);
t[k].sum = t[ls].sum + t[rs].sum;
}
long long query(int k,int l,int r) { //询问l~r的区间和
if (l <= t[k].l && r >= t[k].r) return t[k].sum; //若k点代表的区间被当前询问的区间完全覆盖,就直接返回k点中储存的值
pushdown(k); //延迟标记下传
int mid = (t[k].l + t[k].r) >> 1;
int ls = k << 1;
int rs = ls + 1;
long long val = 0;
if (l <= mid) val += query(ls,l,r);
if (r > mid) val += query(rs,l,r);
return val;
}

浙公网安备 33010602011771号