P9555 「CROI · R1」浣熊的阴阳鱼
树链剖分真可爱。
给出 \(n\) 个点的树,点有点权 \(a_i(a_i\in\{0,1\})\)。支持 \(q\) 次操作:
\(\texttt{1 }u\),\(a_u\leftarrow \lnot a_u\)。
\(\texttt{2 }u\texttt{ }v\),你带着一个最大大小为 \(2\) 的可重集 \(S\) 从 \(u\) 走到 \(v\),初始 \(S=\varnothing\),遇到一个点 \(x\),若 \(\lnot a_x\in S\),则删除 \(\lnot a_x\),得分 \(+1\),否则将 \(a_x\) 插入 \(S\)。求最终的得分。
\(n,q\le 10^5\)。
看到树上修改和路径查询,首先想到树剖。我们发现 \(|S|\le 2\),且 \(S\) 内元素的顺序不影响答案,因此我们可以用一个二元组 \((i,j)(0\le i\le j\le 2)\) 记录 \(S\) 的状态(\(0,1\) 表示放了什么元素,\(2\) 表示没有元素)。为了方便表述,将 \(2\) 也认为是 \(S\) 内的元素,即强制 \(|S|=2\)。
分析一下询问,相当于已经给定了初始状态 \(S=\{a_u,2\}\)。再根据树剖的思想,我们要快速查询一条重链上的信息,即快速查询那条链对应的区间的信息。考虑使用线段树维护。具体地,对于线段树上的一个节点 \(x\)(设对应的区间为 \([l,r]\)),记 \(f_{x,i,j}\) 表示从 \(l\) 位置开始走,经过 \(l\) 位置后 \(S=\{i,j\}\),走到 \(r\) 时的得分,类似地 \(g_{x,i,j}\) 表示从 \(r\) 走到 \(l\) 的得分。还需要记录 \(LtoR_{x,i,j}\) 表示从 \(l\) 位置开始走,经过 \(l\) 位置后 \(S=\{i,j\}\),到达 \(r\) 后 \(S\) 的状态,同理还有 \(RtoL_{x,i,j}\) 表示从 \(r\) 到达 \(l\) 后 \(S\) 的状态。注意,我们强调了 \(S\) 是经过端点后的状态,意味着上述的 \(S\) 都已经受端点的影响,即统计答案的时候不再受到起点的影响(因为已经受过了)。类似地,强调 \(LtoR\) 和 \(RtoL\) 是到达端点后的状态,说明信息已经受到终点的影响,因为状态的定义里保证了它受到起点的影响,所以同时保证了统计了完整的信息。
考虑如何合并区间信息,我们发现需要快速计算出已经到了一个区间末尾,下一步走到另一半区间开头时,\(S\) 的状态,因此还需要记录 \(lc_{x},rc_{x}\) 来存储 \(a_l\) 和 \(a_r\)。所以线段树的节点是张这样的:
#define pii pair<int,int>
struct node{//变量名有部分不同。
int cnt_l[3][3],cnt_r[3][3],lc,rc;//f,g,lc,rc。
pii status_l[3][3],status_r[3][3];//LtoR,RtoL。
}seg[N<<2];
然后计算状态可以通过以下的函数实现(我写的比较暴力,直接枚举 \(12\) 种情况分类讨论):
#define ppi pair<pii,int>
#define mp make_pair
ppi get_status(pii p,int x){//第二维表示得分增量。
if(p==mp(0,0)&&!x){
return mp(p,0);
}
if(p==mp(0,0)&&x){
return mp(mp(0,2),1);
}
if(p==mp(0,1)&&!x){
return mp(mp(0,2),1);
}
if(p==mp(0,1)&&x){
return mp(mp(1,2),1);
}
if(p==mp(0,2)&&!x){
return mp(mp(0,0),0);
}
if(p==mp(0,2)&&x){
return mp(mp(2,2),1);
}
if(p==mp(1,1)&&!x){
return mp(mp(1,2),1);
}
if(p==mp(1,1)&&x){
return mp(mp(1,1),0);
}
if(p==mp(1,2)&&!x){
return mp(mp(2,2),1);
}
if(p==mp(1,2)&&x){
return mp(mp(1,1),0);
}
return mp(mp(x,2),0);//空集就直接放。
}
区间信息具体合并方法为,先从计算当前状态走到区间末尾的信息,然后计算跨过区间时的状态,以及计算跨区间这一步对得分的贡献。然后再从另一半区间,以跨区间时的状态开始走,计算得分。从当前起点走到另一个端点的状态,就是从另一半区间,以跨区间时的状态开始走,走到那个端点时的状态。代码如下:
#define fi first
#define se second
node merge(node l,node r){
node ret;
ret.lc=l.lc;
ret.rc=r.rc;
for(int i=0;i<=2;++i){
for(int j=i;j<=2;++j){
ppi l_start=get_status(r.status_r[i][j],l.rc),r_start=get_status(l.status_l[i][j],r.lc);
ret.cnt_l[i][j]=l.cnt_l[i][j]+r_start.se+r.cnt_l[r_start.fi.fi][r_start.fi.se];
ret.status_l[i][j]=r.status_l[r_start.fi.fi][r_start.fi.se];
ret.cnt_r[i][j]=r.cnt_r[i][j]+l_start.se+l.cnt_r[l_start.fi.fi][l_start.fi.se];
ret.status_r[i][j]=l.status_r[l_start.fi.fi][l_start.fi.se];
}
}
return ret;
}
对于长度为 \(1\) 的区间,有初始化:
seg[x].lc=seg[x].rc=b[l];//b[l] 是那个点的元素值。
for(int i=0;i<=2;++i){
for(int j=i;j<=2;++j){
seg[x].cnt_l[i][j]=seg[x].cnt_r[i][j]=0;//端点已经考虑过且没有遇到新元素,对得分无贡献。
seg[x].status_l[i][j]=seg[x].status_r[i][j]=mp(i,j);//起点终点相同,受到起点的影响即受到了终点的影响。
}
}
那么单点修改也很好维护:
seg[x].lc^=1;
seg[x].rc^=1;
查询就是跳链查询。注意合并信息的顺序,具体可以参考 GSS7。
时间复杂度为 \(\mathcal{O}(q\log ^2n)\),空间复杂度为 \(\mathcal{O}(n)\)。
代码
#include<bits/stdc++.h>
#define ls(x) ((x)<<1)
#define rs(x) ((x)<<1|1)
#define fi first
#define se second
#define pii pair<int,int>
#define ppi pair<pii,int>
#define mp make_pair
using namespace std;
const int N=1e5+5;
int n,q,a[N],siz[N],dep[N],top[N],hson[N],fa[N],dfn[N],id,b[N];
vector<int>g[N];
struct node{
int cnt_l[3][3],cnt_r[3][3],lc,rc;//l to r;r to l;
pii status_l[3][3],status_r[3][3];//start l;start r;
}seg[N<<2];
void dfs1(int u){
siz[u]=1;
for(int v:g[u]){
if(v!=fa[u]){
dep[v]=dep[u]+1;
fa[v]=u;
dfs1(v);
siz[u]+=siz[v];
}
}
}
void dfs2(int u){
for(int v:g[u]){
if(v!=fa[u]){
if((siz[v]<<1)>siz[u]){
hson[u]=v;
top[v]=top[u];
}else{
top[v]=v;
}
dfs2(v);
}
}
}
void dfs3(int u){
dfn[u]=++id;
b[id]=a[u];
if(hson[u]){
dfs3(hson[u]);
}
for(int v:g[u]){
if(v!=fa[u]&&v!=hson[u]){
dfs3(v);
}
}
}
ppi get_status(pii p,int x){
if(p==mp(0,0)&&!x){
return mp(p,0);
}
if(p==mp(0,0)&&x){
return mp(mp(0,2),1);
}
if(p==mp(0,1)&&!x){
return mp(mp(0,2),1);
}
if(p==mp(0,1)&&x){
return mp(mp(1,2),1);
}
if(p==mp(0,2)&&!x){
return mp(mp(0,0),0);
}
if(p==mp(0,2)&&x){
return mp(mp(2,2),1);
}
if(p==mp(1,1)&&!x){
return mp(mp(1,2),1);
}
if(p==mp(1,1)&&x){
return mp(mp(1,1),0);
}
if(p==mp(1,2)&&!x){
return mp(mp(2,2),1);
}
if(p==mp(1,2)&&x){
return mp(mp(1,1),0);
}
return mp(mp(x,2),0);
}
node merge(node l,node r){
node ret;
ret.lc=l.lc;
ret.rc=r.rc;
for(int i=0;i<=2;++i){
for(int j=i;j<=2;++j){
ppi l_start=get_status(r.status_r[i][j],l.rc),r_start=get_status(l.status_l[i][j],r.lc);
ret.cnt_l[i][j]=l.cnt_l[i][j]+r_start.se+r.cnt_l[r_start.fi.fi][r_start.fi.se];
ret.status_l[i][j]=r.status_l[r_start.fi.fi][r_start.fi.se];
ret.cnt_r[i][j]=r.cnt_r[i][j]+l_start.se+l.cnt_r[l_start.fi.fi][l_start.fi.se];
ret.status_r[i][j]=l.status_r[l_start.fi.fi][l_start.fi.se];
}
}
return ret;
}
void build(int x,int l,int r){
if(l==r){
seg[x].lc=seg[x].rc=b[l];
for(int i=0;i<=2;++i){
for(int j=i;j<=2;++j){
seg[x].cnt_l[i][j]=seg[x].cnt_r[i][j]=0;
seg[x].status_l[i][j]=seg[x].status_r[i][j]=mp(i,j);
}
}
return;
}
int mid=(l+r)>>1;
build(ls(x),l,mid);
build(rs(x),mid+1,r);
seg[x]=merge(seg[ls(x)],seg[rs(x)]);
}
void modify(int x,int l,int r,int k){
if(l==r){
seg[x].lc^=1;
seg[x].rc^=1;
return;
}
int mid=(l+r)>>1;
if(k<=mid){
modify(ls(x),l,mid,k);
}else{
modify(rs(x),mid+1,r,k);
}
seg[x]=merge(seg[ls(x)],seg[rs(x)]);
}
node query(int x,int l,int r,int ql,int qr){
if(ql<=l&&r<=qr){
return seg[x];
}
int mid=(l+r)>>1;
if(qr<=mid){
return query(ls(x),l,mid,ql,qr);
}
if(ql>mid){
return query(rs(x),mid+1,r,ql,qr);
}
return merge(query(ls(x),l,mid,ql,qr),query(rs(x),mid+1,r,ql,qr));
}
node pathquery(int x,int y){
node info_x,info_y,temp;
bool empty_x=1,empty_y=1;
while(top[x]!=top[y]){
if(dep[top[x]]>dep[top[y]]){
temp=query(1,1,n,dfn[top[x]],dfn[x]);
if(empty_x){
info_x=temp;
empty_x=0;
}else{
info_x=merge(temp,info_x);
}
x=fa[top[x]];
}else{
temp=query(1,1,n,dfn[top[y]],dfn[y]);
if(empty_y){
info_y=temp;
empty_y=0;
}else{
info_y=merge(temp,info_y);
}
y=fa[top[y]];
}
}
if(dep[x]<dep[y]){
temp=query(1,1,n,dfn[x],dfn[y]);
if(!empty_x){
swap(info_x.cnt_l,info_x.cnt_r);
swap(info_x.status_l,info_x.status_r);
if(!empty_y){
return merge(info_x,merge(temp,info_y));
}
return merge(info_x,temp);
}
if(!empty_y){
return merge(temp,info_y);
}
return temp;
}else{
temp=query(1,1,n,dfn[y],dfn[x]);
if(!empty_x){
info_x=merge(temp,info_x);
swap(info_x.cnt_l,info_x.cnt_r);
swap(info_x.status_l,info_x.status_r);
if(!empty_y){
return merge(info_x,info_y);
}
return info_x;
}
swap(temp.cnt_l,temp.cnt_r);
swap(temp.status_l,temp.status_r);
if(!empty_y){
return merge(temp,info_y);
}
return temp;
}
}
signed main(){
cin.tie(0);
cout.tie(0);
ios::sync_with_stdio(0);
cin>>n>>q;
for(int i=1;i<=n;++i){
cin>>a[i];
}
for(int i=1,u,v;i<n;++i){
cin>>u>>v;
g[u].emplace_back(v);
g[v].emplace_back(u);
}
dfs1(1);
top[1]=1;
dfs2(1);
dfs3(1);
build(1,1,n);
for(int op,u,v,i=1;i<=q;++i){
cin>>op>>u;
if(op==1){
a[u]^=1;
modify(1,1,n,dfn[u]);
}else{
cin>>v;
node ans=pathquery(u,v);
cout<<ans.cnt_l[a[u]][2]<<'\n';
}
}
return 0;
}

浙公网安备 33010602011771号