从0开始编写线段树

常量定义
N=6#线段树范围[1,N]

tree=[0]*(N<<2)
tags=[0]*(N<<2)

空间复杂度O(4n)

左右节点id函数
def left_node(node_id):
    return node_id<<1

def right_node(node_id):
    return node_id<<1|1
基建树函数
def build(node_id,left,right):
    if left==right:
        tree[node_id]=nums[left-1]
        return
    mid=(left+right)>>1
    build(left_node(node_id),left,mid)
    build(right_node(node_id),mid+1,right)
    push_up(node_id)
  
def push_up(node_id):
    tree[node_id]=tree[left_node(node_id)]+tree[right_node(node_id)]

1768569367305
时间复杂度O(nlogn)

查询函数
def query(query_left,query_right,node_id=1,left=1,right=N):
    if (query_left<=left and right<=query_right):
        return tree[node_id]
    ret=0
    mid=(left+right)>>1
    if query_left<=mid:
        ret+=query(query_left,query_right,left_node(node_id),left,mid)
    if query_right>mid:
        ret+=query(query_left,query_right,right_node(node_id),mid+1,right)
    return ret

对于每个查询,复杂度O(logn)

未优化的区间更新函数
def update(update_left,update_right,add,node_id=1,left=1,right=N):
    if(left==right):
        tree[left]+=add
        return
    mid=(left+right)>>1
    if update_left<=mid:
        update(update_left,update_right,add,left_node(node_id),left,mid)
    if update_right>mid:
        update(update_left,update_right,add,right_node(node_id),mid+1,right)
    push_up(node_id)

时间复杂度O(n^2)

优化的区间更新函数

1.修改query,在查询前先下推表记

def query(query_left,query_right,node_id=1,left=1,right=N):
    if (query_left<=left and right<=query_right):
        return tree[node_id]
    push_down(node_id,left,right)
    ret=0
    mid=(left+right)>>1
    if query_left<=mid:
        ret+=query(query_left,query_right,left_node(node_id),left,mid)
    if query_right>mid:
        ret+=query(query_left,query_right,right_node(node_id),mid+1,right)
    return ret

1768573373250

2.配套函数

def push_down(node_id,left,right):
    add=tags[node_id]
    mid=(left+right)>>1
    if (add):
        addtag(left_node(node_id),left,mid,add)
        addtag(right_node(node_id),mid+1,right,add)
        tags[node_id]=0

def addtag(node_id,left,right,add):
    tree[node_id]+=add*(right-left+1)
    tags[node_id]+=add

3.修改update

def update(update_left,update_right,add,node_id=1,left=1,right=N):
    if(update_left<=left and right<=update_right):
        addtag(node_id,left,right,add)
        return
    push_down(node_id,left,right)#updata和query都必须下推
    mid=(left+right)>>1
    if update_left<=mid:
        update(update_left,update_right,add,left_node(node_id),left,mid)
    if update_right>mid:
        update(update_left,update_right,add,right_node(node_id),mid+1,right)
    push_up(node_id)

1768573442379

完整代码
N=6#线段树范围[1,N]

tree=[0]*(N<<2)
tags=[0]*(N<<2)

def left_node(node_id):
    return node_id<<1

def right_node(node_id):
    return node_id<<1|1

def build(node_id,left,right):
    if left==right:
        tree[node_id]=nums[left-1]
        return
    mid=(left+right)>>1
    build(left_node(node_id),left,mid)
    build(right_node(node_id),mid+1,right)
    push_up(node_id)
  
def push_up(node_id):
    tree[node_id]=tree[left_node(node_id)]+tree[right_node(node_id)]

def query(query_left,query_right,node_id=1,left=1,right=N):
    if (query_left<=left and right<=query_right):
        return tree[node_id]
    push_down(node_id,left,right)
    ret=0
    mid=(left+right)>>1
    if query_left<=mid:
        ret+=query(query_left,query_right,left_node(node_id),left,mid)
    if query_right>mid:
        ret+=query(query_left,query_right,right_node(node_id),mid+1,right)
    return ret

def update(update_left,update_right,add,node_id=1,left=1,right=N):
    if(update_left<=left and right<=update_right):
        addtag(node_id,left,right,add)
        return
    push_down(node_id,left,right)
    mid=(left+right)>>1
    if update_left<=mid:
        update(update_left,update_right,add,left_node(node_id),left,mid)
    if update_right>mid:
        update(update_left,update_right,add,right_node(node_id),mid+1,right)
    push_up(node_id)

def push_down(node_id,left,right):
    add=tags[node_id]
    mid=(left+right)>>1
    if (add):
        addtag(left_node(node_id),left,mid,add)
        addtag(right_node(node_id),mid+1,right,add)
        tags[node_id]=0

def addtag(node_id,left,right,add):
    tree[node_id]+=add*(right-left+1)
    tags[node_id]+=add
#以下为测试代码
nums=[]
for i in range(N):
    nums.append(1)
  
build(1,1,N)

update(1,6,1)
for i in range(1,N+1):
    for j in range(i,N+1):
        print("sum([{},{}]) is {}".format(i,j,query(i,j)))
例题

1768572446618

通过代码

n,m=map(int,input().split())
N=n#线段树范围[1,N]

tree=[0]*(N<<2)
tags=[0]*(N<<2)

def left_node(node_id):
    return node_id<<1

def right_node(node_id):
    return node_id<<1|1

def build(node_id,left,right):
    if left==right:
        tree[node_id]=nums[left-1]
        return
    mid=(left+right)>>1
    build(left_node(node_id),left,mid)
    build(right_node(node_id),mid+1,right)
    push_up(node_id)
  
def push_up(node_id):
    tree[node_id]=tree[left_node(node_id)]+tree[right_node(node_id)]

def query(query_left,query_right,node_id=1,left=1,right=N):
    if (query_left<=left and right<=query_right):
        return tree[node_id]
    push_down(node_id,left,right)
    ret=0
    mid=(left+right)>>1
    if query_left<=mid:
        ret+=query(query_left,query_right,left_node(node_id),left,mid)
    if query_right>mid:
        ret+=query(query_left,query_right,right_node(node_id),mid+1,right)
    return ret

def update(update_left,update_right,add,node_id=1,left=1,right=N):
    if(update_left<=left and right<=update_right):
        addtag(node_id,left,right,add)
        return
    push_down(node_id,left,right)
    mid=(left+right)>>1
    if update_left<=mid:
        update(update_left,update_right,add,left_node(node_id),left,mid)
    if update_right>mid:
        update(update_left,update_right,add,right_node(node_id),mid+1,right)
    push_up(node_id)

def push_down(node_id,left,right):
    add=tags[node_id]
    mid=(left+right)>>1
    if (add):
        addtag(left_node(node_id),left,mid,add)
        addtag(right_node(node_id),mid+1,right,add)
        tags[node_id]=0

def addtag(node_id,left,right,add):
    tree[node_id]+=add*(right-left+1)
    tags[node_id]+=add
  
nums=list(map(int,input().split()))
build(1,1,N)
for _ in range(m):
    ops=list(map(int,input().split()))
    if ops[0]==1:
        update(ops[1],ops[2],ops[3])
    else:
        print(query(ops[1],ops[2]))

程序流程图

1768573182670

1768572216312

posted @ 2026-01-16 22:42  404CatNotFound  阅读(10)  评论(1)    收藏  举报