从0开始编写线段树
常量定义
N=6#线段树范围[1,N]
tree=[0]*(N<<2)
tags=[0]*(N<<2)
空间复杂度O(4n)
左右节点id函数
def left_node(node_id):
return node_id<<1
def right_node(node_id):
return node_id<<1|1
基建树函数
def build(node_id,left,right):
if left==right:
tree[node_id]=nums[left-1]
return
mid=(left+right)>>1
build(left_node(node_id),left,mid)
build(right_node(node_id),mid+1,right)
push_up(node_id)
def push_up(node_id):
tree[node_id]=tree[left_node(node_id)]+tree[right_node(node_id)]
``` graph TD A[开始构建节点] --> B{left==right} B -->|是| C["tree[node_id] = nums[left-1]"] B -->|否| D[计算mid] D --> E[递归构建左子树] D --> F[递归构建右子树] E --> G[push_up更新当前节点] F --> G G --> H[结束] C --> H ```

时间复杂度O(nlogn)
查询函数
def query(query_left,query_right,node_id=1,left=1,right=N):
if (query_left<=left and right<=query_right):
return tree[node_id]
ret=0
mid=(left+right)>>1
if query_left<=mid:
ret+=query(query_left,query_right,left_node(node_id),left,mid)
if query_right>mid:
ret+=query(query_left,query_right,right_node(node_id),mid+1,right)
return ret
对于每个查询,复杂度O(logn)
未优化的区间更新函数
def update(update_left,update_right,add,node_id=1,left=1,right=N):
if(left==right):
tree[left]+=add
return
mid=(left+right)>>1
if update_left<=mid:
update(update_left,update_right,add,left_node(node_id),left,mid)
if update_right>mid:
update(update_left,update_right,add,right_node(node_id),mid+1,right)
push_up(node_id)
时间复杂度O(n^2)
优化的区间更新函数
1.修改query,在查询前先下推表记
def query(query_left,query_right,node_id=1,left=1,right=N):
if (query_left<=left and right<=query_right):
return tree[node_id]
push_down(node_id,left,right)
ret=0
mid=(left+right)>>1
if query_left<=mid:
ret+=query(query_left,query_right,left_node(node_id),left,mid)
if query_right>mid:
ret+=query(query_left,query_right,right_node(node_id),mid+1,right)
return ret

2.配套函数
def push_down(node_id,left,right):
add=tags[node_id]
mid=(left+right)>>1
if (add):
addtag(left_node(node_id),left,mid,add)
addtag(right_node(node_id),mid+1,right,add)
tags[node_id]=0
def addtag(node_id,left,right,add):
tree[node_id]+=add*(right-left+1)
tags[node_id]+=add
3.修改update
def update(update_left,update_right,add,node_id=1,left=1,right=N):
if(update_left<=left and right<=update_right):
addtag(node_id,left,right,add)
return
push_down(node_id,left,right)#updata和query都必须下推
mid=(left+right)>>1
if update_left<=mid:
update(update_left,update_right,add,left_node(node_id),left,mid)
if update_right>mid:
update(update_left,update_right,add,right_node(node_id),mid+1,right)
push_up(node_id)

完整代码
N=6#线段树范围[1,N]
tree=[0]*(N<<2)
tags=[0]*(N<<2)
def left_node(node_id):
return node_id<<1
def right_node(node_id):
return node_id<<1|1
def build(node_id,left,right):
if left==right:
tree[node_id]=nums[left-1]
return
mid=(left+right)>>1
build(left_node(node_id),left,mid)
build(right_node(node_id),mid+1,right)
push_up(node_id)
def push_up(node_id):
tree[node_id]=tree[left_node(node_id)]+tree[right_node(node_id)]
def query(query_left,query_right,node_id=1,left=1,right=N):
if (query_left<=left and right<=query_right):
return tree[node_id]
push_down(node_id,left,right)
ret=0
mid=(left+right)>>1
if query_left<=mid:
ret+=query(query_left,query_right,left_node(node_id),left,mid)
if query_right>mid:
ret+=query(query_left,query_right,right_node(node_id),mid+1,right)
return ret
def update(update_left,update_right,add,node_id=1,left=1,right=N):
if(update_left<=left and right<=update_right):
addtag(node_id,left,right,add)
return
push_down(node_id,left,right)
mid=(left+right)>>1
if update_left<=mid:
update(update_left,update_right,add,left_node(node_id),left,mid)
if update_right>mid:
update(update_left,update_right,add,right_node(node_id),mid+1,right)
push_up(node_id)
def push_down(node_id,left,right):
add=tags[node_id]
mid=(left+right)>>1
if (add):
addtag(left_node(node_id),left,mid,add)
addtag(right_node(node_id),mid+1,right,add)
tags[node_id]=0
def addtag(node_id,left,right,add):
tree[node_id]+=add*(right-left+1)
tags[node_id]+=add
#以下为测试代码
nums=[]
for i in range(N):
nums.append(1)
build(1,1,N)
update(1,6,1)
for i in range(1,N+1):
for j in range(i,N+1):
print("sum([{},{}]) is {}".format(i,j,query(i,j)))
例题

通过代码
n,m=map(int,input().split())
N=n#线段树范围[1,N]
tree=[0]*(N<<2)
tags=[0]*(N<<2)
def left_node(node_id):
return node_id<<1
def right_node(node_id):
return node_id<<1|1
def build(node_id,left,right):
if left==right:
tree[node_id]=nums[left-1]
return
mid=(left+right)>>1
build(left_node(node_id),left,mid)
build(right_node(node_id),mid+1,right)
push_up(node_id)
def push_up(node_id):
tree[node_id]=tree[left_node(node_id)]+tree[right_node(node_id)]
def query(query_left,query_right,node_id=1,left=1,right=N):
if (query_left<=left and right<=query_right):
return tree[node_id]
push_down(node_id,left,right)
ret=0
mid=(left+right)>>1
if query_left<=mid:
ret+=query(query_left,query_right,left_node(node_id),left,mid)
if query_right>mid:
ret+=query(query_left,query_right,right_node(node_id),mid+1,right)
return ret
def update(update_left,update_right,add,node_id=1,left=1,right=N):
if(update_left<=left and right<=update_right):
addtag(node_id,left,right,add)
return
push_down(node_id,left,right)
mid=(left+right)>>1
if update_left<=mid:
update(update_left,update_right,add,left_node(node_id),left,mid)
if update_right>mid:
update(update_left,update_right,add,right_node(node_id),mid+1,right)
push_up(node_id)
def push_down(node_id,left,right):
add=tags[node_id]
mid=(left+right)>>1
if (add):
addtag(left_node(node_id),left,mid,add)
addtag(right_node(node_id),mid+1,right,add)
tags[node_id]=0
def addtag(node_id,left,right,add):
tree[node_id]+=add*(right-left+1)
tags[node_id]+=add
nums=list(map(int,input().split()))
build(1,1,N)
for _ in range(m):
ops=list(map(int,input().split()))
if ops[0]==1:
update(ops[1],ops[2],ops[3])
else:
print(query(ops[1],ops[2]))
程序流程图



浙公网安备 33010602011771号