1.11 上午-线段树 & 线段树合并

前言

勿让将来,辜负曾经

少年立下志向,踏上前去线段树神犇的不归路。然而被线段树虐的体无完肤

正文

知识点

线段树是一个很有用的知识点,用于维护动态的区间信息,而线段树合并,往往用于维护树上信息(一般用于优化 DP)

先聊线段树——

板子就是经典的区间加求区间和问题(啥?你说单点?确定不是 l==r 的情况吗?)。然后就是老生常谈地向下递归分拆区间,打标记随时下传,递归到底后合并上传信息……

然后就延伸出了一堆乱七八糟的信息维护问题(有点恐怖……)

当然,线段树优化 DP,可持久化,树套树,权值线段树等也是极常见的套路

再聊线段树合并——

真的很天才好吧,对于树上结点信息的维护,线段树合并往往具有一些比较优秀的正确性与u复杂度,但是作为一个高阶算法,还是得具体题目具体分析哈!

线段树合并需要用到的工具:动态开点线段树

解决一些树上静态信息或者离线问题,也常用于优化 DP

云落顺手学了一下线段树分裂,线段树上二分,线段树分治……(虽然彼此之间毫无瓜葛,也许分裂与合并之间还有点)

正好安利一下我的另一篇博客,线段树梳理

一题一解

T1 统计和(P2068)

链接

板子,不评价(甚至还是单点修改……)

点击查看代码
#include<bits/stdc++.h>
#define endl '\n'
#define int long long
using namespace std;
const int maxn=1e5+10;
int n,q;
struct Segment_tree{
    struct node{
        int l,r,sum,tag;
    }tr[maxn<<2];
    void pushup(int u){
        tr[u].sum=tr[u<<1].sum+tr[u<<1|1].sum;
        return;
    }
    void pushdown(int u){
        if(!tr[u].tag){
            return;
        }
        tr[u<<1].sum+=((tr[u<<1].r-tr[u<<1].l+1)*tr[u].tag);
        tr[u<<1|1].sum+=((tr[u<<1|1].r-tr[u<<1|1].l+1)*tr[u].tag);
        tr[u<<1].tag+=tr[u].tag;
        tr[u<<1|1].tag+=tr[u].tag;
        tr[u].tag=0;
        return;
    }
    void build(int u,int l,int r){
        tr[u].l=l;
        tr[u].r=r;
        if(l==r){
            return;
        }
        int mid=l+r>>1;
        build(u<<1,l,mid);
        build(u<<1|1,mid+1,r);
        pushup(u);
        return;
    }
    void modify(int u,int ql,int qr,int k){
        int l=tr[u].l,r=tr[u].r;
        if(ql<=l&&qr>=r){
            tr[u].sum+=((r-l+1)*k);
            tr[u].tag+=k;
            return;
        }
        pushdown(u);
        int mid=l+r>>1;
        if(ql<=mid){
            modify(u<<1,ql,qr,k);
        }
        if(qr>mid){
            modify(u<<1|1,ql,qr,k);
        }
        pushup(u);
        return;
    }
    int query(int u,int ql,int qr){
        int l=tr[u].l,r=tr[u].r;
        if(ql<=l&&qr>=r){
            return tr[u].sum;
        }
        pushdown(u);
        int mid=l+r>>1,res=0;
        if(ql<=mid){
            res+=query(u<<1,ql,qr);
        }
        if(qr>mid){
            res+=query(u<<1|1,ql,qr);
        }
        return res;
    }
}Tr;
signed main(){
    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    cin>>n>>q;
    Tr.build(1,1,n);
    while(q--){
        char opt;
        cin>>opt;
        if(opt=='x'){
            int pos,x;
            cin>>pos>>x;
            Tr.modify(1,pos,pos,x);
        }else{
            int l,r;
            cin>>l>>r;
            int ans=Tr.query(1,l,r);
            cout<<ans<<endl;
        }
    }
    return 0;
}

T2 线段树(P3372)

链接

板子,不评价

点击查看代码
#include<bits/stdc++.h>
#define endl '\n'
#define int long long
using namespace std;
const int maxn=1e5+10;
int n,q,a[maxn];
struct Segment_tree{
    struct node{
        int l,r,sum,tag;
    }tr[maxn<<2];
    void pushup(int u){
        tr[u].sum=tr[u<<1].sum+tr[u<<1|1].sum;
        return;
    }
    void pushdown(int u){
        if(!tr[u].tag){
            return;
        }
        tr[u<<1].sum+=((tr[u<<1].r-tr[u<<1].l+1)*tr[u].tag);
        tr[u<<1|1].sum+=((tr[u<<1|1].r-tr[u<<1|1].l+1)*tr[u].tag);
        tr[u<<1].tag+=tr[u].tag;
        tr[u<<1|1].tag+=tr[u].tag;
        tr[u].tag=0;
        return;
    }
    void build(int u,int l,int r){
        tr[u].l=l;
        tr[u].r=r;
        if(l==r){
            tr[u].sum=a[l];
            return;
        }
        int mid=l+r>>1;
        build(u<<1,l,mid);
        build(u<<1|1,mid+1,r);
        pushup(u);
        return;
    }
    void modify(int u,int ql,int qr,int k){
        int l=tr[u].l,r=tr[u].r;
        if(ql<=l&&qr>=r){
            tr[u].sum+=((r-l+1)*k);
            tr[u].tag+=k;
            return;
        }
        pushdown(u);
        int mid=l+r>>1;
        if(ql<=mid){
            modify(u<<1,ql,qr,k);
        }
        if(qr>mid){
            modify(u<<1|1,ql,qr,k);
        }
        pushup(u);
        return;
    }
    int query(int u,int ql,int qr){
        int l=tr[u].l,r=tr[u].r;
        if(ql<=l&&qr>=r){
            return tr[u].sum;
        }
        pushdown(u);
        int mid=l+r>>1,res=0;
        if(ql<=mid){
            res+=query(u<<1,ql,qr);
        }
        if(qr>mid){
            res+=query(u<<1|1,ql,qr);
        }
        return res;
    }
}Tr;
signed main(){
    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    cin>>n>>q;
    for(int i=1;i<=n;i++){
        cin>>a[i];
    }
    Tr.build(1,1,n);
    while(q--){
        int opt;
        cin>>opt;
        if(opt==1){
            int l,r,k;
            cin>>l>>r>>k;
            Tr.modify(1,l,r,k);
        }else{
            int l,r;
            cin>>l>>r;
            int ans=Tr.query(1,l,r);
            cout<<ans<<endl;
        }
    }
    return 0;
}

T3 方差(P1471)

链接

众所周知,线段树需要维护区间信息,这个“信息”需要支持合并。不过方差与平均值很难通过简单的四则运算直接合并。于是乎,来推式子吧!

才不要——

平均值是好求的,因为查询的区间长度是个定值,所以直接维护区间和即可。

而众所又周知,方差等于平方的均值减去均值的平方,所以再维护一个平方和即可。

点击查看代码
#include<iostream>
#include<iomanip>
#define endl '\n'
using namespace std;
const int maxn=1e5+10;
int n,m;
double a[maxn];
struct Segment_tree{
	struct node{
		int l,r;
		double sum1,sum2,tag;
	}tr[maxn<<2];
	inline void pushup(int u){
		tr[u].sum1=tr[u<<1].sum1+tr[u<<1|1].sum1;
		tr[u].sum2=tr[u<<1].sum2+tr[u<<1|1].sum2;
		return;
	}
	void pushdown(int id){
		if (tr[id].tag)
		{
			int ls=id<<1,rs=(id<<1)|1;
			tr[ls].sum2+=(tr[ls].r-tr[ls].l+1)*tr[id].tag*tr[id].tag+2*tr[id].tag*tr[ls].sum1;
			tr[ls].sum1+=(tr[ls].r-tr[ls].l+1)*tr[id].tag;
			tr[rs].sum2+=(tr[rs].r-tr[rs].l+1)*tr[id].tag*tr[id].tag+2*tr[id].tag*tr[rs].sum1;
			tr[rs].sum1+=(tr[rs].r-tr[rs].l+1)*tr[id].tag;
			tr[ls].tag+=tr[id].tag,tr[rs].tag+=tr[id].tag;
			tr[id].tag=0;
		}
	}
	inline void build(int u,int l,int r){
		tr[u].l=l;
		tr[u].r=r;
		if(l==r){
			tr[u].sum1=a[l];
			tr[u].sum2=a[l]*a[l];
			return;
		}
		int mid=l+r>>1;
		build(u<<1,l,mid);
		build(u<<1|1,mid+1,r);
		pushup(u);
		return;
	}
	inline void modify(int u,int ql,int qr,double k){
		int l=tr[u].l,r=tr[u].r;
		if(ql<=l&&qr>=r){
			tr[u].tag+=k;
			tr[u].sum2+=(2*tr[u].sum1*k+k*k*(r-l+1));
			tr[u].sum1+=k*(r-l+1);
			return;
		}
		pushdown(u);
		int mid=l+r>>1;
		if(ql<=mid){
			modify(u<<1,ql,qr,k);
		}
		if(qr>mid){
			modify(u<<1|1,ql,qr,k);
		}
		pushup(u);
		return;
	}
	inline double query1(int u,int ql,int qr){
		int l=tr[u].l,r=tr[u].r;
		if(ql<=l&&qr>=r){
			return tr[u].sum1;
		}
		pushdown(u);
		int mid=l+r>>1;
		double res=0;
		if(ql<=mid){
			res+=query1(u<<1,ql,qr);
		}
		if(qr>mid){
			res+=query1(u<<1|1,ql,qr);
		}
		return res;
	}
	inline double query2(int u,int ql,int qr){
		int l=tr[u].l,r=tr[u].r;
		if(ql<=l&&qr>=r){
			return tr[u].sum2;
		}
		pushdown(u);
		int mid=l+r>>1;
		double res=0;
		if(ql<=mid){
			res+=query2(u<<1,ql,qr);
		}
		if(qr>mid){
			res+=query2(u<<1|1,ql,qr);
		}
		return res;
	}
}Tr;
signed main(){
	ios::sync_with_stdio(0);
	cin.tie(0);
	cout.tie(0);
    cin>>n>>m;
    for(int i=1;i<=n;i++){
    	cin>>a[i];
	}
	Tr.build(1,1,n);
	while(m--){
		int opt,l,r;
		cin>>opt>>l>>r;
    	if(opt==1){
    		double x;
    		cin>>x;
    		Tr.modify(1,l,r,x);
		}else if(opt==2){
			double ans=Tr.query1(1,l,r)*1.0/((r-l+1)*1.0);
			cout<<fixed<<setprecision(4)<<ans<<endl;
		}else{
			double res=Tr.query1(1,l,r)*1.0/((r-l+1)*1.0);
			double ans=Tr.query2(1,l,r)*1.0/((r-l+1)*1.0)-res*res*1.0;
			cout<<fixed<<setprecision(4)<<ans<<endl;
		}
	}
    return 0;
}

T4 上帝造题的七分钟 2 / 花神游历各国 (P4145)

链接

我勒个超绝开平方啊……

其实这道题和曾经云落自己出的一道区间取模的题目有异曲同工之妙。需要注意的就是题目条件给到的一些比较优秀的性质。譬如值域的限制,以及向下取整的操作……还有一个类似废话的性质:

\[\sqrt{1}=1 \]

然后你手摸一下这个最大的值域,约 \(6\) 次开平方后就会出现上述情况。并且 \(1\) 不会继续向下更新——只会一遍遍地开方得到自己

So?

正解是,对于每一次修改,暴力(单点修改)处理当前区间所有 \(>1\) 的数。自然而然想到,维护区间最大值。如果该区间内不存在 \(>1\) 的数,即区间最大值为 \(1\),直接 return 剪枝

查询是板,过!

点击查看代码
#include<iostream>
#include<cmath>
#define int long long
#define endl '\n'
using namespace std;
const int maxn=1e5+10;
int n,m,a[maxn];
struct Segment_tree{
	struct node{
		int l,r,sum,mx;
	}tr[maxn<<2];
	inline void pushup(int u){
		tr[u].sum=tr[u<<1].sum+tr[u<<1|1].sum;
		tr[u].mx=max(tr[u<<1].mx,tr[u<<1|1].mx);
		return;
	}
	inline void build(int u,int l,int r){
		tr[u].l=l;
		tr[u].r=r;
		if(l==r){
			tr[u].sum=a[l];
			tr[u].mx=a[l];
			return;
		}
		int mid=l+r>>1;
		build(u<<1,l,mid);
		build(u<<1|1,mid+1,r);
		pushup(u);
		return;
	}
	inline void modify(int u,int ql,int qr){
		int l=tr[u].l,r=tr[u].r;
		if(l==r){
			tr[u].sum=sqrt(tr[u].sum);
			tr[u].mx=sqrt(tr[u].mx);
			return;
		}
		int mid=l+r>>1;
		if(ql<=mid&&tr[u<<1].mx>1){
			modify(u<<1,ql,qr);
		}
		if(qr>mid&&tr[u<<1|1].mx>1){
			modify(u<<1|1,ql,qr);
		}
		pushup(u);
		return;
	}
	inline int query(int u,int ql,int qr){
		int l=tr[u].l,r=tr[u].r;
		if(ql<=l&&qr>=r){
			return tr[u].sum;
		}
		int mid=l+r>>1,res=0;
		if(ql<=mid){
			res+=query(u<<1,ql,qr);
		}
		if(qr>mid){
			res+=query(u<<1|1,ql,qr);
		}
		return res;
	}
}Tr;
signed main(){
	ios::sync_with_stdio(0);
	cin.tie(0);
	cout.tie(0);
    cin>>n;
    for(int i=1;i<=n;i++){
    	cin>>a[i];
	}
	Tr.build(1,1,n);
	cin>>m;
	while(m--){
		int opt,l,r;
		cin>>opt>>l>>r;
		if(l>r){
			swap(l,r);
		}
    	if(opt==0){
    		Tr.modify(1,l,r);
		}else{
			cout<<Tr.query(1,l,r)<<endl;
		}
	}
    return 0;
}

T5 逆序对(P1908)

链接

权值树状树组维护二维偏序 or 归并排序 or 恶臭的权值线段树

比较板子不多赘述

点击查看代码
#include<iostream>
#include<algorithm>
#define int long long
#define endl '\n'
using namespace std;
const int maxn=5e5+10;
int n,m,a[maxn],b[maxn];
struct Segment_tree{
	struct node{
		int l,r,val,tag;
	}tr[maxn<<2];
	inline void pushup(int u){
		tr[u].val=tr[u<<1].val+tr[u<<1|1].val;
		return;
	}
	inline void pushdown(int u){
		if(tr[u].tag>0){
			tr[u<<1].val+=(tr[u].tag*(tr[u<<1].r-tr[u<<1].l+1));
			tr[u<<1|1].val+=(tr[u].tag*(tr[u<<1|1].r-tr[u<<1|1].l+1));
			tr[u<<1].tag+=tr[u].tag;
			tr[u<<1|1].tag+=tr[u].tag;
			tr[u].tag=0;
		}
		return;
	}
	inline void build(int u,int l,int r){
		tr[u].l=l;
		tr[u].r=r;
		if(l==r){
			return;
		}
		int mid=l+r>>1;
		build(u<<1,l,mid);
		build(u<<1|1,mid+1,r);
		pushup(u);
		return;
	}
	inline void modify(int u,int ql,int qr,int k){
		int l=tr[u].l,r=tr[u].r;
		if(ql<=l&&qr>=r){
			tr[u].val+=k;
			return;
		}
		pushdown(u);
		int mid=l+r>>1;
		if(ql<=mid){
			modify(u<<1,ql,qr,k);
		}
		if(qr>mid){
			modify(u<<1|1,ql,qr,k);
		}
		pushup(u);
		return;
	}
	inline int query(int u,int ql,int qr){
		int l=tr[u].l,r=tr[u].r;
		if(ql<=l&&qr>=r){
			return tr[u].val;
		}
		pushdown(u);
		int mid=l+r>>1,res=0;
		if(ql<=mid){
			res+=query(u<<1,ql,qr);
		}
		if(qr>mid){
			res+=query(u<<1|1,ql,qr);
		}
		return res;
	}
}Tr;
signed main(){
	ios::sync_with_stdio(0);
	cin.tie(0);
	cout.tie(0);
    cin>>n;
    for(int i=1;i<=n;i++){
    	cin>>a[i];
    	b[i]=a[i];
	}
	Tr.build(1,1,n);
	sort(b+1,b+n+1);
	int len=unique(b+1,b+n+1)-b-1;
	int ans=0;
	for(int i=1;i<=n;i++){
		int x=lower_bound(b+1,b+len+1,a[i])-b;
		if(x!=n){
			ans+=Tr.query(1,x+1,n);
		}
		Tr.modify(1,x,x,1);
	}
	cout<<ans<<endl;
    return 0;
}

T6 Promotion Counting P(P3605)

链接

一句话题意:求任一结点 \(u\) 的子树内满足 \(p_v>p_u\) 的结点 \(v\) 数量

比较好想的是类比序列上的逆序对问题,对于每一个结点都开一棵权值线段树。然后当我们递归到该结点的时候,统计其在权值线段树上的前缀和。然而,如果只是静态地将树上的每一个子树拆分出来,一个菊花链就可以卡的找不着北……

注意子树与子树之间信息所对应的权值线段树是可合并的,所以考虑线段树合并进行维护。具体地,对于每一个结点 \(u\),假设我们已经知道了其所有子树所对应的所有线段树信息。那么,我们直接将它们区间一一对应的合并即可,查询答案后直接再将其并入 \(u\) 的线段树

时间复杂度 \(O(n \log n)\)

为什么?肯定会有人抬杠说——线段树是 \(O(n \log n)\) 量级的,一共合并 \(O(n)\) 次,总复杂度应当是 \(O(n^2 \log n)\) 量级的。但是不论是从均摊还是代码实现上,时间复杂度总是有保证的。

  • 均摊。注意到最后合并出来的权值线段树一定是 \(O(n \log n)\) 量级的(因为只有 \(n\) 个结点),而总计合并 \(n\) 次,所以单次合并的均摊时间复杂度为 \(O(\log n)\)
  • 代码实现。注意到线段树合并采用的是动态开点线段树,并且 merge 过程中空结点是直接维护并跳过该结点及其子树对应区间的合并的,感性理解一下总时间复杂度自然是 \(O(n \log n)\) 滴!

递归顺序,自根向叶;合并顺序,自叶向根

点击查看代码
#include<bits/stdc++.h>
using namespace std;
const int maxn=1e5+10,inf=1e9;
int n,p[maxn];
int head[maxn],tot;
struct Edge{
    int to,nxt;
}e[maxn<<1];
struct Segment_tree{
    struct node{
        int l,r,sum;
    }tr[maxn<<5];
    int rt[maxn],cnt;
    void pushup(int u){
        tr[u].sum=tr[tr[u].l].sum+tr[tr[u].r].sum;
        return;
    }
    int modify(int u,int l,int r,int pos,int k){
        if(u==0){
            u=++cnt;
        }
        if(l==r){
            tr[u].sum+=k;
            return u;
        }
        int mid=l+r>>1;
        if(pos<=mid){
            tr[u].l=modify(tr[u].l,l,mid,pos,k);
        }else{
            tr[u].r=modify(tr[u].r,mid+1,r,pos,k);
        }
        pushup(u);
        return u;
    }
    int query(int u,int l,int r,int ql,int qr){
        if(u==0){
            return 0;
        }
        if(ql<=l&&qr>=r){
            return tr[u].sum;
        }
        int mid=l+r>>1,res=0;
        if(ql<=mid){
            res+=query(tr[u].l,l,mid,ql,qr);
        }
        if(qr>mid){
            res+=query(tr[u].r,mid+1,r,ql,qr);
        }
        return res;
    }
    int merge(int x,int y,int l,int r){
        if(!x||!y){
            return x+y;
        }
        if(l==r){
            tr[x].sum+=tr[y].sum;
            return x;
        }
        int mid=l+r>>1;
        tr[x].l=merge(tr[x].l,tr[y].l,l,mid);
        tr[x].r=merge(tr[x].r,tr[y].r,mid+1,r);
        pushup(x);
        return x;
    }
}Tr;
int ans[maxn];
inline void add(int u,int v){
    e[++tot].to=v;
    e[tot].nxt=head[u];
    head[u]=tot;
    return;
}
inline void dfs(int u){
    for(int i=head[u];i;i=e[i].nxt){
        int v=e[i].to;
        dfs(v);
        Tr.rt[u]=Tr.merge(Tr.rt[u],Tr.rt[v],1,inf);
    }
    ans[u]=Tr.query(Tr.rt[u],1,inf,p[u]+1,inf);
    Tr.rt[u]=Tr.modify(Tr.rt[u],1,inf,p[u],1);
    return;
}
int main(){
    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    cin>>n;
    for(int i=1;i<=n;i++){
        cin>>p[i];
    }
    for(int i=2;i<=n;i++){
        int fa;
        cin>>fa;
        add(fa,i);
    }
    dfs(1);
    for(int i=1;i<=n;i++){
        cout<<ans[i]<<endl;
    }
    return 0;
}

T7 更为厉害(P3899)

链接

一句话题意:求出对于每个结点 \(a\),所有距离 \(a\) 不大于 \(k\) 的结点 \(b\) 并满足结点 \(a,b\) 都是结点 \(c\) 的祖先有序三元组 \((a,b,c)\) 的个数

又是一个树上静态信息维护问题

还是要对所有结点都做相应统计,并且发现,与上一道题类似,对于每个结点 \(a\),结点 \(b,c\) 总是满足一些奇奇怪怪的限制条件。种种现象暗示着我们这又是一道恶臭的经典的线段树合并题目

先考虑将结点 \(a\) 固定,如何计算答案?我们发现,这里需要一个分讨,即讨论 \(b\)\(a\) 的祖孙关系。

  1. \(b\)\(a\) 的祖先,此时只要 \(c\)\(a\) 子树内就对答案有贡献。具体地,贡献为 \(\min \lbrace \text{dep}_u-1,k \rbrace \times (\text{sz}_u - 1)\)

  2. \(a\)\(b\) 的祖先,显然我们需要查询 \(b\) 所可能的区间对应的答案,即统计一个区间 \([\text{dep}_u+1,\text{dep}_u+k]\) 内所有满足条件的 \(c\)

我们惊奇地发现,一个树上问题居然转化成了一个区间问题。二这个题还存在更优秀的性质,即对于上述的情况二,发现结点 \(b\) 与结点 \(c\) 之间的奇奇怪怪的限制条件也可以被刻画成 \(\text{dep}_b < \text{dep}_c\)

当然,有人会说——这玩意加上时间戳的维度就是一个经典二维数点问题,但是嘛,这不是学习线段树合并的做法咩……

于是乎,题目转化为,对于每个 \(a\),计算 \(b \in [\text{dep}_u+1,\text{dep}_u+k]\)\(c\) 的贡献

考虑暴力是怎么做的,统计出每个结点 \(b\) 的位置,然后加上 \(b\) 子树大小 \(-1\)

然后注意到一个词语——区间!这提示我们可以将这些与子树大小有关的信息放在线段树上处理。那么,完成这道题目的流程呼之欲出

  1. 从根结点向下递归,沿递归路径记录父结点、子树大小、深度等信息

  2. 对于遍历到的结点 \(u\),为其开一棵线段树,并在线段树的 \(\text{dep}_u\) 处插入权值 \(\text{sz}_u-1\)

  3. 自下而上,合并线段树的信息

  4. 预处理出所有结点对应的线段树后,回答询问。答案由两部分贡献组成,分别为 \(\min \lbrace \text{dep}_u-1,k \rbrace \times (\text{sz}_u - 1)\) 以及 \(\operatorname{query} (\text{rt}_u,1,n,\text{dep}_u+1,\text{dep}_u+k)\)

一些细节:

  • 动态开点线段树的树组空间大小为 \(O(M \log N)\),其中 \(M\) 表示操作次数(即 modify 次数),\(N\) 表示结点个数(本题中为 \(n\)

  • 此题的线段树合并不能直接由 \(y\) 并向 \(x\),而是应当新开一个线段树根结点 \(now\),把线段树 \(x\) 和线段树 \(y\) 合并向线段树 \(now\)。因为答案的查询是需要 \(x\) 线段树以及 \(y\) 线段树对应的信息的

  • 贡献是 \(\text{sz}_u-1\) 不是 \(\text{sz}_u\)

  • 不开祖宗见 long long

点击查看代码
#include<bits/stdc++.h>
#define endl '\n'
#define int long long
using namespace std;
const int maxn=3e5+10;
int n,q;
int head[maxn],tot;
struct Edge{
	int to,nxt;
}e[maxn<<1];
int sz[maxn],dep[maxn];
struct Segment_tree{
    struct node{
        int l,r,sum;
    }tr[maxn<<6];
    int rt[maxn],cnt;
    void pushup(int u){
        tr[u].sum=tr[tr[u].l].sum+tr[tr[u].r].sum;
        return;
    }
    int modify(int u,int l,int r,int pos,int k){
        if(u==0){
            u=++cnt;
        }
        if(l==r){
            tr[u].sum+=k;
            return u;
        }
        int mid=l+r>>1;
        if(pos<=mid){
            tr[u].l=modify(tr[u].l,l,mid,pos,k);
        }else{
            tr[u].r=modify(tr[u].r,mid+1,r,pos,k);
        }
        pushup(u);
        return u;
    }
    int query(int u,int l,int r,int ql,int qr){
        if(u==0){
            return 0;
        }
        if(ql<=l&&qr>=r){
            return tr[u].sum;
        }
        int mid=l+r>>1,res=0;
        if(ql<=mid){
            res+=query(tr[u].l,l,mid,ql,qr);
        }
        if(qr>mid){
            res+=query(tr[u].r,mid+1,r,ql,qr);
        }
        return res;
    }
    int merge(int x,int y,int l,int r){
        if(!x||!y){
            return x+y;
        }
        if(l==r){
            int now=++cnt;
            tr[now].sum=tr[x].sum+tr[y].sum;
            return now;
        }
        int mid=l+r>>1,now=++cnt;
        tr[now].l=merge(tr[x].l,tr[y].l,l,mid);
        tr[now].r=merge(tr[x].r,tr[y].r,mid+1,r);
        pushup(now);
        return now;
    }
}Tr;
inline void add(int u,int v){
	e[++tot].to=v;
	e[tot].nxt=head[u];
	head[u]=tot;
	return;
}
inline void dfs(int u,int fa){
    dep[u]=dep[fa]+1;
	sz[u]=1;
	for(int i=head[u];i;i=e[i].nxt){
		int v=e[i].to;
		if(v==fa){
			continue;
		}
		dfs(v,u);
		sz[u]+=sz[v];
	}
	Tr.rt[u]=Tr.modify(Tr.rt[u],1,n,dep[u],sz[u]-1);
	for(int i=head[u];i;i=e[i].nxt){
		int v=e[i].to;
		if(v==fa){
			continue;
		}
		Tr.rt[u]=Tr.merge(Tr.rt[u],Tr.rt[v],1,n);
	}
	return;
}
signed main(){
	ios::sync_with_stdio(0);
	cin.tie(0);
	cout.tie(0);
	cin>>n>>q;
	for(int i=1;i<=n-1;i++){
		int u,v;
		cin>>u>>v;
		add(u,v);
		add(v,u);
	}
	dfs(1,0);
	while(q--){
		int u,k;
		cin>>u>>k;
		int ans1=1ll*min(dep[u]-1,k)*(sz[u]-1);
		int ans2=Tr.query(Tr.rt[u],1,n,dep[u]+1,dep[u]+k);
		cout<<ans1+ans2<<endl;
	}
	return 0;
}

T8 命运(P6773)

链接

两句话题意:给定一棵 \(n\) 个点的树和 \(m\) 条限制,你可以给树边赋 \(0/1\) 的权值。对于所有限制 \((u,v)\)(保证 \(v\)\(u\) 的祖先)你需要保证 \(u\)\(v\) 上至少有一条边的权值为 \(1\),求赋值方案数

黑题 \(++\)

这是一道经典的线段树合并优化 DP 问题,方案数计算问题不是什么组合数学就是 DP,也有可能两个一块用,再套一个数据结构优化什么的……

圆规正传(玩梗),对于同一个结点的 \(u\) 的所有约束 \((u,v_1),(u,v_2),...,(u,v_p)\),发现只需要满足由 \(v\) 中深度最大的约束即可,即满足该约束 \((u,v_{mx})\),其中 \(mx\) 表示 \(\max \limits_{1 \le i \le tot} \lbrace \text{dep}_{v_i} \rbrace\) 中对应的 \(v_i\)

再结合树的形态,考虑一个与约束最大深度有关的树形 DP

为了方便描述,对于一个约束,我们令其为 \((x,y)\),且 \(y\)\(x\) 的祖先;定义一个结点到点集的映射 \(\operatorname{subtree()}\),其中 \(\operatorname{subtree(u)}\) 表示 \(u\) 子树内所有结点所构成的集合(含 \(u\)

我们记 \(f_{u,i}\) 表示同时满足如下条件的方案数:

  • \((x,y)\) 的既定定义,显然不存在 \(y \in \operatorname{subtree}(u) \land x \notin \operatorname{subtree}(u)\) 的情况

  • 任意的形如 \(x \in \operatorname{subtree}(u) \land y \in \operatorname{subtree}(u)\) 的约束已经被满足

  • 不计算形如 \(x \notin \operatorname{subtree}(u) \land y \notin \operatorname{subtree}(u)\) 的约束所造成的贡献

  • 在所有形如 \(x \in \operatorname{subtree}(u) \land y \notin \operatorname{subtree}(u)\) 的约束中,有 \(i = \max \limits_{y} \lbrace \text{dep}_y \rbrace\)

看着有点累?翻译成人话就是——

在所有与 \(u\) 子树相关的约束中,被包含的约束已经全部被满足,横穿 \(u\) 子树的所有约束 \((x,y)\) 中深度最大的 \(y\) 的深度为 \(i\) 的所有给边赋予权值的方案数

特别地,我们定义 \(f_{u,0}\) 表示不存在部分包含于 \(u\) 子树约束的方案数

是转移时间!自然是 \(v \to u\) 的套路转移方式(\(v\)\(u\) 的儿子)。考虑到对边 \((u,v)\) 赋值的情况不同会带来不同的影响,于是乎进行一下下分类讨论

因为子树 \(v\) 存在一个枚举顺序,我们用 \(f_{u,i}\) 表示当前已经计算出的答案,而 \(f'(u,i)\) 表示插入新子树 \(v\) 后更新的答案

火腿肠大巨说:“可以类比树上背包的思路,\(f_{u,i}\) 表示已经维护好的子树前缀的答案,\(f'_{u,i}\) 表示插入子树 \(v\) 后新前缀的答案”
HTC 的理解好奇妙呢~

  1. \((u,v)\) 赋值为 \(1\)

先扔一个式子:

\[f'_{u,i} \gets \sum_{j=0}^{\text{dep}_u} f_{u,i} \times f_{v,j} \]

并不是很难理解,你注意到当 \((u,v)\) 的边权赋值为 \(1\) 时,所有与 \(v\) 子树有关的约束一定全被满足了,于是乎 \(j\) 也就无所谓了。而对于当前维护的前缀答案 \(f'_{u,i}\),最大深度 \(i\) 无法由 \(v\) 子树来提供,仅能由原来的前缀 \(f_{u,i}\) 来提供,然后就出现了上面的式子咯!(建议手动推理加深理解)

  1. \((u,v)\) 赋值为 \(0\)

可以类比情况一的做法,注意到此时最大深度 \(i\) 既可以由曾经的前缀 \(f_{u,i}\) 提供,也可以由当前的子树 \(v\) 提供,所以——式子如下:

\[f'_{u,i} \gets \sum_{j=0}^{i} f_{u,i} \times f_{v,j} + \sum_{j=0}^{i-1} f_{u,j} \times f_{v,i} \]

第一项表示最大深度仍旧有曾经的前缀提供,此时 \(v\) 子树能贡献到的最大深度有了值域限制,具体地,形如 \([0,i]\)。第二项表示最大深度由子树 \(v\) 来提供,和第一项类似,交换一下 \(i,j\) 即可

注意第二项的 \(j\) 的上界是 \(i-1\) 否则会重复计算前缀与子树 \(v\) 同时贡献最大值的情况

综上所述,我们有转移方程:

\[f'_{u,i} \gets \sum_{j=0}^{\text{dep}_u} f_{u,i} \times f_{v,j} + \sum_{j=0}^{i} f_{u,i} \times f_{v,j} + \sum_{j=0}^{i-1} f_{u,j} \times f_{v,i} \]

依据上面的转移方程,我们有了一个可以获得 \(60pts\)\(O(n^2)\) 做法,并不能满足捏……

但是里面出现大量的类似前缀和形式的式子,所以我们做一步转化——记 \(g_{x,k}\) 满足 \(g_{x,k}=\sum_{j=1}^{k} f_{x,j}\)

那么式子就可以写成如下形式:

\[f'_{u,i} = f_{u,i} \times g_{v,\text{dep}_u} + f_{u,i} \times g_{v,i} + f_{v,i} \times g_{u,i-1} \]

然后,就是一个小 trick:线段树合并维护前缀和。具体来说,我们首先先把转移方程式的第二维度直接搬到线段树 \(u\) 上,用线段树维护 \(f_{u,0} \sim f_{u,n}\) 中任意一个区间的和。记录两个动态更新的 \(su,sv\),表示当前转移方程的 \(g_{u,i}\)\(g_{v,i}\)

实现过程是这样的,线段树合并是要将子结点 \(v\) 合并到父结点 \(u\) 上,需要,也就是两棵线段树的合并,也就是线段树上相对应的结点的合并。我们要对合并的情况进行分类讨论:

由于代码的变量名冲突问题,合并部分的代码 \(x\) 对应 \(u\)\(y\) 对应 \(v\)
第一项是个全局查询(\(\text{dep}_u\) 是个定值),撇出去单独计算

  1. 待合并的两个结点都是空结点,直接 return 0 即可
if(x==0&&y==0){
    return 0;
}
  1. \(u\) 结点是空结点,\(v\) 结点不是空结点。先更新前缀和 \(s_v\),然后回到这个 DP 式子

\[f'_{u,i} \gets f_{u,i} \times g_{v,i} + f_{v,i} \times g_{u,i-1} \]

发现 DP 状态的更新形如乘法修改(具体地,乘上的数就是前缀和 \(su\)),也就是更新一下答案和懒标记

if(x==0){//没有第一项
	sv=(sv+tr[y].sum)%mod;//更新前缀和
    //线段树维护修改操作
	tr[y].tag=(tr[y].tag*su)%mod;
	tr[y].sum=(tr[y].sum*su)%mod;
	return y;
}
  1. 同理,\(v\) 结点是空结点,\(u\) 结点不是空结点。完全可以类比情况二,不多赘述

\[f'_{u,i} \gets f_{u,i} \times g_{v,i} + f_{v,i} \times g_{u,i-1} \]

if(y==0){//没有第二项
	su=(su+tr[x].sum)%mod;//更新前缀和
    //线段树维护修改操作
	tr[x].tag=(tr[x].tag*sv)%mod;
	tr[x].sum=(tr[x].sum*sv)%mod;
	return x;
}
  1. 都不是空结点,如果存在子结点就向下递归,直到叶子结点。还是依照下面的式子进行更新

\[f'_{u,i} \gets f_{u,i} \times g_{v,i} + f_{v,i} \times g_{u,i-1} \]

不过这里有一个代码实现的小细节,因为式子的第二项是需要 \(g_{u,i-1}\) 而不是 \(g_{u,i}\),所以 \(su\) 的更新要放在 DP 状态更新的后面……

if(l==r){
	int cu=tr[x].sum,cv=tr[y].sum;
	sv=(sv+cv)%mod;
	tr[x].sum=(tr[x].sum*sv+tr[y].sum*su)%mod;
	su=(su+cu)%mod;
	return x;
}

向下递归的过程不要忘记下传标记以及上传信息,一个细节——\(su,sv\) 都是在整个线段树合并的过程中动态更新的,所以需要设为全局变量或者传引用进去……

小结一下,这道题目的流程大概是这样的

  1. 审题,发现关于约束的一些奇妙性质,并进一步发掘出深度的影响

  2. 考虑 DP 计数,设计一个和约束最大深度有关的 DP 状态

  3. 对于 DP 的转移,需要一些耐心做分类讨论,以及需要注意以一个重复情况的判断

  4. 得出 \(O(n^2)\) 的解法,考虑线段树合并优化

  5. 观察式子的一些优秀的结构,注意到前缀和!

  6. 依照 DP 转移式将第二维度搬上线段树,统计区间贡献并且维护两组前缀和(第一项可以甩出去整体算)

  7. 合并的时候注意一些细节(比如先更新前缀和还是先更新 DP 状态,上传下传信息等)

  8. 取模等若干细节问题(以及漫长的 debug 过程)

点击查看代码
#include<iostream>
#include<vector>
#define int long long
using namespace std;
const int maxn=5e5+10,mod=998244353;
int n,m;
int head[maxn],tot;
struct Edge{
	int to,nxt;
}e[maxn<<1];
int dep[maxn];
vector<int> p[maxn];
struct Segment_tree{
	struct node{
		int l,r,sum,tag;
	}tr[maxn<<5];
	int rt[maxn],cnt;
	void pushup(int u){
		tr[u].sum=(tr[tr[u].l].sum+tr[tr[u].r].sum)%mod;
		return;
	}
	void pushdown(int u){
		if(tr[u].tag==1){
			return;
		}
		int k=tr[u].tag;
		tr[tr[u].l].sum=k*tr[tr[u].l].sum%mod;
		tr[tr[u].r].sum=k*tr[tr[u].r].sum%mod;
		tr[tr[u].l].tag=k*tr[tr[u].l].tag%mod;
		tr[tr[u].r].tag=k*tr[tr[u].r].tag%mod;
		tr[u].tag=1;
		return;
	}
	int modify(int u,int l,int r,int pos,int k){
		if(u==0){
			u=++tot;
		}
		if(l==r){
			tr[u].tag=1;
			tr[u].sum=k;
			return u;
		}
		pushdown(u);
		int mid=l+r>>1;
		if(pos<=mid){
			tr[u].l=modify(tr[u].l,l,mid,pos,k);
		}else{
			tr[u].r=modify(tr[u].r,mid+1,r,pos,k);
		}
		pushup(u);
		return u;
	}
	int query(int u,int l,int r,int ql,int qr){
		if(ql<=l&&qr>=r){
			return tr[u].sum;
		}
		pushdown(u);
		int mid=l+r>>1,res=0;
		if(ql<=mid){
			res=(res+query(tr[u].l,l,mid,ql,qr))%mod;
		}
		if(qr>mid){
			res=(res+query(tr[u].r,mid+1,r,ql,qr))%mod;
		}
		return res;
	}
	int merge(int x,int y,int l,int r,int &su,int &sv){
		if(x==0&&y==0){
			return 0;
		}
		if(x==0){
			sv=(sv+tr[y].sum)%mod;
			tr[y].tag=(tr[y].tag*su)%mod;
			tr[y].sum=(tr[y].sum*su)%mod;
			return y;
		}
		if(y==0){
			su=(su+tr[x].sum)%mod;
			tr[x].tag=(tr[x].tag*sv)%mod;
			tr[x].sum=(tr[x].sum*sv)%mod;
			return x;
		}
		if(l==r){
			int cu=tr[x].sum,cv=tr[y].sum;
			sv=(sv+cv)%mod;
			tr[x].sum=(tr[x].sum*sv+tr[y].sum*su)%mod;
			su=(su+cu)%mod;
			return x;
		}
		pushdown(x);
		pushdown(y);
		int mid=l+r>>1;
		tr[x].l=merge(tr[x].l,tr[y].l,l,mid,su,sv);
		tr[x].r=merge(tr[x].r,tr[y].r,mid+1,r,su,sv);
		pushup(x);
		return x;
	}
}Tr;
inline void add(int u,int v){
	e[++tot].to=v;
	e[tot].nxt=head[u];
	head[u]=tot;
	return;
}
inline void dfs(int u,int fa){
	dep[u]=dep[fa]+1;
	int d=0;
	for(int i:p[u]){
		d=max(d,dep[i]);
	}
	Tr.rt[u]=Tr.modify(Tr.rt[u],0,n,d,1);
	int su=0,sv=0;
	for(int i=head[u];i;i=e[i].nxt){
		int v=e[i].to;
		if(v==fa){
			continue;
		}
		dfs(v,u);
		su=0;
		sv=Tr.query(Tr.rt[v],0,n,0,dep[u]);
		Tr.rt[u]=Tr.merge(Tr.rt[u],Tr.rt[v],0,n,su,sv);
	}
	return;
}
signed main(){
	ios::sync_with_stdio(0);
	cin.tie(0);
	cout.tie(0);
	cin>>n;
	for(int i=1;i<=n-1;i++){
		int u,v;
		cin>>u>>v;
		add(u,v);
		add(v,u);
	}
	cin>>m;
	for(int i=1;i<=m;i++){
		int u,v;
		cin>>u>>v;
		p[v].push_back(u);
	}
	dfs(1,0);
	cout<<Tr.query(Tr.rt[1],0,n,0,0)<<endl;
	return 0;
}

后记

嘻嘻——

完结撒花!

posted @ 2025-02-26 15:52  sunxuhetai  阅读(37)  评论(0)    收藏  举报