3.22 树链剖分
3.22 树链剖分好题分享
P4374 [USACO18OPEN] Disruption P
这道题我做的时候还是紫的,现在蓝了。
题意
给一棵树,再给若干条新边,每条新边有一个边权。对于每条树上的边,求它被砍掉之后使得树重新联通的最小代价。
题解
枚举每条树边直接求显然是很困难的,所以我们考虑计算每条非树边对树边的贡献。
不难发现,每条非树边会跟树上的若干条边形成唯一的一个环,那么这条边就可能在环上任意一条树边被砍掉之后成为替代它的边,也就是会对这些边产生贡献。
于是做法便明晰了:对树进行树链剖分,建线段树维护最小值,对于每条非树边,更新其两个端点之间的树链的答案。
#include <cstdio>
#define N 100005
#define M 800005
int min(int x,int y) {return x<y?x:y;}
int max(int x,int y) {return x>y?x:y;}
int n,m;
int hed[N],tal[N],nxt[N],cnte;
void adde(int u,int v)
{
tal[++cnte]=v;
nxt[cnte]=hed[u];
hed[u]=cnte;
}
struct sgt
{
int d[M],tg[M],ls[M],rs[M],idx;
#define mid (lb+rb>>1)
#define pushup(x) d[x]=max(d[ls[x]],d[rs[x]])
int upd(int x,int y)
{
if(x==-1) return y;
if(y==-1) return x;
return min(x,y);
}
int newnode()
{
int nx=++idx;
d[nx]=tg[nx]=-1;
return nx;
}
void pushdown(int x)
{
if(tg[x]==-1) return;
if(!ls[x]) ls[x]=newnode();
if(!rs[x]) rs[x]=newnode();
d[ls[x]]=upd(d[ls[x]],tg[x]);
d[rs[x]]=upd(d[rs[x]],tg[x]);
tg[ls[x]]=upd(tg[ls[x]],tg[x]);
tg[rs[x]]=upd(tg[rs[x]],tg[x]);
tg[x]=-1;
}
void modify(int &x,int k,int l,int r,int lb,int rb)
{
if(!x) x=newnode();
if(l<=lb&&rb<=r)
{
d[x]=upd(d[x],k);
tg[x]=upd(tg[x],k);
return;
}
pushdown(x);
if(l<=mid) modify(ls[x],k,l,r,lb,mid);
if(r>mid) modify(rs[x],k,l,r,mid+1,rb);
pushup(x);
}
int query(int x,int t,int lb,int rb)
{
if(!x) return -1;
if(lb==rb) return d[x];
pushdown(x);
if(t<=mid) return query(ls[x],t,lb,mid);
else return query(rs[x],t,mid+1,rb);
}
#undef mid
#undef pushup
} tr;
int rt,dfn[N],dep[N];
struct HLD
{
int fa[N],son[N],siz[N],top[N],idx;
void dfs1(int x)
{
siz[x]=1;
for(int i=hed[x];i;i=nxt[i])
if(!siz[tal[i]])
{
fa[tal[i]]=x,dep[tal[i]]=dep[x]+1;
dfs1(tal[i]);
siz[x]+=siz[tal[i]];
if(siz[tal[i]]>siz[son[x]]) son[x]=tal[i];
}
}
void dfs2(int x,int tp)
{
if(!x) return;
dfn[x]=++idx;
top[x]=tp;
dfs2(son[x],tp);
for(int i=hed[x];i;i=nxt[i])
if(!top[tal[i]]) dfs2(tal[i],tal[i]);
}
void init()
{
dfs1(1);
dfs2(1,1);
}
void modify(int u,int v,int w)
{
while(top[u]!=top[v])
{
if(dep[top[u]]<dep[top[v]]) {int tmp=u;u=v,v=tmp;}
tr.modify(rt,w,dfn[top[u]],dfn[u],1,n);
u=fa[top[u]];
}
if(dep[u]>dep[v]) {int tmp=u;u=v,v=tmp;};
tr.modify(rt,w,dfn[u]+1,dfn[v],1,n);
}
} hld;
int t1[N],t2[N];
int main()
{
scanf("%d%d",&n,&m);
for(int i=1;i<n;i++)
{
scanf("%d%d",&t1[i],&t2[i]);
adde(t1[i],t2[i]);
adde(t2[i],t1[i]);
}
hld.init();
for(int i=1;i<=m;i++)
{
int u,v,w;
scanf("%d%d%d",&u,&v,&w);
hld.modify(u,v,w);
}
for(int i=1;i<n;i++)
{
if(dep[t2[i]]>dep[t1[i]])
{
int tmp=t2[i];
t2[i]=t1[i],t1[i]=tmp;
}
printf("%d\n",tr.query(rt,dfn[t1[i]],1,n));
}
}
P3401 洛谷树
题意
树,单点修改,查询树链上每个子区间的异或和之和。
题解
首先考虑序列上的这个问题该怎么解决。
注意到题目最后一行写了:
对于 \(100\%\) 的数据,所有边权小于等于 \(1023\)。
这使我们联想到按位维护。
对于第 \(k\) 个二进制位,每个子区间的异或和的这一位,要么是 \(1\),要么是 \(0\)。所以要求每个子区间的异或和之和,只需要求有多少个子区间的异或和是 \(1\),然后把这个数量乘以 \(2^k\) 加入答案。
于是问题得到了简化:给一 \(01\) 序列,维护单点修改,求区间有多少个子区间异或和为 \(1\)。
这令我们联想到小白逛公园这道题。于是可以采用与这题类似的做法,用线段树维护每个节点对应区间的前后缀信息。
具体地,记录它本身的答案、异或和,以及它有多少个异或和为 \(1/0\) 的前后缀,于是 pushup 就能写出来了。
回到本题,由于是树上问题,所以使用树链剖分。需要注意边权转点权的问题。
注意树剖 query 的时候需要先求两边的链,再 reverse 一下,最后合并答案。
#include <cstdio>
#define N 30005
#define int long long
int n,q,cnte,hed[N],tal[N<<1],wt[N<<1],a[N],nxt[N<<1],rt[15];
void adde(int u,int v,int w) {tal[++cnte]=v,wt[cnte]=w,nxt[cnte]=hed[u],hed[u]=cnte;}
struct seq
{
int sum,d0,d1,l0,l1,r0,r1;
seq(int x=0):sum(x),d0(!x),d1(x),l0(!x),l1(x),r0(!x),r1(x) {}
seq rev()
{
seq ret;
ret.sum=sum,ret.d0=d0,ret.d1=d1,ret.l0=r0,ret.l1=r1,ret.r0=l0,ret.r1=l1;
return ret;
}
} ans[2][15];
seq merge(seq x,seq y)
{
if(x.sum==-1) return y;
if(y.sum==-1) return x;
seq ret;
ret.sum=x.sum^y.sum;
ret.d0=x.d0+y.d0+x.r0*y.l0+x.r1*y.l1;
ret.d1=x.d1+y.d1+x.r0*y.l1+x.r1*y.l0;
ret.l0=x.l0,ret.l1=x.l1;
if(x.sum) ret.l0+=y.l1,ret.l1+=y.l0;
else ret.l0+=y.l0,ret.l1+=y.l1;
ret.r0=y.r0,ret.r1=y.r1;
if(y.sum) ret.r0+=x.r1,ret.r1+=x.r0;
else ret.r0+=x.r0,ret.r1+=x.r1;
return ret;
}
int dfn[N],li[N],dep[N],fa[N],son[N],siz[N],top[N],idx;
struct sgt
{
seq d[N<<5];
int ls[N<<5],rs[N<<5],id;
#define mid (lb+rb>>1)
int build(int t,int lb,int rb)
{
int x=++id;
if(lb==rb) {d[x]=a[li[lb]]>>t&1;return x;}
ls[x]=build(t,lb,mid),rs[x]=build(t,mid+1,rb);
d[x]=merge(d[ls[x]],d[rs[x]]);
return x;
}
void modify(int x,int t,int lb,int rb)
{
if(lb==rb)
{
d[x]=d[x].sum^1;
return;
}
(t<=mid)?modify(ls[x],t,lb,mid):modify(rs[x],t,mid+1,rb);
d[x]=merge(d[ls[x]],d[rs[x]]);
}
seq query(int x,int l,int r,int lb,int rb)
{
if(l>r) return -1;
if(l<=lb&&rb<=r) return d[x];
if(r<=mid) return query(ls[x],l,r,lb,mid);
if(l>mid) return query(rs[x],l,r,mid+1,rb);
return merge(query(ls[x],l,r,lb,mid),query(rs[x],l,r,mid+1,rb));
}
#undef mid
} tr;
void dfs1(int x)
{
siz[x]=1;
for(int i=hed[x];i;i=nxt[i]) if(!siz[tal[i]])
{
fa[tal[i]]=x,dep[tal[i]]=dep[x]+1,a[tal[i]]=wt[i];
dfs1(tal[i]);
siz[x]+=siz[tal[i]];
if(siz[tal[i]]>siz[son[x]]) son[x]=tal[i];
}
}
void dfs2(int x,int tp)
{
if(!x) return;
li[dfn[x]=++idx]=x;
dfs2(son[x],top[x]=tp);
for(int i=hed[x];i;i=nxt[i]) if(!top[tal[i]]) dfs2(tal[i],tal[i]);
}
int query(int x,int y)
{
for(int i=0;i<15;i++) ans[0][i]=ans[1][i]=-1;
while(top[x]!=top[y])
{
if(dep[top[x]]>dep[top[y]])
{
for(int i=0;i<15;i++)
ans[0][i]=merge(tr.query(rt[i],dfn[top[x]],dfn[x],1,n),ans[0][i]);
x=fa[top[x]];
}
else
{
for(int i=0;i<15;i++)
ans[1][i]=merge(tr.query(rt[i],dfn[top[y]],dfn[y],1,n),ans[1][i]);
y=fa[top[y]];
}
}
if(dep[x]>dep[y]) for(int i=0;i<15;i++)
ans[0][i]=merge(tr.query(rt[i],dfn[y]+1,dfn[x],1,n),ans[0][i]);
else for(int i=0;i<15;i++)
ans[1][i]=merge(tr.query(rt[i],dfn[x]+1,dfn[y],1,n),ans[1][i]);
for(int i=0;i<15;i++)
ans[0][i]=merge(ans[0][i].rev(),ans[1][i]);
int ret=0;
for(int i=0;i<15;i++) if(ans[0][i].sum!=-1) ret+=ans[0][i].d1<<i;
return ret;
}
main()
{
scanf("%lld%lld",&n,&q);
for(int i=1,u,v,w;i<n;i++) scanf("%lld%lld%lld",&u,&v,&w),adde(u,v,w),adde(v,u,w);
dfs1(1),dfs2(1,1);
for(int i=0;i<15;i++) rt[i]=tr.build(i,1,n);
while(q--)
{
int op,u,v,w;
scanf("%lld%lld%lld",&op,&u,&v);
if(op==1) printf("%lld\n",query(u,v));
if(op==2)
{
scanf("%lld",&w);
if(fa[u]==v) {int tmp=u;u=v,v=tmp;}
int y=a[v]^w;
for(int i=0;i<15;i++) if(y>>i&1) tr.modify(rt[i],dfn[v],1,n);
a[v]=w;
}
}
}

浙公网安备 33010602011771号