一种比较简洁的线段树写法

线段树是一种高效的数据结构,旨在快速的处理区间的修改和查询问题,单次操作复杂度为O(logn)
但我们今天的重点并非讲解线段树的诞生及其原理,而是总结一下线段树代码以及基础的用法
只有原理是写不出代码的,那是纯纯的发明创造

image



先贴板子,可能不是最先进的,但我相信我的版本一定是比较简洁易懂的
(简洁 != 短,比起无脑压行,我相信适当的多写几行是有助于理解的。
本人喜欢将数据结构封装起来,所以用到了一点其他的知识,但是很简单)

struct SegTree {
    #define lc u << 1
    #define rc u << 1 | 1

    typedef struct node {
        int l, r;
        i64 val, tag;   
    } node;
    
    std::vector<i64> a;
    std::vector<node> tree;

    SegTree (int size) {
        a.resize(size + 10);
        tree.resize(4 * size + 10);
    }

    void push_up (int u) {
        /* tree[u].val = tree[lc].val ?? tree[rc].val */
    }

    void push_down (int u) {
        if (tree[u].tag) {
        // Left:
            
        // Right:

        // Final
            tree[u].tag = 0;
        }
    }

    void build (int u, int l, int r) {
        tree[u].l = l;
        tree[u].r = r;
        tree[u].tag = 0;

        if (l == r) {
            tree[u].val = a[l];
            return;
        }

        int mid = tree[u].l + tree[u].r >> 1;
        build(lc, l, mid);
        build(rc, mid + 1, r);
        push_up(u);
    }

    void update (int u, int x, int y, i64 k) {
        if (x <= tree[u].l && tree[u].r <= y) {
            // tree[u].val = ?? 
            // tree[u].tag = ??
            return;
        }
        push_down(u);

        int mid = tree[u].l + tree[u].r >> 1;
        if (x <= mid) update(lc, x, y, k);
        if (y >  mid) update(rc, x, y, k);
        push_up(u);
    }

    i64 query (int u, int x, int y) {
        if (x <= tree[u].l && tree[u].r <= y) {
            return tree[u].val;
        }
        push_down(u);

        int mid = tree[u].l + tree[u].r >> 1;
        i64 ans = 0;
        if (x <= mid) /* ans <-- query(lc, x, y); */
        if (y >  mid) /* ans <-- query(rc, x, y); */
        return ans;
    }
};


线段树入门第一步:结构体和宏定义
线段树是一棵满二叉树,因此,如果当前节点为u,那么我们就可以用u * 2和u * 2 + 1来访问他的两个子节点
因此,我们加入这两句宏定义:
#define lc u << 1
#define rc u << 1 | 1
这里的 << 和 | 是位运算符,其中u<< 1就是u * 2,u << 1 | 1就是u * 2 + 1
我们用lc代表左子节点,rc代表右子节点

接下来,我们用结构体来定义线段树的节点

  typedef struct node {
      int l, r;
      i64 val, tag;   
  } node;

l和r表示当前节点表示的区间,val是你要求的东西(比如区间和、区间最值),tag是懒标记
别问为什么,照着这么写就行了,用结构体封装,后续会方便很多



线段树入门第二步:构造函数
如果想要在使用线段树就像使用STL容器那么自然,那么我们就手动封装起来整个线段树,让他“成为STL容器”

  struct SegTree {
      std::vector<i64> a;
      std::vector<node> tree;

      SegTree (int size) {
          a.resize(size + 10);
          tree.resize(4 * size + 10);
      }
  };

我们在内部定义数组a和线段树数组tree,数组a就是你需要操作的目标序列
线段树数组用先前的node结构体,线段树要开原数组a的四倍空间
我们使用构造函数,重新设置数组和线段树数组的空间,这样的话,在主函数中我们可以这么写:

  int main () {
      int n;
      std::cin >> n;

      SegTree T(n);
      for (int i = 1; i <= n; i++) {
          std::cin >> T.a[i];
      }
      T.build(1, 1, n);

      return 0;
  }

这里的SegTree T(n);就是开了一个针对长度为n的序列的线段树
使用for循环输入T.a数组中的每一位,再调用T.build()函数去建立线段树
而接下来,我们就要进行线段树的建树操作



线段树入门第三步:建立线段树
建立线段树的时候,我们需要通过递归的方式来实现
当建立代表[l, r]这个区间的节点时
接下来就需要去建立代表[l, mid]和[mid + 1, r]的两个节点了
这里的mid = l + r >> 1(线段树为了减少代码量,一些地方采用位运算,这里的 >> 1 就是 ÷ 2 的意思)

当l == r的时候,就代表着到了叶子节点,此时没有子节点,要赋值并return
再层层return的时候,我们要执行push_up操作进行向上传递,更新父节点的值
push_up()函数同样是由自己定义的,去实现自己需要的操作

比如说,我们要求区间和,那我们的push_up()函数就是
tree[u].val = tree[lc].val + tree[rc].val
如果是区间最值,那么
tree[u].val = std::max(tree[lc].val, tree[rc].val)
也可以是区间gcd
tree[u].val = std::__gcd(tree[lc].val, tree[rc].val)

代码如下:

  void push_up (int u) {
      /* tree[u].val = tree[lc].val ?? tree[rc].val */
  }

  void build (int u, int l, int r) {
      tree[u].l = l;
      tree[u].r = r;
      tree[u].tag = 0;

      if (l == r) {
          tree[u].val = a[l];
          return;
      }

      int mid = tree[u].l + tree[u].r >> 1;
      build(lc, l, mid);
      build(rc, mid + 1, r);
      push_up(u);
  }


线段树入门第四步:query函数
给定了需要查询的区间[x, y],现在执行查询操作,从[1, n]区间开始层层向下查找
假设当前查到了代表[l, r]区间的节点,那么我们进行判断

  1. 如果区间[l, r]被完全包含在[x, y]查询的区间内,那么return
  2. 否则,我们要继续向下查找

那么向下查找的时候,是向左,还是向右呢?
我们此时需要再加一个判断,定义mid = tree[u].l + tree[u].r >> 1

  1. 如果x <= mid,那么需要向左子节点查询
  2. 如果y > mid,那么需要向右子节点查询

我们需要定义一个ans,然后在return过程中层层更新,代码如下:
(这里的push_down()操作我们稍后会讲)

  i64 query (int u, int x, int y) {
      if (x <= tree[u].l && tree[u].r <= y) {
          return tree[u].val;
      }
      push_down(u);

      int mid = tree[u].l + tree[u].r >> 1;
      i64 ans = 0;
      if (x <= mid) /* ans <-- query(lc, x, y); */
      if (y >  mid) /* ans <-- query(rc, x, y); */
      return ans;
  }


线段树入门第五步:update和push_down函数
我们到了线段树最重要的部分,这部分是线段树能够始终在O(logn)复杂度完成区间修改的核心

懒标记
试想一下,如果我们给定了一个区间[x, y]需要进行修改/更新
那当我们找到了被[x, y]完全包含的区间[l, r],我们还要不要向下继续更新?
如果不向下更新,那么我们后续查询到下面的节点时,获取的答案就是错误的
但如果我们更新了,更新操作的复杂度就不再是O(logn)了,会退化成O(n),失去了意义
此时我们引入了懒标记

懒标记:顾名思义,懒
如果我们更新到某个被完全包含的区间,不是向下更新,而是打上一个标记,告诉程序:我这里有待更新,会怎么样呢?
当程序下次查询到下面的节点时,因为被告知有待更新,所以会对下面进行更新操作,这样是不是就保证了答案的正确性?
所以,懒标记的作用就是,告诉我们此处需要更新的值,至于如何更新,就看我们的操作了

push_down操作分为三步:

  1. 修改左子节点的val值和tag值
  2. 修改右子节点的val值和tag值
  3. 将当前节点的tag值置为0或者初始值

向下更新完之后,记得向上也要更新,代码如下:

  void push_down (int u) {
      if (tree[u].tag) {
      // Left:
          
      // Right:

      // Final
          tree[u].tag = 0;
      }
  }

  void update (int u, int x, int y, i64 k) {
      if (x <= tree[u].l && tree[u].r <= y) {
          // tree[u].val = ?? 
          // tree[u].tag = ??
          return;
      }
      push_down(u);

      int mid = tree[u].l + tree[u].r >> 1;
      if (x <= mid) update(lc, x, y, k);
      if (y >  mid) update(rc, x, y, k);
      push_up(u);
  }


现在给出一个区间和线段树的模板,可通过洛谷P3372【模板】线段树 1
下面的代码中,重点的地方会通过注释的方法讲解

struct SegTree {
	#define lc u << 1
	#define rc u << 1 | 1

	typedef struct node {
		int l, r;
		i64 sum, tag;
	} node;

	std::vector<i64> a;
	std::vector<node> tree;

	SegTree (int n) {
		a.resize(n + 10);
		tree.resize(4 * n + 10);
	}

	void push_up (int u) {
        // 因为我们求的是区间和,所以在向上更新时,大区间的区间和 = 两个小区间的区间和相加
        // 也就是父节点区间和 = 左子节点区间和 + 右子节点区间和
		tree[u].sum = tree[lc].sum + tree[rc].sum;
	}

	void push_down (int u) {
        // push_down函数是很重要的一部分,注意这里的写法
        // 当懒标记存在的时候,我们进行下传的操作,如果不存在就不用了
		if (tree[u].tag) {
            // 修改左子节点的情况
            // 由于我们是区间和线段树,我们这里的更新操作是“区间内每个数字 + k”
            // 因此我们的tree[lc].sum += 区间长度 * tree[u].tag
            // 区间长度 = (tree[lc].r - tree[lc].l + 1)
            // tree[u].tag是经过一层层传下来的值,意思是,当前区间每个数字需要加上tree[u].tag
			tree[lc].tag += tree[u].tag;
			tree[lc].sum += (tree[lc].r - tree[lc].l + 1) * tree[u].tag;
			
            // 右子节点同理,不再赘述
			tree[rc].tag += tree[u].tag;
			tree[rc].sum += (tree[rc].r - tree[rc].l + 1) * tree[u].tag;

            // 更新完了,就让当前节点的tag值置为0,意思是下次不用再下传啦
			tree[u].tag = 0;
		}
	}

	void build (int u, int l, int r) {
		tree[u].l = l;
		tree[u].r = r;
		tree[u].tag = 0;
			
		if (l == r) {
			tree[u].sum = a[l];
			return;
		}
		
		int mid = tree[u].l + tree[u].r >> 1;
		build(lc, l, mid);
		build(rc, mid + 1, r);
		push_up(u);
	}

	void update (int u, int x, int y, i64 num) {
		if (x <= tree[u].l && tree[u].r <= y) {
            // 当更新操作更新到“被完全包含的区间”时,就不要继续向下走了
            // 只更新当前节点的区间和,在这里打上标记
            // 在后续push_down操作时,以此标记向下传递
            // 注意是 += 而不是 = !!!
            // 因为我们的tag值可能不是0,如果是直接赋值,相当于无视了先前的“修改申请”
			tree[u].sum += (tree[u].r - tree[u].l + 1) * num;
			tree[u].tag += num;
			return;
		}
		push_down(u);
		
		int mid = tree[u].l + tree[u].r >> 1;
		if (x <= mid) update(lc, x, y, num);
		if (y > mid)  update(rc, x, y, num); 
		push_up(u);
	}

	i64 query (int u, int x, int y) {
		if (x <= tree[u].l && tree[u].r <= y) {
			return tree[u].sum;
		}
		push_down(u);
		
		int mid = tree[u].l + tree[u].r >> 1; 
		i64 ans = 0;
		if (x <= mid) ans += query(lc, x, y);
		if (y > mid)  ans += query(rc, x, y);
		return ans; 
	}
};
posted @ 2025-07-25 21:57  _彩云归  阅读(167)  评论(0)    收藏  举报