懒标记线段树

1. 操作

符号 含义
\(nums\) 原数组
\(d\) 线段树节点维护值
\(lazytag\) 线段树节点懒标记值
\(p\) 当前节点
\(s\) 查询区间的开始
\(e\) 查询区间的结尾
\(l\) 节点区间的开始
\(r\) 节点区间的结尾
  • 一般习惯:
    • 建树从下标 \(1\) 开始
    • \(mid = (l + r) >> 1\)

2. 建立线段树

\(nums = \{1,2,3\}\) 为例子,首先是提供的 api ,伪代码如下所示。对于一个节点,我们从根节点开始,递归构造整个树到叶子节点,并且从叶子节点 \(pushup\) 到根节点。对于一个节点:

  • 若当前是叶子节点,即 \(s=e\) ,那么 \(d_{p} = nums_{p}\)
  • 若当前的节点是非叶子节点,即 \(s < e\) ,那么将区间分割为两个部分即 \([s, mid]\)\([mid + 1, l]\) , 递归到左右两个孩子进行建树。
void pushup(int p){
	d[p] = d[p << 1] + d[(p << 1) | 1];
}
void build(int s, int e, int p){
	if (s == e){
		d[p] = nums[p];
		return;
	}
	int mid = (s + e) >> 1;
	build(s, mid, p << 1);
	build(mid + 1, e, (p << 1) | 1);
	pushup(p);
}

3. 区间查询

对于一个节点,其存储的是 \([l, r]\) 区间内维护的值,而对于一个区间查询 \([s, e]\) ,若有查询区间 \([s, e]\) 完全覆盖当前节点区间 \([l, r]\)\(s\leq l\)\(r\leq e\) 查询的值即为 \(d_{p}\) 。否则,区间将一分为二,在分别查询。

  • \([l, mid]\) 左儿子对应的下标为 \(p << 1\)
  • \([mid + 1, r]\) 右儿子的对应下标为 \((p << 1) | 1\)
    如果每一次更新区间值,都会使得整个线段树向下更新到根节点。在区间更新的时候应该使用懒标记线段树,延迟整个节点的信息更新。
    带懒标记的线段树,实际上是父节点暂时记录了下推到子节点的信息。在查询时,才将延迟更新的节点信息加载到子节点。
int query(int l, int r, int s, int e, int p){
	if (l >= s and r <= e){
		return  d[p];
	}
	auto pushdown = [&](int p){
		if (lazy[p]){
			d[p << 1] += lazy[p] * (mid - s + 1);
			d[(p << 1) | 1] += lazy[p] * (e - mid);
			lazy[p << 1] += lazy[p];
			lazy[(p << 1) | 1] += lazy[p];
		}
		lazy[p] = 0;
	};
	pushdown(p);
	int mid = (s + e) >> 1;
	int ret = 0;
	if (s <= mid){
		ret += query(l, r, s, mid, p << 1);
	}
	if (l >= mid){
		ret += query(l, r, mid, l, (p << 1) | 1);
	}
	return ret;
}

4.区间修改

区间修改的过程是产生新的懒标记,而区间查询是将懒标记向下传递。

void modify(int l, int r, int x, int s, int e, int p){
	if (l <= s and r >= e){
		d[p] += (e - s + 1) * x;
		lazy[p] += x;
		return;
	}
	int mid = (s + e) >> 1;
	pushdown(p);
	if (s <= mid){
		modify(l, r, x, s, mid, p << 1);
	}
	if (e >= mid){
		modify(l, r, x, mid, e, (p << 1) | 1);
	}
	d[p] = d[p << 1] + d[(p << 1) | 1];
}

更新数组后处理求和查询

  • 每个节点存储区间求和值和区间长度即可。
class LazySegmentTree:
    __slots__ = ["op_X", "e_X", "mapping", "compose", "id_M", "N", "log", "N0", "data", "lazy"]

    def __init__(self, op_X, e_X, mapping, compose, id_M, N, array=None):
        self.e_X = e_X
        self.op_X = op_X
        self.mapping = mapping
        self.compose = compose
        self.id_M = id_M
        self.N = N
        self.log = (N - 1).bit_length()
        self.N0 = 1 << self.log
        self.data = [e_X] * (2 * self.N0)
        self.lazy = [id_M] * self.N0
        if array is not None:
            assert N == len(array)
            self.data[self.N0:self.N0 + self.N] = array
            for i in range(self.N0 - 1, 0, -1):
                self.update(i)

    def Set(self, p, x):
        assert 0 <= p < self.N
        p += self.N0
        for i in range(self.log, 0, -1):
            self.push(p >> i)
        self.data[p] = x
        for i in range(1, self.log + 1):
            self.update(p >> i)
    
    def prod(self, l, r):
        if l == r: return self.e_X
        l += self.N0
        r += self.N0
        for i in range(self.log, 0, -1):
            if (l >> i) << i != l:
                self.push(l >> i)
            if (r >> i) << i != r:
                self.push(r >> i)

        sml = smr = self.e_X
        while l < r:
            if l & 1:
                sml = self.op_X(sml, self.data[l])
                l += 1
            if r & 1:
                r -= 1
                smr = self.op_X(self.data[r], smr)
            l >>= 1
            r >>= 1
        return self.op_X(sml, smr)

    def all_prod(self):
        return self.data[1]

    def apply(self, p, f):
        p += self.N0
        for i in range(self.log, 0, -1):
            self.push(p >> i)
        self.data[p] = self.mapping(f, self.data[p])
        for i in range(1, self.log + 1):
            self.update(p >> i)

    def apply(self, l, r, f):
        if l == r: return
        l += self.N0
        r += self.N0
        for i in range(self.log, 0, -1):
            if (l >> i) << i != l:
                self.push(l >> i)
            if (r >> i) << i != r:
                self.push((r - 1) >> i)

        l2, r2 = l, r
        while l < r:
            if l & 1:
                self.all_apply(l, f)
                l += 1
            if r & 1:
                r -= 1
                self.all_apply(r, f)
            l >>= 1
            r >>= 1

        l, r = l2, r2
        for i in range(1, self.log + 1):
            if (l >> i) << i != l:
                self.update(l >> i)
            if (r >> i) << i != r:
                self.update((r - 1) >> i)


    def update(self, k):
        self.data[k] = self.op_X(self.data[2 * k], self.data[2 * k + 1])

    def all_apply(self, k, f):
        self.data[k] = self.mapping(f, self.data[k])
        if k < self.N0:
            self.lazy[k] = self.compose(f, self.lazy[k])

    def push(self, k): 
        if self.lazy[k] is self.id_M: return
        self.data[2 * k] = self.mapping(self.lazy[k], self.data[2 * k])
        self.data[2 * k + 1] = self.mapping(self.lazy[k], self.data[2 * k + 1])
        if 2 * k < self.N0:
            self.lazy[2 * k] = self.compose(self.lazy[k], self.lazy[2 * k])
            self.lazy[2 * k + 1] = self.compose(self.lazy[k], self.lazy[2 * k + 1])
        self.lazy[k] = self.id_M

e_X = [0, 0]
id_M = 0

def op_X(X, Y):
    return [X[0] + Y[0], X[1] + Y[1]]

def compose(f, g):
    return f ^ g

def mapping(f, X):
    if f == 1:
        X[0] = X[1] - X[0]
    return X

class Solution:
    def handleQuery(self, nums1: List[int], nums2: List[int], queries: List[List[int]]) -> List[int]:
        n = len(nums1)
        nums1 = [[x, 1] for x in nums1]
        st = LazySegmentTree(op_X, e_X, mapping, compose, id_M, n, nums1)
        s = sum(nums2)
        ans = []
        for idx, a, b in queries:
            if idx == 1:
                st.apply(a, b + 1, 1)
            elif idx == 2:
                s += st.all_prod()[0] * a
            else:
                ans.append(s)
        return ans
posted @ 2023-07-27 17:46  Wasser007  阅读(55)  评论(0编辑  收藏  举报