个人线段树写法 & 注意逝项
前言
众所周知由于一些原因,我们有时候需要写一些维护较多东西的线段树,如 P4513 小白逛公园 这种。
这个过程中,不妙的实现(比如说像某位李姓,名字最后一个字是木字旁的性感同学的常见实现),比如随意多开线段树,大量使用 if,大量复制粘贴来完成的,难写难调,很容易爆炸。
那么相反的,合理的实现,就会有清晰的架构,简易的代码。
起因是最近写一道树剖题目,其中的线段树类似小白逛公园类,但是有时需要区间取反以及合并答案时需要整体翻转,涵盖范围相对全。
举例
具体问题:树上,点权为 \(1\) 或者 \(-1\),操作包含:路径乘 \(-1\),路径问长度减前缀 \(\max\) 加后缀 \(\min\)。
分析:类似小白逛公园,由于区间乘 \(-1\) 的操作,我还需要维护前缀 \(\min\) 和后缀 \(\max\)。
另外为了合并两个区间,需要维护区间 \(\text{sum}\)。
分析一下,需要复杂合并 / 下传 / 维护答案。
采用如下实现方式:
-
node开节点,原因见下一条: -
node operator +这样build()、change()里面的pushup(),合并答案都可以解决。 -
node chg(node x)将pushdown()内对两个子节点分别的操作,和change()返回前的修改操作统一。 -
node rev(node x)树上的问题中,链拼起来可能需要翻转顺序,这个就适用。
放个示例看看。
#include<bits/stdc++.h>
using namespace std;
struct node{
long long l,r,len,le,lx,ln,rx,rn,tg;
}t[800110];
node operator + (const node &x,const node &y)
{
return {x.l,y.r,x.len+y.len,x.le+y.le,max(x.lx,x.len+y.lx),min(x.ln,x.len+y.ln),max(y.rx,y.len+x.rx),min(y.rn,y.len+x.rn),0};
}
long long a[100050];
#define mid ((t[o].l+t[o].r)>>1)
#define ls (o<<1)
#define rs ((o<<1)^1)
node chg(node x)
{
return {x.l,x.r,-x.len,x.le,-x.ln,-x.lx,-x.rn,-x.rx,x.tg^1};
}
node rev(node x)
{
return {x.l,x.r,x.len,x.le,x.rx,x.rn,x.lx,x.ln,0};
}
void build(int l,int r,int o)
{
t[o].l=l,t[o].r=r;
if(l==r)
{
if(a[l]==1) t[o].lx=t[o].rx=t[o].len=1;
else t[o].ln=t[o].rn=t[o].len=-1;
t[o].le=1;
return;
}
build(l,mid,ls);
build(mid+1,r,rs);
t[o]=t[ls]+t[rs];
}
void pushdown(int o) //why not spread lol
{
if(t[o].tg)
{
t[o].tg=0;
t[ls]=chg(t[ls]);
t[rs]=chg(t[rs]);
}
}
void change(int l,int r,int o)
{
if(l<=t[o].l&&t[o].r<=r)
{
t[o]=chg(t[o]);
return;
}
pushdown(o);
if(l<=mid) change(l,r,ls);
if(r>mid) change(l,r,rs);
t[o]=t[ls]+t[rs];
}
node ask(int l,int r,int o)
{
if(l<=t[o].l&&t[o].r<=r) return t[o];
pushdown(o);
if(l<=mid&&r>mid) return ask(l,r,ls)+ask(l,r,rs);
if(l<=mid) return ask(l,r,ls);
if(r>mid) return ask(l,r,rs);
}
vector<int> e[114514];
int fa[114514],si[114514],son[114514],dep[114514],ms[114514],dfn[114514],cnt,da[114514],top[114514];
void dfs1(int u)
{
si[u]=1;
for(auto v:e[u]) if(v!=fa[u])
{
fa[v]=u;
dep[v]=dep[u]+1;
dfs1(v);
si[u]+=si[v];
if(si[v]>ms[u])
{
ms[u]=si[v];
son[u]=v;
}
}
}
void dfs2(int u,int topf)
{
cnt++;
dfn[u]=cnt;
a[cnt]=da[u];
top[u]=topf;
if(!son[u]) return;
dfs2(son[u],topf);
for(auto v:e[u])
if(v!=fa[u]&&v!=son[u])
dfs2(v,v);
}
int n,m,i,j,r,u,v,w,x,y,z,in;
signed main()
{
freopen("loser.in","r",stdin);
freopen("loser.out","w",stdout);
cin>>n>>m;r=1;
for(i=1;i<n;i++)
{
cin>>u>>v;
e[u].push_back({v});e[v].push_back({u});
}
for(i=1;i<=n;i++)
cin>>da[i],da[i]=(da[i]?1:-1);
memset(ms,-1,sizeof(ms));
dfs1(r);dfs2(r,r);
build(1,n,1);
for(i=1;i<=m;i++)
{
cin>>in>>x>>y;
if(in==2)
{
node ansx,ansy;
ansx=ansy={0,0,0,0,0,0,0,0};
int rv=0;
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]]) swap(x,y),swap(ansx,ansy),rv^=1;
ansx=ask(dfn[top[x]],dfn[x],1)+ansx;
x=fa[top[x]];
}
if(dep[x]>dep[y]) swap(x,y),swap(ansx,ansy),rv^=1;
ansx=(rev(ansx)+ask(dfn[x],dfn[y],1))+ansy;
if(rv) ansx=rev(ansx);
cout<<ansx.le-ansx.lx+ansx.rn<<endl;
}
if(in==1)
{
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]]) swap(x,y);
change(dfn[top[x]],dfn[x],1);
x=fa[top[x]];
}
if(dep[x]>dep[y]) swap(x,y);
change(dfn[x],dfn[y],1);
}
}
return 0;
}

浙公网安备 33010602011771号