【学习笔记】[CTSC2018]暴力写挂 (边分树合并/边分治+虚树)
显然我们可以将问题转化为 1 2 ( dep ( x ) + dep ( y ) + dist ( x , y ) − 2 dep ′ ( lca ′ ( x , y ) ) ) \frac{1}{2}(\text{dep}(x)+\text{dep}(y)+\text{dist}(x,y)-2\text{dep}'(\text{lca}'(x,y))) 21(dep(x)+dep(y)+dist(x,y)−2dep′(lca′(x,y))) 。因为要分治所以这里写成两点距离比较方便。
如果边权非负的话直接套用 [WC2018]通道 的做法合并最远点对即可,然而这题边权有负。
考虑两种算法。
算法一
边分治+虚树。
对第一颗树进行边分治,设 x x x到分治中心的距离是 D ( x ) D(x) D(x),显然答案可以写成 dep ( x ) + D ( x ) + dep ( y ) + D ( y ) − 2 dep ′ ( lca ′ ( x , y ) ) \text{dep}(x)+D(x)+\text{dep}(y)+D(y)-2\text{dep}'(\text{lca}'(x,y)) dep(x)+D(x)+dep(y)+D(y)−2dep′(lca′(x,y)),其中 x x x, y y y属于分治中心的两侧。那么直接在第二颗树对应的虚树上统计即可。
如果加上归并排序和 R M Q RMQ RMQ求 lca \text{lca} lca的优化可以做到 O ( n log n O(n\log n O(nlogn)。
算法二
类比点分树,我们可以构造出边分树,其中叶子节点都是原树上的点,非叶子节点是原树上的边。类比
trie
\text{trie}
trie树,我们可以得到一个叶子节点在每次分治过程中被分到了哪一边。那么我们可以对边分树上的叶子节点到根的链进行合并,道理很简单,如果
x
x
x,
y
y
y到根节点的路径上有公共点
z
z
z,那么可以在
z
z
z处统计答案。那么我们只需在第二棵树上枚举
lca
′
(
x
,
y
)
\text{lca}'(x,y)
lca′(x,y),把儿子节点全部合并起来时统计答案即可。
复杂度 O ( n log n ) O(n\log n) O(nlogn)。
我信仰什么,我便实现哪种方法。
因为太懒了所以只写了虚树的做法。
#include<bits/stdc++.h>
#define fi first
#define se second
#define ll long long
#define db double
#define pb push_back
#define inf 0x3f3f3f3f3f3f3f3f
#define int ll
using namespace std;
const int N=2e6+5;
int n,n2,num,dfn[N];
int hd[N],to[N],st[N],nxt[N],vis[N],w[N],tot=1;
int f2[N][20],dep4[N];
ll dep[N],dep1[N],dep2[N],dep3[N];
vector<pair<int,int>>g[N],g3[N];
void add(int x,int y,int z){
to[++tot]=y,st[tot]=x,w[tot]=z,nxt[tot]=hd[x],hd[x]=tot;
to[++tot]=x,st[tot]=y,w[tot]=z,nxt[tot]=hd[y],hd[y]=tot;
}
int Lca2(int x,int y){
if(dep4[x]<dep4[y])swap(x,y);
for(int i=19;i>=0;i--)if(dep4[f2[x][i]]>=dep4[y])x=f2[x][i];
if(x==y)return x;
for(int i=19;i>=0;i--)if(f2[x][i]!=f2[y][i])x=f2[x][i],y=f2[y][i];
return f2[x][0];
}
void dfs(int u,int topf){
f2[u][0]=topf,dep4[u]=dep4[topf]+1;
for(int i=1;i<20;i++)f2[u][i]=f2[f2[u][i-1]][i-1];
int lst=u;
for(auto v:g[u]){
if(v.fi!=topf){
++n2,add(lst,n2,0),add(n2,v.fi,v.se),lst=n2,dep1[v.fi]=dep1[u]+v.se,dfs(v.fi,u);
}
}
}
int edge,D,siz[N],c[N],s[N],cnt,f[N][20];
ll res(-inf);
vector<int>vec;
void dfs2(int u,int topf,int sz){
siz[u]=1;
for(int k=hd[u];k;k=nxt[k]){
int v=to[k];
if(!vis[k]&&v!=topf){
dfs2(v,u,sz),siz[u]+=siz[v];
if(max(siz[v],sz-siz[v])<D)D=max(siz[v],sz-siz[v]),edge=k;
}
}
}
ll dp1[N],dp2[N];
int p[N],m;
vector<int>g2[N];
void dfs3(int u,int topf,int C){
c[u]=C;if(u<=n)p[++m]=u;
for(int k=hd[u];k;k=nxt[k]){
int v=to[k];
if(!vis[k]&&v!=topf)dep2[v]=dep2[u]+w[k],dfs3(v,u,C);
}
}
void Add(int x,int y){
if(dep[x]>dep[y])swap(x,y);
g2[x].pb(y);
}
bool cmp(int x,int y){
return dfn[x]<dfn[y];
}
void dfs5(int u,int topf){
dfn[u]=++num,dep[u]=dep[topf]+1,f[u][0]=topf;
for(int i=1;i<20;i++)f[u][i]=f[f[u][i-1]][i-1];
for(auto v:g3[u]){
if(v.fi!=topf)dep3[v.fi]=dep3[u]+v.se,dfs5(v.fi,u);
}
}
int Lca(int x,int y){
if(dep[x]<dep[y])swap(x,y);
for(int i=19;i>=0;i--)if(dep[f[x][i]]>=dep[y])x=f[x][i];
if(x==y)return x;
for(int i=19;i>=0;i--)if(f[x][i]!=f[y][i])x=f[x][i],y=f[y][i];
return f[x][0];
}
void build(){
for(int i=1;i<=m;i++)vec.pb(p[i]);
sort(p+1,p+1+m,cmp);
s[cnt=1]=1,vec.pb(1);
for(int i=1;i<=m;i++){
if(p[i]==1)continue;
int u=p[i],v=Lca(u,s[cnt]);
if(v==s[cnt])s[++cnt]=u;
else {
int l=0;
while(dep[s[cnt]]>dep[v]){
if(l)Add(l,s[cnt]);
l=s[cnt--];
}
if(s[cnt]==v){
Add(l,v),s[++cnt]=u;
}
else{
Add(l,v),s[++cnt]=v,vec.pb(v),s[++cnt]=u;
}
}
}for(int i=1;i<cnt;i++)Add(s[i],s[i+1]);
}
void dfs4(int u){
dp1[u]=dp2[u]=-inf;
if(c[u]==1)dp1[u]=dep1[u]+dep2[u];
if(c[u]==2)dp2[u]=dep1[u]+dep2[u];
for(auto v:g2[u]){
dfs4(v),res=max(res,w[edge]+max(dp1[u]+dp2[v],dp2[u]+dp1[v])-2*dep3[u]);
dp1[u]=max(dp1[u],dp1[v]),dp2[u]=max(dp2[u],dp2[v]);
}
}
void solve(int u,int sz){
if(sz==1)return;
edge=-1,D=0x3f3f3f3f,dfs2(u,0,sz);
vis[edge]=vis[edge^1]=1,m=0,dep2[st[edge]]=dep2[to[edge]]=0,dfs3(st[edge],0,1),dfs3(to[edge],0,2);
build(),dfs4(1);
for(int i=0;i<vec.size();i++)g2[vec[i]].clear(),c[vec[i]]=0;vec.clear();
int tmp=edge;
solve(st[tmp],sz-siz[to[tmp]]),solve(to[tmp],siz[to[tmp]]);
}
signed main(){
ios::sync_with_stdio(false);
cin.tie(0),cout.tie(0);
cin>>n,n2=n;
for(int i=1;i<n;i++){
int x,y,z;cin>>x>>y>>z;
g[x].pb({y,z}),g[y].pb({x,z});
}
for(int i=1;i<n;i++){
int x,y,z;cin>>x>>y>>z;
g3[x].pb({y,z}),g3[y].pb({x,z});
}
dfs(1,0),dfs5(1,0);
for(int i=1;i<=n;i++)res=max(res,2*(dep1[i]-dep3[i]));
solve(1,n2);
cout<<res/2;
}

浙公网安备 33010602011771号