线段树合并学习笔记

线段树合并是为了保证「合并两个动态开点线段树的信息」这个操作的复杂度的。

暴力合并两个满二叉树的复杂度一次就是其节点数 \(O(n)\),完全不能接受。


考虑现在有两个动态开点线段树,要将线段树 \(x\) 的信息合并到线段树 \(y\) 上。

对于当前合并的区间 \([l,r]\)

  • 若都有左/右儿子,则继续遍历左/右儿子。
  • \(x\) 无左/右儿子,则跳过。
  • \(y\) 无左/右儿子,则合并后其左/右子树全部来自于 \(x\),将 \(y\) 的左/右儿子编号换成 \(x\) 的即可。

显然对于一个节点,只有两棵线段树都有这个节点的时候会遍历到,复杂度 \(O(\min(cnt_x,cnt_y))\),其中 \(cnt_x\) 是线段树 \(x\) 的节点数量。

以下是一些习题。

luogu P4556 雨天的尾巴

链上加、单点查,看到这个东西就套路地(比如说 情报传递这篇题解 就提到了这个技巧)将其转化为单点加和子树查询。

设链 \((x,y)\) 加,\(\operatorname{lca}(x,y)=d\)

则转化为 \(x\) 加,\(y\) 加,\(d\) 减,\(fa_d\) 减即可。

树上每个节点开一颗动态开点线段树,每一次对一个点进行操作增加的节点个数是 \(O(\log n)\) 的(合并不会增加节点数量)。

所以总的节点个数就是 \(O(m \log n)\) 的。

因为这个东西是离线的,所以全部操作执行完之后再从下到上将子树信息合并到其根上,最后挨个查询即可。

#include<bits/stdc++.h>
#define sd std::
// #define int long long
#define F(i,a,b) for(int i=(a);i<=(b);i++)
#define ff(i,a,b) for(int i=(a);i>=(b);i--)
#define me(x,y) memset(x,y,sizeof x)
#define pii sd pair<int,int>
#define X first
#define Y second
#define dbg(x) sd cout<<#x<<":"<<x<<" "
#define dg(x) sd cout<<#x<<":"<<x<<"\n"
#define inf 1e10
int read(){int w=1,c=0;char ch=getchar();for(;ch>'9'||ch<'0';ch=getchar()) if(ch=='-') w=-1;for(;ch>='0'&&ch<='9';ch=getchar()) c=(c<<3)+(c<<1)+ch-48;return w*c;}
void printt(int x){if(x>9) printt(x/10);putchar(x%10+48);}
void print(int x){if(x<0) putchar('-'),printt(-x);else printt(x);}
void put(int x){print(x);putchar('\n');}
void printk(int x){print(x);putchar(' ');}
const int N=2e5+10,P=1e9+7,V=1e5;
int n,m,rt[N];//rt[i]为节点i的线段树的根节点
struct node
{
	int ma,c,l,r;//ma为最大次数,左右儿子
	node(int maa,int cc,int ll,int rr)
	{
		ma=maa,c=cc,l=ll,r=rr;
	}
	node(){ma=0,c=200000,l=0,r=0;}
}s[N*40];
#define ls(x) s[x].l
#define rs(x) s[x].r
int num;//记录编号
void update(int &k,int l,int r,int x,int y)//x处加y
{
	if(!k) k=++num;
	if(l==r)
	{
		s[k].ma+=y;
		s[k].c=l;
		return;
	}
	int mid=l+r>>1;
	if(x<=mid) update(ls(k),l,mid,x,y);
	else update(rs(k),mid+1,r,x,y);
	int p=s[ls(k)].ma,q=s[rs(k)].ma;
	s[k].ma=sd max(p,q);
	s[k].c=(p>q?s[ls(k)].c:p<q?s[rs(k)].c:sd min(s[ls(k)].c,s[rs(k)].c));
}
int cas,ans[N];
void merge(int l,int r,int x,int y)//将节点编号x合并到节点编号y,l-r区间
{
	if(l==r)
	{
		s[y].ma+=s[x].ma;
		return;
	}
	int mid=l+r>>1;
	if(!ls(y)) ls(y)=ls(x);//
	else if(ls(x)) merge(l,mid,ls(x),ls(y));//
	if(!rs(y)) rs(y)=rs(x);//
	else if(rs(x)) merge(mid+1,r,rs(x),rs(y));//
	int p=s[ls(y)].ma,q=s[rs(y)].ma;
	s[y].ma=sd max(p,q);
	s[y].c=(p>q?s[ls(y)].c:p<q?s[rs(y)].c:sd min(s[ls(y)].c,s[rs(y)].c));
}
sd vector<int> g[N];
int dep[N],f[N][21];
void dfs1(int u,int fa)
{
	for(auto v:g[u])
	{
		if(v==fa) continue;
		dep[v]=dep[u]+1;
		f[v][0]=u;
		F(i,0,18) f[v][i+1]=f[f[v][i]][i];
		dfs1(v,u);
	}
}
int lca(int u,int v)
{
	if(dep[u]<dep[v]) sd swap(u,v);
	ff(i,19,0) if(dep[f[u][i]]>=dep[v]) u=f[u][i];
	if(u==v) return u;
	ff(i,19,0) if(f[u][i]!=f[v][i]) u=f[u][i],v=f[v][i];
	return f[u][0];
}
void dfs2(int u,int fa)
{
	for(auto v:g[u])
	{
		if(v==fa) continue;
		dfs2(v,u);
		if(!rt[u]) rt[u]=rt[v];//
		else merge(1,V,rt[v],rt[u]);
	}
	ans[u]=(s[rt[u]].ma>0?s[rt[u]].c:0);
}
void solve()
{
	s[0].ma=-200000;
	n=read();m=read();
	F(i,2,n)
	{
		int x=read(),y=read();
		g[x].emplace_back(y);
		g[y].emplace_back(x);
	}
	dep[1]=1;
	dfs1(1,0);
	F(i,1,m)
	{
		int x=read(),y=read(),z=read();
		int d=lca(x,y);
		update(rt[x],1,V,z,1);
		update(rt[y],1,V,z,1);
		update(rt[d],1,V,z,-1);
		update(rt[f[d][0]],1,V,z,-1);
	}
	dfs2(1,0);
	F(i,1,n) put(ans[i]);
}
int main()
{
	int T=1;
	// T=read();
	while(T--) solve();
    return 0;
}

注意线段树合并的时候要考虑以下问题(也可能只是我实现地太撇了):

  1. 将线段树 \(x\) 合并到 \(y\) 之后 \(y\) 上会挂一些 \(x\) 的节点,然后如果我们继续对 \(y\) 进行其他操作,则有可能改到从线段树 \(x\) 处复制的节点,则 \(x\) 这棵线段树也会受到影响。

上题的处理办法就是遍历到 \(x\) 时立刻记录答案,这个时候 \(x\) 虽然进行了合并,但并没有 \(y\) 上挂了 \(x\) 的节点,就不可能影响到。

  1. 以下的代码块 1 如果不用代码块 2 的特判会出错,因为不特判就有可能导致节点 \(0\) 下面挂了左右儿子,然后再次遍历到 \(0\) 的时候就会进入左右儿子,显然是错完了的。
if(!ls(y)) ls(y)=ls(x);
else if(ls(x)) merge(l,mid,ls(x),ls(y));
if(!rs(y)) rs(y)=rs(x);
else if(rs(x)) merge(mid+1,r,rs(x),rs(y));
if(!rt[u]) rt[u]=rt[v];

P3224 [HNOI2012] 永无乡

每个节点 \(u\) 维护一个权值线段树,叶子维护与这个节点联通的节点是否有这个重要度的,区间维护某个重要度区间内存在多少个节点与 \(u\) 联通。

初始时显然 \(i\) 的线段树上只有 \(p_i\) 有值表示其只能到达自己。

第一个操作以及初始的边就是将 \(x\)\(y\) 的线段树都变成其合并之后的线段树。

直接将 \(x\) 合并到 \(y\) 上然后并查集维护即可。

第二个操作的第 \(k\) 大是老套路了,线段树上二分即可。

#include<bits/stdc++.h>
#define sd std::
#define int long long
#define F(i,a,b) for(int i=(a);i<=(b);i++)
#define ff(i,a,b) for(int i=(a);i>=(b);i--)
#define me(x,y) memset(x,y,sizeof x)
#define pii sd pair<int,int>
#define X first
#define Y second
#define dbg(x) sd cout<<#x<<":"<<x<<" "
#define dg(x) sd cout<<#x<<":"<<x<<"\n"
#define inf 1e10
int read(){int w=1,c=0;char ch=getchar();for(;ch>'9'||ch<'0';ch=getchar()) if(ch=='-') w=-1;for(;ch>='0'&&ch<='9';ch=getchar()) c=(c<<3)+(c<<1)+ch-48;return w*c;}
void printt(int x){if(x>9) printt(x/10);putchar(x%10+48);}
void print(int x){if(x<0) putchar('-'),printt(-x);else printt(x);}
void put(int x){print(x);putchar('\n');}
void printk(int x){print(x);putchar(' ');}
const int N=1e5+10,P=1e9+7;
int n,m;
int rt[N],num,p[N];
struct node
{
	int l,r;
	int cnt;
}s[N*40];
#define ls(k) s[k].l
#define rs(k) s[k].r
int fa[N];
int find(int x)
{
	return (x==fa[x]?x:fa[x]=find(fa[x]));
}
void update(int &k,int l,int r,int x)
{
	if(!k) k=++num;
	if(l==r)
	{
		s[k].cnt=1;
		return;
	}
	int mid=l+r>>1;
	if(x<=mid) update(ls(k),l,mid,x);
	else update(rs(k),mid+1,r,x);
	s[k].cnt=s[ls(k)].cnt+s[rs(k)].cnt;
}
void merge(int l,int r,int x,int y)//把x合并到y上
{
	if(l==r)
	{
		s[y].cnt|=s[x].cnt;
		return;
	}
	int mid=l+r>>1;
	if(!ls(y)) ls(y)=ls(x);
	else if(ls(x)) merge(l,mid,ls(x),ls(y));
	if(!rs(y)) rs(y)=rs(x);
	else if(rs(x)) merge(mid+1,r,rs(x),rs(y));
	s[y].cnt=s[ls(y)].cnt+s[rs(y)].cnt;
}
int find(int k,int l,int r,int x)
{
	// dbg(l),dbg(r),dg(x);
	// dg(s[k].cnt);
	if(s[k].cnt<x) return -1;
	if(l==r) return p[l];
	int mid=l+r>>1;
	if(!rs(k)) return find(ls(k),l,mid,x);
	if(!ls(k)) return find(rs(k),mid+1,r,x);
	if(s[ls(k)].cnt>=x) return find(ls(k),l,mid,x);
	return find(rs(k),mid+1,r,x-s[ls(k)].cnt);
}
void solve()
{
	n=read();m=read();
	F(i,1,n)
	{
		fa[i]=i;
		int x=read();p[x]=i;
		update(rt[i],1,n,x);
	}
	F(i,1,m)
	{
		int x=read(),y=read();
		x=find(x),y=find(y);
		if(x!=y)
		{
			merge(1,n,rt[x],rt[y]);
			fa[x]=y;
		}
	}
	int Q=read();
	while(Q--)
	{
		char op[2];
		int x,y;
		scanf("%s",op);x=read(),y=read();
		x=find(x);
		if(op[0]=='Q')
		{
			put(find(rt[x],1,n,y));
		}
		else
		{
			y=find(y);
			if(x!=y)
			{
				merge(1,n,rt[x],rt[y]);
				fa[x]=y;
			}
		}
	}
}
signed main()
{
	int T=1;
	// T=read();
	while(T--) solve();
    return 0;
}

P3605 [USACO17JAN] Promotion Counting P

离散化之后直接每次查线段树上的一段前缀和然后合并即可。

其实感觉还有另外一种做法,将 \(>\) 它的看作 \(1\),否则看作 \(0\)

从最大值开始,每次子树查一下,然后 \(i\to i+1\) 的状态差别就只有将 \(i\) 这个点设为 \(1\)

不过毕竟是熟悉线段树合并就打的复杂一点吧。

#include<bits/stdc++.h>
#define sd std::
#define int long long
#define F(i,a,b) for(int i=(a);i<=(b);i++)
#define ff(i,a,b) for(int i=(a);i>=(b);i--)
#define me(x,y) memset(x,y,sizeof x)
#define pii sd pair<int,int>
#define X first
#define Y second
#define dbg(x) sd cout<<#x<<":"<<x<<" "
#define dg(x) sd cout<<#x<<":"<<x<<"\n"
#define inf 1e10
int read(){int w=1,c=0;char ch=getchar();for(;ch>'9'||ch<'0';ch=getchar()) if(ch=='-') w=-1;for(;ch>='0'&&ch<='9';ch=getchar()) c=(c<<3)+(c<<1)+ch-48;return w*c;}
void printt(int x){if(x>9) printt(x/10);putchar(x%10+48);}
void print(int x){if(x<0) putchar('-'),printt(-x);else printt(x);}
void put(int x){print(x);putchar('\n');}
void printk(int x){print(x);putchar(' ');}
const int N=5e5+10,P=1e9+7,V=1e9;
int n,num,rt[N],p[N],ans[N];
sd vector<int> g[N];
struct node
{
	int val,l,r;
}s[N*20];
#define ls(k) s[k].l
#define rs(k) s[k].r
void update(int &k,int l,int r,int x)
{
	if(!k) k=++num;
	if(l==r)
	{
		s[k].val=1;
		return;
	}
	int mid=l+r>>1;
	if(x<=mid) update(ls(k),l,mid,x);
	else update(rs(k),mid+1,r,x);
	s[k].val=s[ls(k)].val+s[rs(k)].val;
}
int ask(int k,int l,int r,int x,int y)
{
	if(x<=l&&y>=r) return s[k].val;
	int mid=l+r>>1,res=0;
	if(x<=mid&&ls(k)) res+=ask(ls(k),l,mid,x,y);
	if(y>mid&&rs(k))res+=ask(rs(k),mid+1,r,x,y);
	return res;
}
void merge(int l,int r,int x,int y)//把 x 合并到 y 上
{
	if(l==r)
	{
		s[y].val+=s[x].val;
		return;
	}
	int mid=l+r>>1;
	if(!ls(y)) ls(y)=ls(x);
	else if(ls(x)) merge(l,mid,ls(x),ls(y));
	if(!rs(y)) rs(y)=rs(x);
	else if(rs(x)) merge(mid+1,r,rs(x),rs(y));
	s[y].val=s[ls(y)].val+s[rs(y)].val;
}
void dfs(int u)
{
	for(auto v:g[u])
	{
		dfs(v);
		ans[u]+=ask(rt[v],1,V,p[u]+1,V);
		merge(1,V,rt[v],rt[u]);
	}
}
void solve()
{
	n=read();
	F(i,1,n)
	{
		p[i]=read();
		update(rt[i],1,V,p[i]);
	}
	F(i,2,n)
	{
		int x=read();
		g[x].emplace_back(i);
	}
	dfs(1);
	F(i,1,n) put(ans[i]);
}
signed main()
{
// 	freopen(".in","r",stdin);
//	freopen(".out","w",stdout);
	int T=1;
	// T=read();
	while(T--) solve();
    return 0;
}

[POI 2011] ROT-Tree Rotations

假设左/右子树的逆序对分别为 \(x,y\),则总的逆序对就是 \(x+y\) 再加上左子树比右子树大的数量。

反过来的话就是右子树比左子树大的数量。

不难发现 \(u\) 的决策不影响 \(fa_u\) 另一棵子树的决策。

于是我们直接从上到下每个点单独考虑,考虑怎么快速算出左子树比右子树大的数量。

直接上线段树合并。

在合并 \([l,r]\) 的时候,顺带计算一下值域 \([l,r]\) 中右子树比左子树大的数量 \(val_{l,r}\)

不难发现 \(val_{l,r}=val_{l,mid}+val_{mid+1,r}\) 然后加上 \([l,mid]\) 中左子树的数量乘上 \([mid+1,r]\) 右子树的数量。

感觉很好做啊,直接就做完了。

稍微卡了一下空间过了。

#include<bits/stdc++.h>
#define sd std::
// #define int long long
#define ll long long
#define F(i,a,b) for(int i=(a);i<=(b);i++)
#define ff(i,a,b) for(int i=(a);i>=(b);i--)
#define me(x,y) memset(x,y,sizeof x)
#define pii sd pair<int,int>
#define X first
#define Y second
#define dbg(x) sd cout<<#x<<":"<<x<<" "
#define dg(x) sd cout<<#x<<":"<<x<<"\n"
#define inf 1e10
int read(){int w=1,c=0;char ch=getchar();for(;ch>'9'||ch<'0';ch=getchar()) if(ch=='-') w=-1;for(;ch>='0'&&ch<='9';ch=getchar()) c=(c<<3)+(c<<1)+ch-48;return w*c;}
void printt(int x){if(x>9) printt(x/10);putchar(x%10+48);}
void print(int x){if(x<0) putchar('-'),printt(-x);else printt(x);}
void put(int x){print(x);putchar('\n');}
void printk(int x){print(x);putchar(' ');}
const int N=4e5+10,P=1e9+7;
int n,lson[N],rson[N],id=1,num;
int rt[N];
struct node
{
	int val;
	int l,r;
}s[N*10];
#define ls(k) s[k].l
#define rs(k) s[k].r
void update(int &k,int l,int r,int x)
{
	if(!k) k=++num;
	if(l==r)
	{
		s[k].val=1;
		return;
	}
	int mid=l+r>>1;
	if(x<=mid) update(ls(k),l,mid,x);
	else update(rs(k),mid+1,r,x);
	s[k].val=s[ls(k)].val+s[rs(k)].val;
}
ll calc(int l,int r,int x,int y)//x是左边的,y是右边的
{
	if(l==r) return 0;
	int mid=l+r>>1;
	ll res=(ll)s[rs(y)].val*s[ls(x)].val;
	if(ls(x)&&ls(y)) res+=calc(l,mid,ls(x),ls(y));
	if(rs(x)&&rs(y)) res+=calc(mid+1,r,rs(x),rs(y));
	return res;
}
void merge(int l,int r,int x,int y)
{
	if(l==r)
	{
		s[y].val+=s[x].val;
		return;
	}
	int mid=l+r>>1;
	if(!ls(y)) ls(y)=ls(x);
	else if(ls(x)) merge(l,mid,ls(x),ls(y));
	if(!rs(y)) rs(y)=rs(x);
	else if(rs(x)) merge(mid+1,r,rs(x),rs(y));
	s[y].val=s[ls(y)].val+s[rs(y)].val;
}
void input(int u)
{
	int x=read();
	if(x)
	{
		update(rt[u],1,n,x);
	}
	else
	{
		input(lson[u]=++id);
		input(rson[u]=++id);
	}
}
ll dfs(int u)//sum为节点数量
{
	// dg(u);
	// dbg(lson[u]),dg(rson[u]);
	if(!lson[u]&&!rson[u]) return 0;
	ll val=dfs(lson[u])+dfs(rson[u]);
	// dbg(val);
	ll p1=calc(1,n,rt[lson[u]],rt[rson[u]]);
	// dbg(p1);
	ll p2=calc(1,n,rt[rson[u]],rt[lson[u]]);
	// dg(p2);
	val+=sd min(p1,p2);
	merge(1,n,rt[lson[u]],rt[rson[u]]);
	rt[u]=rt[rson[u]];
	return val;
}
void solve()
{
	n=read();
	input(1);
	sd cout<<dfs(1);
}
int main()
{
// 	freopen(".in","r",stdin);
//	freopen(".out","w",stdout);
	int T=1;
	// T=read();
	while(T--) solve();
    return 0;
}

CF208E Blood Cousins

\(u\)\(p\) 级表亲数量就是 \(u\)\(p\) 级祖先的距离为 \(p\) 的儿子数量。

考虑直接给每个点维护其 \(k\) 级儿子,\(v\) 合并到 \(u\) 就是右移一位然后插入一个数。

但是其实不用这么麻烦,考虑线段树维护的叶子 \(val_x\) 代表 \(u\) 子树内 \(dep=x\) 的结点数量。

然后查询的时候就是询问某个点的深度加上 \(k\) 这个深度有多少个节点,就直接问就行。

但是由于众所周知线段树合并不能在线,所以得离线到每个点上做询问。

#include<bits/stdc++.h>
#define sd std::
// #define int long long
#define F(i,a,b) for(int i=(a);i<=(b);i++)
#define ff(i,a,b) for(int i=(a);i>=(b);i--)
#define me(x,y) memset(x,y,sizeof x)
#define pii sd pair<int,int>
#define X first
#define Y second
#define dbg(x) sd cout<<#x<<":"<<x<<" "
#define dg(x) sd cout<<#x<<":"<<x<<"\n"
#define inf 1e10
int read(){int w=1,c=0;char ch=getchar();for(;ch>'9'||ch<'0';ch=getchar()) if(ch=='-') w=-1;for(;ch>='0'&&ch<='9';ch=getchar()) c=(c<<3)+(c<<1)+ch-48;return w*c;}
void printt(int x){if(x>9) printt(x/10);putchar(x%10+48);}
void print(int x){if(x<0) putchar('-'),printt(-x);else printt(x);}
void put(int x){print(x);putchar('\n');}
void printk(int x){print(x);putchar(' ');}
const int N=1e5+10,P=1e9+7;
int n,num,ans[N],rt[N];
sd vector<int> g[N],root;
sd vector<pii> q[N];
struct node
{
	int l,r,val;
}s[N*20];
#define ls(k) s[k].l
#define rs(k) s[k].r
void update(int &k,int l,int r,int x)
{
	if(!k) k=++num;
	if(l==r)
	{
		s[k].val++;
		return;
	}
	int mid=l+r>>1;
	if(x<=mid) update(ls(k),l,mid,x);
	else update(rs(k),mid+1,r,x);
	s[k].val=s[ls(k)].val+s[rs(k)].val;
}
void merge(int l,int r,int x,int y)
{
	if(l==r)
	{
		s[y].val+=s[x].val;
		return;
	}
	int mid=l+r>>1;
	if(!ls(y)) ls(y)=ls(x);
	else if(ls(x)) merge(l,mid,ls(x),ls(y));
	if(!rs(y)) rs(y)=rs(x);
	else if(rs(x)) merge(mid+1,r,rs(x),rs(y));
	s[y].val=s[ls(y)].val+s[rs(y)].val;
}
int ask(int k,int l,int r,int x)
{
	if(l==r) return s[k].val;
	int mid=l+r>>1;
	if(x<=mid&&ls(k)) return ask(ls(k),l,mid,x);
	else if(rs(k)) return ask(rs(k),mid+1,r,x);
	return 0;
}
int dep[N],f[N][21];
void dfs1(int u)
{
	for(auto v:g[u])
	{
		dep[v]=dep[u]+1;
		f[v][0]=u;
		F(i,0,19) f[v][i+1]=f[f[v][i]][i];
		dfs1(v);
	}
}
void dfs2(int u)
{
	update(rt[u],1,n,dep[u]);
	for(auto v:g[u])
	{
		dfs2(v);
		merge(1,n,rt[v],rt[u]);
	}
	for(auto [k,id]:q[u])
	{
		ans[id]=ask(rt[u],1,n,dep[u]+k);
	}
}
int find(int u,int k)//查询祖先
{
	int val=dep[u]-k;
	ff(i,20,0) if(dep[f[u][i]]>=val) u=f[u][i];
	return u;
}

void solve()
{
	n=read();
	F(i,1,n)
	{
		int x=read();
		if(!x) root.emplace_back(i);
		else g[x].emplace_back(i);
	}
	for(auto u:root)
	{
		dep[u]=1;
		dfs1(u);
	}
	int Q=read();
	F(i,1,Q)
	{
		int v=read(),p=read();
		q[find(v,p)].emplace_back(p,i);
	}
	for(auto u:root) dfs2(u);
	F(i,1,Q) printk(ans[i]?ans[i]-1:0);
}
int main()
{
// 	freopen(".in","r",stdin);
//	freopen(".out","w",stdout);
	int T=1;
	// T=read();
	while(T--) solve();
    return 0;
}

P3899 [湖南集训] 更为厉害

不难发现 \(a,b\) 肯定是祖先和子孙的关系。

可以分为 \(b\)\(a\) 祖先和 \(b\)\(a\) 子孙。

第一种情况很简单,即 \(siz_u\) 乘上和 \(u\) 距离不超过 \(k\) 的祖先。

考虑第二种情况怎么计算。

即计算 \(\sum\limits_{x\in \operatorname{Subtree}(u),dis(u)} (siz_x-1)\)

考虑直接每个节点开一颗线段树,叶子 \(i\) 就是 \(dep=i\)\(\sum siz-1\) 之和。

然后询问的时候直接查线段树上某一段的和即可。

这个还是比较简单的,直接上线段树合并即可。

#include<bits/stdc++.h>
#define sd std::
#define int long long
#define F(i,a,b) for(int i=(a);i<=(b);i++)
#define ff(i,a,b) for(int i=(a);i>=(b);i--)
#define me(x,y) memset(x,y,sizeof x)
#define pii sd pair<int,int>
#define X first
#define Y second
#define dbg(x) sd cout<<#x<<":"<<x<<" "
#define dg(x) sd cout<<#x<<":"<<x<<"\n"
#define inf 1e10
int read(){int w=1,c=0;char ch=getchar();for(;ch>'9'||ch<'0';ch=getchar()) if(ch=='-') w=-1;for(;ch>='0'&&ch<='9';ch=getchar()) c=(c<<3)+(c<<1)+ch-48;return w*c;}
void printt(int x){if(x>9) printt(x/10);putchar(x%10+48);}
void print(int x){if(x<0) putchar('-'),printt(-x);else printt(x);}
void put(int x){print(x);putchar('\n');}
void printk(int x){print(x);putchar(' ');}
const int N=3e5+10,P=1e9+7;
int n,Q,siz[N],dep[N],num,rt[N],ans[N];
struct node
{
	int val,l,r;
}s[N*20];
#define ls(k) s[k].l
#define rs(k) s[k].r
void update(int &k,int l,int r,int x,int y)
{
	if(!k) k=++num;
	if(l==r)
	{
		s[k].val+=y;
		return;
	}
	int mid=l+r>>1;
	if(x<=mid) update(ls(k),l,mid,x,y);
	else update(rs(k),mid+1,r,x,y);
	s[k].val=s[ls(k)].val+s[rs(k)].val;
}
void merge(int l,int r,int x,int y)//x合并到y
{
	if(l==r)
	{
		s[y].val+=s[x].val;
		return;
	}
	int mid=l+r>>1;
	if(!ls(y)) ls(y)=ls(x);
	else if(ls(x)) merge(l,mid,ls(x),ls(y));
	if(!rs(y)) rs(y)=rs(x);
	else if(rs(x)) merge(mid+1,r,rs(x),rs(y));
	s[y].val=s[ls(y)].val+s[rs(y)].val;
}
int ask(int k,int l,int r,int x,int y)
{
	if(x<=l&&y>=r) return s[k].val;
	int mid=l+r>>1,res=0;
	if(x<=mid&&ls(k)) res+=ask(ls(k),l,mid,x,y);
	if(y>mid&&rs(k)) res+=ask(rs(k),mid+1,r,x,y);
	return res;
}
sd vector<int> g[N];
sd vector<pii> q[N];
void dfs1(int u,int fa)
{
	siz[u]=1;
	for(auto v:g[u])
	{
		if(v==fa) continue;
		dep[v]=dep[u]+1;
		dfs1(v,u);
		siz[u]+=siz[v];
	}
}
void dfs2(int u,int fa)
{
	update(rt[u],1,n,dep[u],siz[u]-1);
	for(auto v:g[u])
	{
		if(v==fa) continue;
		dfs2(v,u);
		merge(1,n,rt[v],rt[u]);
	}
	for(auto [k,id]:q[u])
	{
		int val=(dep[u]==n?0:ask(rt[u],1,n,dep[u]+1,sd min(n,dep[u]+k)));
		ans[id]=sd min(dep[u]-1,k)*(siz[u]-1)+val;
	}
	
}
void solve()
{
	n=read(),Q=read();
	F(i,2,n)
	{
		int x=read(),y=read();
		g[x].emplace_back(y);
		g[y].emplace_back(x);
	}
	dep[1]=1;
	dfs1(1,0);
	F(i,1,Q)
	{
		int p=read(),k=read();
		q[p].emplace_back(k,i);
	}
	dfs2(1,0);
	F(i,1,Q) put(ans[i]);
}
signed main()
{
// 	freopen(".in","r",stdin);
//	freopen(".out","w",stdout);
	int T=1;
	// T=read();
	while(T--) solve();
    return 0;
}

P5298 [PKUWC2018] Minimax

最初的想法是考虑每个节点开一颗线段树,每个叶子维护其作为这个值的概率乘以10000。

然后区间就维护这个值域的总概率 \(val_{l,r}\),即选到这个值域的概率是多少。

考虑左右儿子 \(x,y\) 的两颗线段树怎么合并。考虑值域 \([l,r]\)\(u\) 线段树的影响。

将这个影响分为三类:

  • \(x,y\) 都选择 \([l,mid]\) 的数。
  • \(x,y\) 都选择 \([mid+1,r]\) 的数。
  • \(x,y\) 一个选择 \([l,mid]\),一个选择 \([mid+1,r]\) 的数。

前两者递归处理,我们只需要处理跨过中点的贡献。

贡献式子写出来之后我们发现大概就是 \(x\)\([mid+1,r]\) 区间会对 \(u\)\([mid+1,r]\) 区间一一对应的造成其自身乘以 \(valy_{l,mid}\times p_u\) 的贡献。

剩下的同理。

但是我们不能直接在每个 \([l,r]\) 处打乘法标记,一是 \(x,y\) 的乘法标记不互通,而是我们在递归子区间的时候处理的显然应该是没有打标记的结果。那这个乘法标记究竟在哪里下传就很难办。

考虑将这个乘法标记累计起来,如果 \(x,y\) 都有某一方的节点就继续遍历,否则将累计的乘法标记打上去。

#include<bits/stdc++.h>
#define sd std::
#define int long long
#define F(i,a,b) for(int i=(a);i<=(b);i++)
#define ff(i,a,b) for(int i=(a);i>=(b);i--)
#define me(x,y) memset(x,y,sizeof x)
#define pii sd pair<int,int>
#define X first
#define Y second
#define dbg(x) sd cout<<#x<<":"<<x<<" "
#define dg(x) sd cout<<#x<<":"<<x<<"\n"
#define inf 1e10
int read(){int w=1,c=0;char ch=getchar();for(;ch>'9'||ch<'0';ch=getchar()) if(ch=='-') w=-1;for(;ch>='0'&&ch<='9';ch=getchar()) c=(c<<3)+(c<<1)+ch-48;return w*c;}
void printt(int x){if(x>9) printt(x/10);putchar(x%10+48);}
void print(int x){if(x<0) putchar('-'),printt(-x);else printt(x);}
void put(int x){print(x);putchar('\n');}
void printk(int x){print(x);putchar(' ');}
const int N=3e5+10,P=998244353,V=1e9;
int n,num,p[N],rt[N],lson[N],rson[N];
struct node
{
	int val,l,r,tag;
}s[N*20];
#define ls(k) s[k].l
#define rs(k) s[k].r
int L,R;
void pushdown(int k)
{
	int tg=s[k].tag;
	if(tg==1) return;
	if(ls(k))
	{
		s[ls(k)].tag=s[ls(k)].tag*tg%P;
		s[ls(k)].val=s[ls(k)].val*tg%P;
	}
	if(rs(k))
	{
		s[rs(k)].tag=s[rs(k)].tag*tg%P;
		s[rs(k)].val=s[rs(k)].val*tg%P;
	}
	s[k].tag=1;
}
void update(int &k,int l,int r,int x)
{
	if(!k) k=++num,s[k].tag=1;
	if(l==r)
	{
		s[k].val=1;
		return;
	}
	int mid=l+r>>1;
	if(x<=mid) update(ls(k),l,mid,x);
	else update(rs(k),mid+1,r,x);
	s[k].val=s[ls(k)].val+s[rs(k)].val;
}
int d;//p[u]
void merge(int l,int r,int x,int &y,int tagx,int tagy)
{
	if(!x&&!y) return;
	if(!x)
	{
		s[y].tag=s[y].tag*tagy%P;
		s[y].val=s[y].val*tagy%P;
		return;
	}
	if(!y)
	{
		y=x;
		s[y].tag=s[y].tag*tagx%P;
		s[y].val=s[y].val*tagx%P;
		return;
	}
	pushdown(x);pushdown(y);
	int vly=s[ls(y)].val,vry=s[rs(y)].val,vlx=s[ls(x)].val,vrx=s[rs(x)].val;
	L=l,R=r;
	int mid=l+r>>1;
	merge(l,mid,ls(x),ls(y),(tagx+vry*(1-d+P))%P,(tagy+vrx*(1-d+P))%P);
	merge(mid+1,r,rs(x),rs(y),(tagx+vly*d)%P,(tagy+vlx*d)%P);
	s[y].val=(s[ls(y)].val+s[rs(y)].val)%P;
}
int cnt,ans;
int now;
void out(int k,int l,int r)
{
	if(l==r)
	{
		++cnt;
		(ans+=l*cnt%P*s[k].val%P*s[k].val%P)%=P;
		return;
	}
	L=l,R=r;
	pushdown(k);
	int mid=l+r>>1;
	if(ls(k)) out(ls(k),l,mid);
	if(rs(k)) out(rs(k),mid+1,r);
}
void dfs(int u)
{
	if(!lson[u]) return;
	dfs(lson[u]);
	if(rson[u])
	{
		dfs(rson[u]);
		d=p[u];
		merge(1,V,rt[rson[u]],rt[lson[u]],0,0);
	}
	rt[u]=rt[lson[u]];
}
void solve()
{
	n=read();read();
	F(i,2,n)
	{
		int x=read();
		if(!lson[x]) lson[x]=i;
		else if(!rson[x]) rson[x]=i;
	}
	int inv=796898467;//10000的逆元
	F(u,1,n) 
	{
		int x=read();
		if(!lson[u]) update(rt[u],1,V,x);
		else p[u]=x*inv%P;
	}
	dfs(1);
	out(rt[1],1,V);
	put(ans);
}
signed main()
{
// 	freopen(".in","r",stdin);
//	freopen(".out","w",stdout);
	int T=1;
	// T=read();
	while(T--) solve();
    return 0;
}
posted @ 2025-09-11 20:52  _E_M_T  阅读(9)  评论(0)    收藏  举报