1 class TreeNode:
2 def __init__(self, left, right, mx):
3 self.left = left
4 self.right = right
5 self.mx = mx
6
7
8 # 线段树类
9 # 以_开头的是递归实现
10 class Tree(object):
11 def __init__(self, n, arr):
12 self.n = n
13 self.max_size = 4 * n
14 self.tree = [TreeNode() for _ in range(self.max_size)] # 维护一个TreeNode数组
15 self.arr = arr
16
17 # index从1开始
18 def _build(self, index, left, right):
19 self.tree[index].left = left
20 self.tree[index].right = right
21 if left == right:
22 self.tree[index].mx = left
23 else:
24 mid = (left + right) // 2
25 self._build(index * 2, left, mid)
26 self._build(index * 2, mid + 1, right)
27 self.tree[index].mx = max(self.tree[index * 2].mx, self.tree[index * 2 + 1].mx)
28
29 # 构建线段树
30 def build(self):
31 self._build(1, 1, self.n)
32
33 def _update(self, ind, k, v): # 点更新,将arr[k]的值改成v
34 if self.tree[ind].left == self.tree[ind].right and self.tree[ind].left == k:
35 self.tree[ind].mx = v
36 return
37 mid = (self.tree[ind].left + self.tree[ind].right) // 2
38 if k <= mid:
39 self._update(ind * 2, k, v)
40 else:
41 self._update(ind * 2 + 1, k, v)
42
43 # 回归时更新
44 self.tree[ind].mx = max(self.tree[ind * 2].mx, self.tree[ind * 2 + 1].mx)
45
46 # 区间覆盖
47 def _query(self, ind, l, r):
48 if self.tree[ind].left >= l and self.tree[ind].right <= r:
49 return self.tree[ind].mx
50 mid = (self.tree[ind].left + self.tree[ind].right) // 2
51 res = float("-inf") # 局部变量
52 if l <= mid:
53 res = max(res, self._query(ind * 2, l, r))
54 if r > mid:
55 res = max(res, self._query(ind * 2 + 1, l, r))
56 return res
57
58 # 区间相等
59 def _query2(self, ind, l, r):
60 if self.tree[ind].left == l and self.tree[ind].right == r:
61 return self.tree[ind].mx
62 mid = (self.tree[ind].left + self.tree[ind].right) // 2
63 if r < mid:
64 return self._query2(ind * 2, l, r)
65 elif l > mid:
66 return self._query2(ind * 2 + 1, l, r)
67 else:
68 return max(self._query2(ind * 2, l, mid), self._query2(ind * 2 + 1, mid + 1, r))
69
70 def query(self, ql, qr):
71 return self._query(1, ql, qr)
72
73 # 深度遍历打印数组
74 def _show_arr(self, i):
75 if self.tree[i].left == self.tree[i].right and self.tree[i].left != -1:
76 print(self.tree[i].mx, end=" ")
77 if 2 * i < len(self.tree):
78 self._show_arr(i * 2)
79 self._show_arr(i * 2 + 1)
80
81 # 显示更新后的数组的样子
82 def show_arr(self, ):
83 self._show_arr(1)