线段树合并学习笔记
线段树合并是为了保证「合并两个动态开点线段树的信息」这个操作的复杂度的。
暴力合并两个满二叉树的复杂度一次就是其节点数 \(O(n)\),完全不能接受。
考虑现在有两个动态开点线段树,要将线段树 \(x\) 的信息合并到线段树 \(y\) 上。
对于当前合并的区间 \([l,r]\):
- 若都有左/右儿子,则继续遍历左/右儿子。
- 若 \(x\) 无左/右儿子,则跳过。
- 若 \(y\) 无左/右儿子,则合并后其左/右子树全部来自于 \(x\),将 \(y\) 的左/右儿子编号换成 \(x\) 的即可。
显然对于一个节点,只有两棵线段树都有这个节点的时候会遍历到,复杂度 \(O(\min(cnt_x,cnt_y))\),其中 \(cnt_x\) 是线段树 \(x\) 的节点数量。
以下是一些习题。
luogu P4556 雨天的尾巴
链上加、单点查,看到这个东西就套路地(比如说 情报传递 的 这篇题解 就提到了这个技巧)将其转化为单点加和子树查询。
设链 \((x,y)\) 加,\(\operatorname{lca}(x,y)=d\)。
则转化为 \(x\) 加,\(y\) 加,\(d\) 减,\(fa_d\) 减即可。
树上每个节点开一颗动态开点线段树,每一次对一个点进行操作增加的节点个数是 \(O(\log n)\) 的(合并不会增加节点数量)。
所以总的节点个数就是 \(O(m \log n)\) 的。
因为这个东西是离线的,所以全部操作执行完之后再从下到上将子树信息合并到其根上,最后挨个查询即可。
#include<bits/stdc++.h>
#define sd std::
// #define int long long
#define F(i,a,b) for(int i=(a);i<=(b);i++)
#define ff(i,a,b) for(int i=(a);i>=(b);i--)
#define me(x,y) memset(x,y,sizeof x)
#define pii sd pair<int,int>
#define X first
#define Y second
#define dbg(x) sd cout<<#x<<":"<<x<<" "
#define dg(x) sd cout<<#x<<":"<<x<<"\n"
#define inf 1e10
int read(){int w=1,c=0;char ch=getchar();for(;ch>'9'||ch<'0';ch=getchar()) if(ch=='-') w=-1;for(;ch>='0'&&ch<='9';ch=getchar()) c=(c<<3)+(c<<1)+ch-48;return w*c;}
void printt(int x){if(x>9) printt(x/10);putchar(x%10+48);}
void print(int x){if(x<0) putchar('-'),printt(-x);else printt(x);}
void put(int x){print(x);putchar('\n');}
void printk(int x){print(x);putchar(' ');}
const int N=2e5+10,P=1e9+7,V=1e5;
int n,m,rt[N];//rt[i]为节点i的线段树的根节点
struct node
{
int ma,c,l,r;//ma为最大次数,左右儿子
node(int maa,int cc,int ll,int rr)
{
ma=maa,c=cc,l=ll,r=rr;
}
node(){ma=0,c=200000,l=0,r=0;}
}s[N*40];
#define ls(x) s[x].l
#define rs(x) s[x].r
int num;//记录编号
void update(int &k,int l,int r,int x,int y)//x处加y
{
if(!k) k=++num;
if(l==r)
{
s[k].ma+=y;
s[k].c=l;
return;
}
int mid=l+r>>1;
if(x<=mid) update(ls(k),l,mid,x,y);
else update(rs(k),mid+1,r,x,y);
int p=s[ls(k)].ma,q=s[rs(k)].ma;
s[k].ma=sd max(p,q);
s[k].c=(p>q?s[ls(k)].c:p<q?s[rs(k)].c:sd min(s[ls(k)].c,s[rs(k)].c));
}
int cas,ans[N];
void merge(int l,int r,int x,int y)//将节点编号x合并到节点编号y,l-r区间
{
if(l==r)
{
s[y].ma+=s[x].ma;
return;
}
int mid=l+r>>1;
if(!ls(y)) ls(y)=ls(x);//
else if(ls(x)) merge(l,mid,ls(x),ls(y));//
if(!rs(y)) rs(y)=rs(x);//
else if(rs(x)) merge(mid+1,r,rs(x),rs(y));//
int p=s[ls(y)].ma,q=s[rs(y)].ma;
s[y].ma=sd max(p,q);
s[y].c=(p>q?s[ls(y)].c:p<q?s[rs(y)].c:sd min(s[ls(y)].c,s[rs(y)].c));
}
sd vector<int> g[N];
int dep[N],f[N][21];
void dfs1(int u,int fa)
{
for(auto v:g[u])
{
if(v==fa) continue;
dep[v]=dep[u]+1;
f[v][0]=u;
F(i,0,18) f[v][i+1]=f[f[v][i]][i];
dfs1(v,u);
}
}
int lca(int u,int v)
{
if(dep[u]<dep[v]) sd swap(u,v);
ff(i,19,0) if(dep[f[u][i]]>=dep[v]) u=f[u][i];
if(u==v) return u;
ff(i,19,0) if(f[u][i]!=f[v][i]) u=f[u][i],v=f[v][i];
return f[u][0];
}
void dfs2(int u,int fa)
{
for(auto v:g[u])
{
if(v==fa) continue;
dfs2(v,u);
if(!rt[u]) rt[u]=rt[v];//
else merge(1,V,rt[v],rt[u]);
}
ans[u]=(s[rt[u]].ma>0?s[rt[u]].c:0);
}
void solve()
{
s[0].ma=-200000;
n=read();m=read();
F(i,2,n)
{
int x=read(),y=read();
g[x].emplace_back(y);
g[y].emplace_back(x);
}
dep[1]=1;
dfs1(1,0);
F(i,1,m)
{
int x=read(),y=read(),z=read();
int d=lca(x,y);
update(rt[x],1,V,z,1);
update(rt[y],1,V,z,1);
update(rt[d],1,V,z,-1);
update(rt[f[d][0]],1,V,z,-1);
}
dfs2(1,0);
F(i,1,n) put(ans[i]);
}
int main()
{
int T=1;
// T=read();
while(T--) solve();
return 0;
}
注意线段树合并的时候要考虑以下问题(也可能只是我实现地太撇了):
- 将线段树 \(x\) 合并到 \(y\) 之后 \(y\) 上会挂一些 \(x\) 的节点,然后如果我们继续对 \(y\) 进行其他操作,则有可能改到从线段树 \(x\) 处复制的节点,则 \(x\) 这棵线段树也会受到影响。
上题的处理办法就是遍历到 \(x\) 时立刻记录答案,这个时候 \(x\) 虽然进行了合并,但并没有 \(y\) 上挂了 \(x\) 的节点,就不可能影响到。
- 以下的代码块 1 如果不用代码块 2 的特判会出错,因为不特判就有可能导致节点 \(0\) 下面挂了左右儿子,然后再次遍历到 \(0\) 的时候就会进入左右儿子,显然是错完了的。
if(!ls(y)) ls(y)=ls(x);
else if(ls(x)) merge(l,mid,ls(x),ls(y));
if(!rs(y)) rs(y)=rs(x);
else if(rs(x)) merge(mid+1,r,rs(x),rs(y));
if(!rt[u]) rt[u]=rt[v];
P3224 [HNOI2012] 永无乡
每个节点 \(u\) 维护一个权值线段树,叶子维护与这个节点联通的节点是否有这个重要度的,区间维护某个重要度区间内存在多少个节点与 \(u\) 联通。
初始时显然 \(i\) 的线段树上只有 \(p_i\) 有值表示其只能到达自己。
第一个操作以及初始的边就是将 \(x\) 和 \(y\) 的线段树都变成其合并之后的线段树。
直接将 \(x\) 合并到 \(y\) 上然后并查集维护即可。
第二个操作的第 \(k\) 大是老套路了,线段树上二分即可。
#include<bits/stdc++.h>
#define sd std::
#define int long long
#define F(i,a,b) for(int i=(a);i<=(b);i++)
#define ff(i,a,b) for(int i=(a);i>=(b);i--)
#define me(x,y) memset(x,y,sizeof x)
#define pii sd pair<int,int>
#define X first
#define Y second
#define dbg(x) sd cout<<#x<<":"<<x<<" "
#define dg(x) sd cout<<#x<<":"<<x<<"\n"
#define inf 1e10
int read(){int w=1,c=0;char ch=getchar();for(;ch>'9'||ch<'0';ch=getchar()) if(ch=='-') w=-1;for(;ch>='0'&&ch<='9';ch=getchar()) c=(c<<3)+(c<<1)+ch-48;return w*c;}
void printt(int x){if(x>9) printt(x/10);putchar(x%10+48);}
void print(int x){if(x<0) putchar('-'),printt(-x);else printt(x);}
void put(int x){print(x);putchar('\n');}
void printk(int x){print(x);putchar(' ');}
const int N=1e5+10,P=1e9+7;
int n,m;
int rt[N],num,p[N];
struct node
{
int l,r;
int cnt;
}s[N*40];
#define ls(k) s[k].l
#define rs(k) s[k].r
int fa[N];
int find(int x)
{
return (x==fa[x]?x:fa[x]=find(fa[x]));
}
void update(int &k,int l,int r,int x)
{
if(!k) k=++num;
if(l==r)
{
s[k].cnt=1;
return;
}
int mid=l+r>>1;
if(x<=mid) update(ls(k),l,mid,x);
else update(rs(k),mid+1,r,x);
s[k].cnt=s[ls(k)].cnt+s[rs(k)].cnt;
}
void merge(int l,int r,int x,int y)//把x合并到y上
{
if(l==r)
{
s[y].cnt|=s[x].cnt;
return;
}
int mid=l+r>>1;
if(!ls(y)) ls(y)=ls(x);
else if(ls(x)) merge(l,mid,ls(x),ls(y));
if(!rs(y)) rs(y)=rs(x);
else if(rs(x)) merge(mid+1,r,rs(x),rs(y));
s[y].cnt=s[ls(y)].cnt+s[rs(y)].cnt;
}
int find(int k,int l,int r,int x)
{
// dbg(l),dbg(r),dg(x);
// dg(s[k].cnt);
if(s[k].cnt<x) return -1;
if(l==r) return p[l];
int mid=l+r>>1;
if(!rs(k)) return find(ls(k),l,mid,x);
if(!ls(k)) return find(rs(k),mid+1,r,x);
if(s[ls(k)].cnt>=x) return find(ls(k),l,mid,x);
return find(rs(k),mid+1,r,x-s[ls(k)].cnt);
}
void solve()
{
n=read();m=read();
F(i,1,n)
{
fa[i]=i;
int x=read();p[x]=i;
update(rt[i],1,n,x);
}
F(i,1,m)
{
int x=read(),y=read();
x=find(x),y=find(y);
if(x!=y)
{
merge(1,n,rt[x],rt[y]);
fa[x]=y;
}
}
int Q=read();
while(Q--)
{
char op[2];
int x,y;
scanf("%s",op);x=read(),y=read();
x=find(x);
if(op[0]=='Q')
{
put(find(rt[x],1,n,y));
}
else
{
y=find(y);
if(x!=y)
{
merge(1,n,rt[x],rt[y]);
fa[x]=y;
}
}
}
}
signed main()
{
int T=1;
// T=read();
while(T--) solve();
return 0;
}
P3605 [USACO17JAN] Promotion Counting P
离散化之后直接每次查线段树上的一段前缀和然后合并即可。
其实感觉还有另外一种做法,将 \(>\) 它的看作 \(1\),否则看作 \(0\)。
从最大值开始,每次子树查一下,然后 \(i\to i+1\) 的状态差别就只有将 \(i\) 这个点设为 \(1\)。
不过毕竟是熟悉线段树合并就打的复杂一点吧。
#include<bits/stdc++.h>
#define sd std::
#define int long long
#define F(i,a,b) for(int i=(a);i<=(b);i++)
#define ff(i,a,b) for(int i=(a);i>=(b);i--)
#define me(x,y) memset(x,y,sizeof x)
#define pii sd pair<int,int>
#define X first
#define Y second
#define dbg(x) sd cout<<#x<<":"<<x<<" "
#define dg(x) sd cout<<#x<<":"<<x<<"\n"
#define inf 1e10
int read(){int w=1,c=0;char ch=getchar();for(;ch>'9'||ch<'0';ch=getchar()) if(ch=='-') w=-1;for(;ch>='0'&&ch<='9';ch=getchar()) c=(c<<3)+(c<<1)+ch-48;return w*c;}
void printt(int x){if(x>9) printt(x/10);putchar(x%10+48);}
void print(int x){if(x<0) putchar('-'),printt(-x);else printt(x);}
void put(int x){print(x);putchar('\n');}
void printk(int x){print(x);putchar(' ');}
const int N=5e5+10,P=1e9+7,V=1e9;
int n,num,rt[N],p[N],ans[N];
sd vector<int> g[N];
struct node
{
int val,l,r;
}s[N*20];
#define ls(k) s[k].l
#define rs(k) s[k].r
void update(int &k,int l,int r,int x)
{
if(!k) k=++num;
if(l==r)
{
s[k].val=1;
return;
}
int mid=l+r>>1;
if(x<=mid) update(ls(k),l,mid,x);
else update(rs(k),mid+1,r,x);
s[k].val=s[ls(k)].val+s[rs(k)].val;
}
int ask(int k,int l,int r,int x,int y)
{
if(x<=l&&y>=r) return s[k].val;
int mid=l+r>>1,res=0;
if(x<=mid&&ls(k)) res+=ask(ls(k),l,mid,x,y);
if(y>mid&&rs(k))res+=ask(rs(k),mid+1,r,x,y);
return res;
}
void merge(int l,int r,int x,int y)//把 x 合并到 y 上
{
if(l==r)
{
s[y].val+=s[x].val;
return;
}
int mid=l+r>>1;
if(!ls(y)) ls(y)=ls(x);
else if(ls(x)) merge(l,mid,ls(x),ls(y));
if(!rs(y)) rs(y)=rs(x);
else if(rs(x)) merge(mid+1,r,rs(x),rs(y));
s[y].val=s[ls(y)].val+s[rs(y)].val;
}
void dfs(int u)
{
for(auto v:g[u])
{
dfs(v);
ans[u]+=ask(rt[v],1,V,p[u]+1,V);
merge(1,V,rt[v],rt[u]);
}
}
void solve()
{
n=read();
F(i,1,n)
{
p[i]=read();
update(rt[i],1,V,p[i]);
}
F(i,2,n)
{
int x=read();
g[x].emplace_back(i);
}
dfs(1);
F(i,1,n) put(ans[i]);
}
signed main()
{
// freopen(".in","r",stdin);
// freopen(".out","w",stdout);
int T=1;
// T=read();
while(T--) solve();
return 0;
}
[POI 2011] ROT-Tree Rotations
假设左/右子树的逆序对分别为 \(x,y\),则总的逆序对就是 \(x+y\) 再加上左子树比右子树大的数量。
反过来的话就是右子树比左子树大的数量。
不难发现 \(u\) 的决策不影响 \(fa_u\) 另一棵子树的决策。
于是我们直接从上到下每个点单独考虑,考虑怎么快速算出左子树比右子树大的数量。
直接上线段树合并。
在合并 \([l,r]\) 的时候,顺带计算一下值域 \([l,r]\) 中右子树比左子树大的数量 \(val_{l,r}\)。
不难发现 \(val_{l,r}=val_{l,mid}+val_{mid+1,r}\) 然后加上 \([l,mid]\) 中左子树的数量乘上 \([mid+1,r]\) 右子树的数量。
感觉很好做啊,直接就做完了。
稍微卡了一下空间过了。
#include<bits/stdc++.h>
#define sd std::
// #define int long long
#define ll long long
#define F(i,a,b) for(int i=(a);i<=(b);i++)
#define ff(i,a,b) for(int i=(a);i>=(b);i--)
#define me(x,y) memset(x,y,sizeof x)
#define pii sd pair<int,int>
#define X first
#define Y second
#define dbg(x) sd cout<<#x<<":"<<x<<" "
#define dg(x) sd cout<<#x<<":"<<x<<"\n"
#define inf 1e10
int read(){int w=1,c=0;char ch=getchar();for(;ch>'9'||ch<'0';ch=getchar()) if(ch=='-') w=-1;for(;ch>='0'&&ch<='9';ch=getchar()) c=(c<<3)+(c<<1)+ch-48;return w*c;}
void printt(int x){if(x>9) printt(x/10);putchar(x%10+48);}
void print(int x){if(x<0) putchar('-'),printt(-x);else printt(x);}
void put(int x){print(x);putchar('\n');}
void printk(int x){print(x);putchar(' ');}
const int N=4e5+10,P=1e9+7;
int n,lson[N],rson[N],id=1,num;
int rt[N];
struct node
{
int val;
int l,r;
}s[N*10];
#define ls(k) s[k].l
#define rs(k) s[k].r
void update(int &k,int l,int r,int x)
{
if(!k) k=++num;
if(l==r)
{
s[k].val=1;
return;
}
int mid=l+r>>1;
if(x<=mid) update(ls(k),l,mid,x);
else update(rs(k),mid+1,r,x);
s[k].val=s[ls(k)].val+s[rs(k)].val;
}
ll calc(int l,int r,int x,int y)//x是左边的,y是右边的
{
if(l==r) return 0;
int mid=l+r>>1;
ll res=(ll)s[rs(y)].val*s[ls(x)].val;
if(ls(x)&&ls(y)) res+=calc(l,mid,ls(x),ls(y));
if(rs(x)&&rs(y)) res+=calc(mid+1,r,rs(x),rs(y));
return res;
}
void merge(int l,int r,int x,int y)
{
if(l==r)
{
s[y].val+=s[x].val;
return;
}
int mid=l+r>>1;
if(!ls(y)) ls(y)=ls(x);
else if(ls(x)) merge(l,mid,ls(x),ls(y));
if(!rs(y)) rs(y)=rs(x);
else if(rs(x)) merge(mid+1,r,rs(x),rs(y));
s[y].val=s[ls(y)].val+s[rs(y)].val;
}
void input(int u)
{
int x=read();
if(x)
{
update(rt[u],1,n,x);
}
else
{
input(lson[u]=++id);
input(rson[u]=++id);
}
}
ll dfs(int u)//sum为节点数量
{
// dg(u);
// dbg(lson[u]),dg(rson[u]);
if(!lson[u]&&!rson[u]) return 0;
ll val=dfs(lson[u])+dfs(rson[u]);
// dbg(val);
ll p1=calc(1,n,rt[lson[u]],rt[rson[u]]);
// dbg(p1);
ll p2=calc(1,n,rt[rson[u]],rt[lson[u]]);
// dg(p2);
val+=sd min(p1,p2);
merge(1,n,rt[lson[u]],rt[rson[u]]);
rt[u]=rt[rson[u]];
return val;
}
void solve()
{
n=read();
input(1);
sd cout<<dfs(1);
}
int main()
{
// freopen(".in","r",stdin);
// freopen(".out","w",stdout);
int T=1;
// T=read();
while(T--) solve();
return 0;
}
CF208E Blood Cousins
\(u\) 的 \(p\) 级表亲数量就是 \(u\) 的 \(p\) 级祖先的距离为 \(p\) 的儿子数量。
考虑直接给每个点维护其 \(k\) 级儿子,\(v\) 合并到 \(u\) 就是右移一位然后插入一个数。
但是其实不用这么麻烦,考虑线段树维护的叶子 \(val_x\) 代表 \(u\) 子树内 \(dep=x\) 的结点数量。
然后查询的时候就是询问某个点的深度加上 \(k\) 这个深度有多少个节点,就直接问就行。
但是由于众所周知线段树合并不能在线,所以得离线到每个点上做询问。
#include<bits/stdc++.h>
#define sd std::
// #define int long long
#define F(i,a,b) for(int i=(a);i<=(b);i++)
#define ff(i,a,b) for(int i=(a);i>=(b);i--)
#define me(x,y) memset(x,y,sizeof x)
#define pii sd pair<int,int>
#define X first
#define Y second
#define dbg(x) sd cout<<#x<<":"<<x<<" "
#define dg(x) sd cout<<#x<<":"<<x<<"\n"
#define inf 1e10
int read(){int w=1,c=0;char ch=getchar();for(;ch>'9'||ch<'0';ch=getchar()) if(ch=='-') w=-1;for(;ch>='0'&&ch<='9';ch=getchar()) c=(c<<3)+(c<<1)+ch-48;return w*c;}
void printt(int x){if(x>9) printt(x/10);putchar(x%10+48);}
void print(int x){if(x<0) putchar('-'),printt(-x);else printt(x);}
void put(int x){print(x);putchar('\n');}
void printk(int x){print(x);putchar(' ');}
const int N=1e5+10,P=1e9+7;
int n,num,ans[N],rt[N];
sd vector<int> g[N],root;
sd vector<pii> q[N];
struct node
{
int l,r,val;
}s[N*20];
#define ls(k) s[k].l
#define rs(k) s[k].r
void update(int &k,int l,int r,int x)
{
if(!k) k=++num;
if(l==r)
{
s[k].val++;
return;
}
int mid=l+r>>1;
if(x<=mid) update(ls(k),l,mid,x);
else update(rs(k),mid+1,r,x);
s[k].val=s[ls(k)].val+s[rs(k)].val;
}
void merge(int l,int r,int x,int y)
{
if(l==r)
{
s[y].val+=s[x].val;
return;
}
int mid=l+r>>1;
if(!ls(y)) ls(y)=ls(x);
else if(ls(x)) merge(l,mid,ls(x),ls(y));
if(!rs(y)) rs(y)=rs(x);
else if(rs(x)) merge(mid+1,r,rs(x),rs(y));
s[y].val=s[ls(y)].val+s[rs(y)].val;
}
int ask(int k,int l,int r,int x)
{
if(l==r) return s[k].val;
int mid=l+r>>1;
if(x<=mid&&ls(k)) return ask(ls(k),l,mid,x);
else if(rs(k)) return ask(rs(k),mid+1,r,x);
return 0;
}
int dep[N],f[N][21];
void dfs1(int u)
{
for(auto v:g[u])
{
dep[v]=dep[u]+1;
f[v][0]=u;
F(i,0,19) f[v][i+1]=f[f[v][i]][i];
dfs1(v);
}
}
void dfs2(int u)
{
update(rt[u],1,n,dep[u]);
for(auto v:g[u])
{
dfs2(v);
merge(1,n,rt[v],rt[u]);
}
for(auto [k,id]:q[u])
{
ans[id]=ask(rt[u],1,n,dep[u]+k);
}
}
int find(int u,int k)//查询祖先
{
int val=dep[u]-k;
ff(i,20,0) if(dep[f[u][i]]>=val) u=f[u][i];
return u;
}
void solve()
{
n=read();
F(i,1,n)
{
int x=read();
if(!x) root.emplace_back(i);
else g[x].emplace_back(i);
}
for(auto u:root)
{
dep[u]=1;
dfs1(u);
}
int Q=read();
F(i,1,Q)
{
int v=read(),p=read();
q[find(v,p)].emplace_back(p,i);
}
for(auto u:root) dfs2(u);
F(i,1,Q) printk(ans[i]?ans[i]-1:0);
}
int main()
{
// freopen(".in","r",stdin);
// freopen(".out","w",stdout);
int T=1;
// T=read();
while(T--) solve();
return 0;
}
P3899 [湖南集训] 更为厉害
不难发现 \(a,b\) 肯定是祖先和子孙的关系。
可以分为 \(b\) 为 \(a\) 祖先和 \(b\) 为 \(a\) 子孙。
第一种情况很简单,即 \(siz_u\) 乘上和 \(u\) 距离不超过 \(k\) 的祖先。
考虑第二种情况怎么计算。
即计算 \(\sum\limits_{x\in \operatorname{Subtree}(u),dis(u)} (siz_x-1)\)。
考虑直接每个节点开一颗线段树,叶子 \(i\) 就是 \(dep=i\) 的 \(\sum siz-1\) 之和。
然后询问的时候直接查线段树上某一段的和即可。
这个还是比较简单的,直接上线段树合并即可。
#include<bits/stdc++.h>
#define sd std::
#define int long long
#define F(i,a,b) for(int i=(a);i<=(b);i++)
#define ff(i,a,b) for(int i=(a);i>=(b);i--)
#define me(x,y) memset(x,y,sizeof x)
#define pii sd pair<int,int>
#define X first
#define Y second
#define dbg(x) sd cout<<#x<<":"<<x<<" "
#define dg(x) sd cout<<#x<<":"<<x<<"\n"
#define inf 1e10
int read(){int w=1,c=0;char ch=getchar();for(;ch>'9'||ch<'0';ch=getchar()) if(ch=='-') w=-1;for(;ch>='0'&&ch<='9';ch=getchar()) c=(c<<3)+(c<<1)+ch-48;return w*c;}
void printt(int x){if(x>9) printt(x/10);putchar(x%10+48);}
void print(int x){if(x<0) putchar('-'),printt(-x);else printt(x);}
void put(int x){print(x);putchar('\n');}
void printk(int x){print(x);putchar(' ');}
const int N=3e5+10,P=1e9+7;
int n,Q,siz[N],dep[N],num,rt[N],ans[N];
struct node
{
int val,l,r;
}s[N*20];
#define ls(k) s[k].l
#define rs(k) s[k].r
void update(int &k,int l,int r,int x,int y)
{
if(!k) k=++num;
if(l==r)
{
s[k].val+=y;
return;
}
int mid=l+r>>1;
if(x<=mid) update(ls(k),l,mid,x,y);
else update(rs(k),mid+1,r,x,y);
s[k].val=s[ls(k)].val+s[rs(k)].val;
}
void merge(int l,int r,int x,int y)//x合并到y
{
if(l==r)
{
s[y].val+=s[x].val;
return;
}
int mid=l+r>>1;
if(!ls(y)) ls(y)=ls(x);
else if(ls(x)) merge(l,mid,ls(x),ls(y));
if(!rs(y)) rs(y)=rs(x);
else if(rs(x)) merge(mid+1,r,rs(x),rs(y));
s[y].val=s[ls(y)].val+s[rs(y)].val;
}
int ask(int k,int l,int r,int x,int y)
{
if(x<=l&&y>=r) return s[k].val;
int mid=l+r>>1,res=0;
if(x<=mid&&ls(k)) res+=ask(ls(k),l,mid,x,y);
if(y>mid&&rs(k)) res+=ask(rs(k),mid+1,r,x,y);
return res;
}
sd vector<int> g[N];
sd vector<pii> q[N];
void dfs1(int u,int fa)
{
siz[u]=1;
for(auto v:g[u])
{
if(v==fa) continue;
dep[v]=dep[u]+1;
dfs1(v,u);
siz[u]+=siz[v];
}
}
void dfs2(int u,int fa)
{
update(rt[u],1,n,dep[u],siz[u]-1);
for(auto v:g[u])
{
if(v==fa) continue;
dfs2(v,u);
merge(1,n,rt[v],rt[u]);
}
for(auto [k,id]:q[u])
{
int val=(dep[u]==n?0:ask(rt[u],1,n,dep[u]+1,sd min(n,dep[u]+k)));
ans[id]=sd min(dep[u]-1,k)*(siz[u]-1)+val;
}
}
void solve()
{
n=read(),Q=read();
F(i,2,n)
{
int x=read(),y=read();
g[x].emplace_back(y);
g[y].emplace_back(x);
}
dep[1]=1;
dfs1(1,0);
F(i,1,Q)
{
int p=read(),k=read();
q[p].emplace_back(k,i);
}
dfs2(1,0);
F(i,1,Q) put(ans[i]);
}
signed main()
{
// freopen(".in","r",stdin);
// freopen(".out","w",stdout);
int T=1;
// T=read();
while(T--) solve();
return 0;
}
P5298 [PKUWC2018] Minimax
最初的想法是考虑每个节点开一颗线段树,每个叶子维护其作为这个值的概率乘以10000。
然后区间就维护这个值域的总概率 \(val_{l,r}\),即选到这个值域的概率是多少。
考虑左右儿子 \(x,y\) 的两颗线段树怎么合并。考虑值域 \([l,r]\) 对 \(u\) 线段树的影响。
将这个影响分为三类:
- \(x,y\) 都选择 \([l,mid]\) 的数。
- \(x,y\) 都选择 \([mid+1,r]\) 的数。
- \(x,y\) 一个选择 \([l,mid]\),一个选择 \([mid+1,r]\) 的数。
前两者递归处理,我们只需要处理跨过中点的贡献。
贡献式子写出来之后我们发现大概就是 \(x\) 的 \([mid+1,r]\) 区间会对 \(u\) 的 \([mid+1,r]\) 区间一一对应的造成其自身乘以 \(valy_{l,mid}\times p_u\) 的贡献。
剩下的同理。
但是我们不能直接在每个 \([l,r]\) 处打乘法标记,一是 \(x,y\) 的乘法标记不互通,而是我们在递归子区间的时候处理的显然应该是没有打标记的结果。那这个乘法标记究竟在哪里下传就很难办。
考虑将这个乘法标记累计起来,如果 \(x,y\) 都有某一方的节点就继续遍历,否则将累计的乘法标记打上去。
#include<bits/stdc++.h>
#define sd std::
#define int long long
#define F(i,a,b) for(int i=(a);i<=(b);i++)
#define ff(i,a,b) for(int i=(a);i>=(b);i--)
#define me(x,y) memset(x,y,sizeof x)
#define pii sd pair<int,int>
#define X first
#define Y second
#define dbg(x) sd cout<<#x<<":"<<x<<" "
#define dg(x) sd cout<<#x<<":"<<x<<"\n"
#define inf 1e10
int read(){int w=1,c=0;char ch=getchar();for(;ch>'9'||ch<'0';ch=getchar()) if(ch=='-') w=-1;for(;ch>='0'&&ch<='9';ch=getchar()) c=(c<<3)+(c<<1)+ch-48;return w*c;}
void printt(int x){if(x>9) printt(x/10);putchar(x%10+48);}
void print(int x){if(x<0) putchar('-'),printt(-x);else printt(x);}
void put(int x){print(x);putchar('\n');}
void printk(int x){print(x);putchar(' ');}
const int N=3e5+10,P=998244353,V=1e9;
int n,num,p[N],rt[N],lson[N],rson[N];
struct node
{
int val,l,r,tag;
}s[N*20];
#define ls(k) s[k].l
#define rs(k) s[k].r
int L,R;
void pushdown(int k)
{
int tg=s[k].tag;
if(tg==1) return;
if(ls(k))
{
s[ls(k)].tag=s[ls(k)].tag*tg%P;
s[ls(k)].val=s[ls(k)].val*tg%P;
}
if(rs(k))
{
s[rs(k)].tag=s[rs(k)].tag*tg%P;
s[rs(k)].val=s[rs(k)].val*tg%P;
}
s[k].tag=1;
}
void update(int &k,int l,int r,int x)
{
if(!k) k=++num,s[k].tag=1;
if(l==r)
{
s[k].val=1;
return;
}
int mid=l+r>>1;
if(x<=mid) update(ls(k),l,mid,x);
else update(rs(k),mid+1,r,x);
s[k].val=s[ls(k)].val+s[rs(k)].val;
}
int d;//p[u]
void merge(int l,int r,int x,int &y,int tagx,int tagy)
{
if(!x&&!y) return;
if(!x)
{
s[y].tag=s[y].tag*tagy%P;
s[y].val=s[y].val*tagy%P;
return;
}
if(!y)
{
y=x;
s[y].tag=s[y].tag*tagx%P;
s[y].val=s[y].val*tagx%P;
return;
}
pushdown(x);pushdown(y);
int vly=s[ls(y)].val,vry=s[rs(y)].val,vlx=s[ls(x)].val,vrx=s[rs(x)].val;
L=l,R=r;
int mid=l+r>>1;
merge(l,mid,ls(x),ls(y),(tagx+vry*(1-d+P))%P,(tagy+vrx*(1-d+P))%P);
merge(mid+1,r,rs(x),rs(y),(tagx+vly*d)%P,(tagy+vlx*d)%P);
s[y].val=(s[ls(y)].val+s[rs(y)].val)%P;
}
int cnt,ans;
int now;
void out(int k,int l,int r)
{
if(l==r)
{
++cnt;
(ans+=l*cnt%P*s[k].val%P*s[k].val%P)%=P;
return;
}
L=l,R=r;
pushdown(k);
int mid=l+r>>1;
if(ls(k)) out(ls(k),l,mid);
if(rs(k)) out(rs(k),mid+1,r);
}
void dfs(int u)
{
if(!lson[u]) return;
dfs(lson[u]);
if(rson[u])
{
dfs(rson[u]);
d=p[u];
merge(1,V,rt[rson[u]],rt[lson[u]],0,0);
}
rt[u]=rt[lson[u]];
}
void solve()
{
n=read();read();
F(i,2,n)
{
int x=read();
if(!lson[x]) lson[x]=i;
else if(!rson[x]) rson[x]=i;
}
int inv=796898467;//10000的逆元
F(u,1,n)
{
int x=read();
if(!lson[u]) update(rt[u],1,V,x);
else p[u]=x*inv%P;
}
dfs(1);
out(rt[1],1,V);
put(ans);
}
signed main()
{
// freopen(".in","r",stdin);
// freopen(".out","w",stdout);
int T=1;
// T=read();
while(T--) solve();
return 0;
}

浙公网安备 33010602011771号