BZOJ5341[Ctsc2018]暴力写挂——边分治+虚树+树形DP

题目链接:

CSTC2018暴力写挂

 

题目大意:给出n个点结构不同的两棵树,边有边权(有负权边及0边),要求找到一个点对(a,b)满足dep(a)+dep(b)-dep(lca)-dep'(lca)最大,其中dep为第一棵树中的深度,dep'为第二棵树中的深度,lca为两点的最近公共祖先。注意:a与b可以相同!

本题讲解两种做法,其中第一种做法常数较小且比较好写,第二种做法思路比较奇特。为了方便讲解,设两点在第一棵树中的距离为$dis(x,y)$

解法一

题中给的式子显然不能直接做,我们将它变换一下(设答案的点对为(x,y)):

$dep(x)+dep(y)-dep(lca)-dep'(lca)$

$=\frac{1}{2}(dep(x)+dep(y)-2*dep(lca)+dep(x)+dep(y))-dep'(lca)$

$=\frac{1}{2}(dis(x,y)+dep(x)+dep(y))-dep'(lca)$

我们对第一棵树多叉树转二叉树(详见边分治讲解)然后进行边分治。

对于每次分治,设一个点到分治中心边的距离为$d(x)$,那么式子就变成了$\frac{1}{2}(d(x)+d(y)+value+dep(x)+dep(y))-dep'(lca)$,其中$value$为分治中心边的边权。

我们将一个点在第二棵树上的点权看成是$v(x)=d(x)+dep(x)$,答案就是$v(x)+v(y)+value-dep'(lca)$。

对于每次边分治将分治联通块内所有点在第二棵树上的建出虚树,同时将分治联通块以分治中心边为界限分成两部分,将一部分的点标为黑点,将另一部分的点标为白点。

那么对于虚树中的一个点,以它为$lca$的最大答案就是在它的两个不同子树中分别选出一个黑点和一个白点使这两个点的点权和最大。

我们在虚树上进行树形DP,每个点维护这个点在虚树上的子树中黑点的最大点权及白点的最大点权。

对于每个点回溯时分别将它每个子树的信息合并上来并在合并时更新答案,这样就保证了当前点一定是所维护的最大点权黑点和最大点权白点的lca。

注意求$LCA$时不能比较带权深度而要比较不带权深度。

#include<set>
#include<map>
#include<queue>
#include<cmath>
#include<stack>
#include<cstdio>
#include<vector>
#include<bitset>
#include<cstring>
#include<iostream>
#include<algorithm>
#define ll long long
#define INF (1<<30)
#define pr pair<int,ll>
using namespace std;
char *p1,*p2,buf[100000];
#define nc() (p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++)
int rd() {int x=0,f=1; char c=nc(); while(!isdigit(c)) {if(c=='-') f=-1; c=nc();} while(isdigit(c)) x=(((x<<2)+x)<<1)+(c^48),c=nc(); return x*f;}
ll rd2() {ll x=0,f=1; char c=nc(); while(!isdigit(c)) {if(c=='-') f=-1; c=nc();} while(isdigit(c)) x=(((x<<2)+x)<<1)+(c^48),c=nc(); return x*f;}
int n,m;
int x,y;
ll ans;
ll z;
int cnt;
int col[400010];
struct miku
{
    int x;
    ll val;
}t[400010];
namespace virtual_tree
{
    ll res;
    int tot;
    int top;
    int dfn=0;
    ll mid_edge;
    ll d[400010];
    ll v[400010];
    int s[400010];
    ll val[800010];
    int st[400010];
    int to[800010];
    int lg[800010];
    int vis[400010];
    int dep[400010];
    int head[400010];
    int next[800010];
    int dp[400010][2];
    int f[800010][20];
    vector<int>q[400010];
    inline void add(int x,int y,ll z)
    {
        next[++tot]=head[x];
        head[x]=tot;
        to[tot]=y;
        val[tot]=z;
    }
    inline bool cmp(const miku &a,const miku &b)
    {
        return s[a.x]<s[b.x];
    }
    void dfs(int x,int fa)
    {
        f[++dfn][0]=x;
        s[x]=dfn;
        for(int i=head[x];i;i=next[i])
        {
            if(to[i]!=fa)
            {
                dep[to[i]]=dep[x]+1;
                d[to[i]]=d[x]+val[i];
                dfs(to[i],x);
                f[++dfn][0]=x;
            }
        }
    }
    inline int mn(int x,int y)
    {
        return dep[x]<dep[y]?x:y;
    }
    inline void ST()
    {
        for(int i=2;i<=dfn;i++)
        {
            lg[i]=lg[i>>1]+1;
        }
        for(int j=1;j<=19;j++)
        {
            for(int i=1;i+(1<<j)-1<=dfn;i++)
            {
                f[i][j]=mn(f[i][j-1],f[i+(1<<(j-1))][j-1]);
            }
        }
    }
    inline int lca(int x,int y)
    {
        x=s[x],y=s[y];
        if(x>y)
        {
            swap(x,y);
        }
        int len=lg[y-x+1];
        return mn(f[x][len],f[y-(1<<len)+1][len]);
    }
    inline void insert(int x)
    {
        int fa=lca(x,st[top]);
        if(!vis[fa])
        {
            vis[fa]=1;
            v[fa]=-1ll<<60;
            dp[fa][0]=dp[fa][1]=0;
        }
        while(top>1&&dep[st[top-1]]>=dep[fa])
        {
            q[st[top-1]].push_back(st[top]);
            top--;
        }
        if(fa!=st[top])
        {
            q[fa].push_back(st[top]);
            st[top]=fa;
        }
        st[++top]=x;
    }
    inline int merge(int x,int y)
    {
        if(!x||!y)
        {
            return x+y;
        }
        return v[x]>v[y]?x:y;
    }
    inline void query(int x,int y)
    {
        if(!x||!y)
        {
            return ;
        }
        res=max(res,v[x]+v[y]);
    }
    void tree_dp(int x)
    {
        int len=q[x].size();
        for(int i=0;i<len;i++)
        {
            int to=q[x][i];
            tree_dp(to);
            res=-1ll<<60;
            query(dp[x][0],dp[to][1]);
            query(dp[x][1],dp[to][0]);
            ans=max(ans,res+mid_edge-2*d[x]);
            dp[x][0]=merge(dp[x][0],dp[to][0]);
            dp[x][1]=merge(dp[x][1],dp[to][1]);
        }
        vis[x]=0;
        q[x].clear();
    }
    inline void build(ll value)
    {
        mid_edge=value;
        for(int i=1;i<=cnt;i++)
        {
            vis[t[i].x]=1;
            v[t[i].x]=t[i].val;
            dp[t[i].x][col[t[i].x]-1]=t[i].x;
            dp[t[i].x][(col[t[i].x]-1)^1]=0;
            col[t[i].x]=0;
        }
        sort(t+1,t+1+cnt,cmp);
        top=0;
        if(t[1].x!=1)
        {
            st[++top]=1;
        }
        for(int i=1;i<=cnt;i++)
        {
            insert(t[i].x);
        }
        while(top>1)
        {
            q[st[top-1]].push_back(st[top]);
            top--;
        }
        tree_dp(1);
    }
    inline void work()
    {
        dfs(1,0);
        ST();
    }
};
namespace edge_partation
{
    int tot;
    int num;
    int root;
    ll d[800010];
    ll val[1600010];
    int to[1600010];
    int vis[800010];
    int head[800010];
    int size[800010];
    int next[1600010];
    vector<pr>v[400010];
    inline void push(int x,int y,ll z)
    {
        v[x].push_back(make_pair(y,z));
    }
    inline void add(int x,int y,ll z)
    {
        next[++tot]=head[x];
        head[x]=tot;
        to[tot]=y;
        val[tot]=z;
    }
    void rebuild(int x,int fa)
    {
        int tmp=0;
        int last=0;
        int len=v[x].size();
        for(int i=0;i<len;i++)
        {
            int to=v[x][i].first;
            int val=v[x][i].second;
            if(to==fa)
            {
                continue;
            }
            tmp++;
            if(tmp==1)
            {
                add(x,to,val);
                add(to,x,val);
                last=x;
            }
            else if(tmp==len-(x!=1))
            {
                add(last,to,val);
                add(to,last,val);
            }
            else
            {
                m++;
                add(last,m,0);
                add(m,last,0);
                last=m;
                add(m,to,val);
                add(to,m,val);
            }
        }
        for(int i=0;i<len;i++)
        {
            if(v[x][i].first==fa)
            {
                continue;
            }
            rebuild(v[x][i].first,x);
        }
    }
    void dfs(int x,int fa)
    {
        for(int i=head[x];i;i=next[i])
        {
            if(to[i]!=fa)
            {
                d[to[i]]=d[x]+val[i];
                dfs(to[i],x);
            }
        }
    }
    void getroot(int x,int fa,int sum)
    {
        size[x]=1;
        for(int i=head[x];i;i=next[i])
        {
            if(!vis[i>>1]&&to[i]!=fa)
            {
                getroot(to[i],x,sum);
                size[x]+=size[to[i]];
                int mx_size=max(size[to[i]],sum-size[to[i]]);
                if(mx_size<num)
                {
                    num=mx_size;
                    root=i;
                }
            }
        }
    }
    void dfs2(int x,int fa,ll dep,int opt)
    {
        if(x<=n)
        {
            col[x]=opt;
            t[++cnt]=(miku){x,d[x]+dep};
        }
        for(int i=head[x];i;i=next[i])
        {
            if(!vis[i>>1]&&to[i]!=fa)
            {
                dfs2(to[i],x,dep+val[i],opt);
            }
        }
    }
    void partation(int x,int sum)
    {
        num=INF;
        getroot(x,0,sum);
        if(num==INF)
        {
            return ;
        }
        int now=root;
        vis[now>>1]=1;
        cnt=0;
        dfs2(to[now],0,0ll,1);
        dfs2(to[now^1],0,0ll,2);
        virtual_tree::build(val[now]);
        int sz=size[to[now]];
        partation(to[now],sz);
        partation(to[now^1],sum-sz);
    }
    inline void work()
    {
        tot=1;
        rebuild(1,0);
        dfs(1,0);
        partation(1,m);
    }
};
int main()
{
    m=n=rd();
    ans=-1ll<<60;
    for(int i=1;i<n;i++)
    {
        x=rd(),y=rd(),z=rd2();
        edge_partation::push(x,y,z);
        edge_partation::push(y,x,z);
    }
    for(int i=1;i<n;i++)
    {
        x=rd(),y=rd(),z=rd2();
        virtual_tree::add(x,y,z);
        virtual_tree::add(y,x,z);
    }
    virtual_tree::work();
    edge_partation::work();
    ans>>=1;
    for(int i=1;i<=n;i++)
    {
        ans=max(ans,edge_partation::d[i]-virtual_tree::d[i]);
    }
    printf("%lld",ans);
}

解法二

同样将原式转化一下:

$dep(x)+dep(y)-dep(lca)-dep'(lca)$

$=dep(y)+(dep(x)-dep(lca))-dep'(lca)$(这好像除了加了个括号、换了下位置啥都没变啊QAQ)

咳咳...我们先不关注这个,同样对于第一棵树进行边分治,对于一次分治,将当前分治联通块以分治中心边为界分成了两部分,我们设这两部分为$S$和$T$。

如果每次边分治时都以当前联通块在原树中深度最小的点为根找分治中心边,那么分治中心边就是一条从原树的父节点连向子节点的边,那么$S$和$T$中就一定有一个是原树的一棵子树。

我们假设$T$是原树的一棵子树,那么显然$S$距离根节点(即$1$号点)更近。

假设$x\subseteq S,y\subset T,y'\subseteq T$,那么$x$与$y$的$lca$一定不在$T$中,也就是说$lca(x,y)$一定等于$lca(x,y')$。

所以对于$S$中的一个点$x$,它与$T$中任意一个点的$lca$都是相同的,假设分治中心边属于$T$的端点为$v$,那么只要求出$lca(x,v)$就求出了$x$与任意$T$中节点的$lca$。

那么对于$S$中的点,我们将它在第二棵树中的点权设为$v(x)=dep(x)-dep(lca(x,v))$;

对于$T$中的点,我们将它在第二棵树中的点权设为$v(x)=dep(x)$。

再像解法一那样对于分治联通块在第二棵树中建出虚树并在虚树上树形DP即可。

#include<set>
#include<map>
#include<queue>
#include<cmath>
#include<stack>
#include<cstdio>
#include<vector>
#include<bitset>
#include<cstring>
#include<iostream>
#include<algorithm>
#define ll long long
#define INF (1<<30)
#define pr pair<int,ll>
using namespace std;
char *p1,*p2,buf[100000];
#define nc() (p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++)
int rd() {int x=0,f=1; char c=nc(); while(!isdigit(c)) {if(c=='-') f=-1; c=nc();} while(isdigit(c)) x=(((x<<2)+x)<<1)+(c^48),c=nc(); return x*f;}
ll rd2() {ll x=0,f=1; char c=nc(); while(!isdigit(c)) {if(c=='-') f=-1; c=nc();} while(isdigit(c)) x=(((x<<2)+x)<<1)+(c^48),c=nc(); return x*f;}
int n,m;
int x,y;
ll ans;
ll z;
int cnt;
int col[400010];
struct miku
{
    int x;
    ll val;
}t[400010];
namespace virtual_tree
{
    ll res;
    int tot;
    int top;
    int dfn=0;
    ll d[400010];
    ll v[400010];
    int s[400010];
    ll val[800010];
    int st[400010];
    int to[800010];
    int lg[800010];
    int vis[400010];
    int dep[400010];
    int head[400010];
    int next[800010];
    int dp[400010][2];
    int f[800010][20];
    vector<int>q[400010];
    inline void add(int x,int y,ll z)
    {
        next[++tot]=head[x];
        head[x]=tot;
        to[tot]=y;
        val[tot]=z;
    }
    inline bool cmp(const miku &a,const miku &b)
    {
        return s[a.x]<s[b.x];
    }
    inline void dfs(int x,int fa)
    {
        f[++dfn][0]=x;
        s[x]=dfn;
        for(int i=head[x];i;i=next[i])
        {
            if(to[i]!=fa)
            {
                dep[to[i]]=dep[x]+1;
                d[to[i]]=d[x]+val[i];
                dfs(to[i],x);
                f[++dfn][0]=x;
            }
        }
    }
    inline int mn(int x,int y)
    {
        return dep[x]<dep[y]?x:y;
    }
    inline void ST()
    {
        for(int i=2;i<=dfn;i++)
        {
            lg[i]=lg[i>>1]+1;
        }
        for(int j=1;j<=19;j++)
        {
            for(int i=1;i+(1<<j)-1<=dfn;i++)
            {
                f[i][j]=mn(f[i][j-1],f[i+(1<<(j-1))][j-1]);
            }
        }
    }
    inline int lca(int x,int y)
    {
        x=s[x],y=s[y];
        if(x>y)
        {
            swap(x,y);
        }
        int len=lg[y-x+1];
        return mn(f[x][len],f[y-(1<<len)+1][len]);
    }
    inline void insert(int x)
    {
        int fa=lca(x,st[top]);
        if(!vis[fa])
        {
            vis[fa]=1;
            v[fa]=-1ll<<60;
            dp[fa][0]=dp[fa][1]=0;
        }
        while(top>1&&dep[st[top-1]]>=dep[fa])
        {
            q[st[top-1]].push_back(st[top]);
            top--;
        }
        if(fa!=st[top])
        {
            q[fa].push_back(st[top]);
            st[top]=fa;
        }
        st[++top]=x;
    }
    inline int merge(int x,int y)
    {
        if(!x||!y)
        {
            return x+y;
        }
        return v[x]>v[y]?x:y;
    }
    inline void query(int x,int y)
    {
        if(!x||!y)
        {
            return ;
        }
        res=max(res,v[x]+v[y]);
    }
    inline void tree_dp(int x)
    {
        int len=q[x].size();
        for(int i=0;i<len;i++)
        {
            int to=q[x][i];
            tree_dp(to);
            res=-1ll<<60;
            query(dp[x][0],dp[to][1]);
            query(dp[x][1],dp[to][0]);
            ans=max(ans,res-d[x]);
            dp[x][0]=merge(dp[x][0],dp[to][0]);
            dp[x][1]=merge(dp[x][1],dp[to][1]);
        }
        vis[x]=0;
        q[x].clear();
    }
    inline void build()
    {
        for(int i=1;i<=cnt;i++)
        {
            vis[t[i].x]=1;
            v[t[i].x]=t[i].val;
            dp[t[i].x][col[t[i].x]-1]=t[i].x;
            dp[t[i].x][(col[t[i].x]-1)^1]=0;
            col[t[i].x]=0;
        }
        sort(t+1,t+1+cnt,cmp);
        top=0;
        if(t[1].x!=1)
        {
            st[++top]=1;
        }
        for(int i=1;i<=cnt;i++)
        {
            insert(t[i].x);
        }
        while(top>1)
        {
            q[st[top-1]].push_back(st[top]);
            top--;
        }
        tree_dp(1);
    }
    inline void work()
    {
        dfs(1,0);
        ST();
    }
};
namespace edge_partation
{
    int tot;
    int num;
    int root;
    int dfn=0;
    ll d[800010];
    int s[800010];
    int lg[1600010];
    ll val[1600010];
    int to[1600010];
    int vis[800010];
    int dep[800010];
    int head[800010];
    int size[800010];
    int next[1600010];
    int f[1600010][21];
    vector<pr>v[400010];
    inline void push(int x,int y,ll z)
    {
        v[x].push_back(make_pair(y,z));
    }
    inline void add(int x,int y,ll z)
    {
        next[++tot]=head[x];
        head[x]=tot;
        to[tot]=y;
        val[tot]=z;
    }
    inline void rebuild(int x,int fa)
    {
        int tmp=0;
        int last=0;
        int len=v[x].size();
        for(int i=0;i<len;i++)
        {
            int to=v[x][i].first;
            int val=v[x][i].second;
            if(to==fa)
            {
                continue;
            }
            tmp++;
            if(tmp==1)
            {
                add(x,to,val);
                add(to,x,val);
                last=x;
            }
            else if(tmp==len-(x!=1))
            {
                add(last,to,val);
                add(to,last,val);
            }
            else
            {
                m++;
                add(last,m,0);
                add(m,last,0);
                last=m;
                add(m,to,val);
                add(to,m,val);
            }
        }
        for(int i=0;i<len;i++)
        {
            if(v[x][i].first==fa)
            {
                continue;
            }
            rebuild(v[x][i].first,x);
        }
    }
    inline void dfs(int x,int fa)
    {
        f[++dfn][0]=x;
        s[x]=dfn;
        for(int i=head[x];i;i=next[i])
        {
            if(to[i]!=fa)
            {
                dep[to[i]]=dep[x]+1;
                d[to[i]]=d[x]+val[i];
                dfs(to[i],x);
                f[++dfn][0]=x;
            }
        }
    }
    inline int mn(int x,int y)
    {
        return dep[x]<dep[y]?x:y;
    }
    inline void ST()
    {
        for(int i=2;i<=dfn;i++)
        {
            lg[i]=lg[i>>1]+1;
        }
        for(int j=1;j<=20;j++)
        {
            for(int i=1;i+(1<<j)-1<=dfn;i++)
            {
                f[i][j]=mn(f[i][j-1],f[i+(1<<(j-1))][j-1]);
            }
        }
    }
    inline int lca(int x,int y)
    {
        x=s[x],y=s[y];
        if(x>y)
        {
            swap(x,y);
        }
        int len=lg[y-x+1];
        return mn(f[x][len],f[y-(1<<len)+1][len]);
    }
    inline void getroot(int x,int fa,int sum)
    {
        size[x]=1;
        for(int i=head[x];i;i=next[i])
        {
            if(!vis[i>>1]&&to[i]!=fa)
            {
                getroot(to[i],x,sum);
                size[x]+=size[to[i]];
                int mx_size=max(size[to[i]],sum-size[to[i]]);
                if(mx_size<num)
                {
                    num=mx_size;
                    root=i;
                }
            }
        }
    }
    inline void dfs2(int x,int fa,int rt,int opt)
    {
        if(x<=n)
        {
            col[x]=opt;
            ll value=rt?d[x]-d[lca(x,rt)]:d[x];
            t[++cnt]=(miku){x,value};
        }
        for(int i=head[x];i;i=next[i])
        {
            if(!vis[i>>1]&&to[i]!=fa)
            {
                dfs2(to[i],x,rt,opt);
            }
        }
    }
    inline void partation(int x,int sum)
    {
        num=INF;
        getroot(x,0,sum);
        if(num==INF)
        {
            return ;
        }
        int now=root;
        vis[now>>1]=1;
        cnt=0;
        dfs2(x,0,to[now],1);
        dfs2(to[now],0,0,2);
        virtual_tree::build();
        int sz=size[to[now]];
        partation(to[now],sz);
        partation(x,sum-sz);
    }
    inline void work()
    {
        tot=1;
        rebuild(1,0);
        dfs(1,0);
        ST();
        partation(1,m);
    }
};
int main()
{
    m=n=rd();
    ans=-1ll<<60;
    for(int i=1;i<n;i++)
    {
        x=rd(),y=rd(),z=rd2();
        edge_partation::push(x,y,z);
        edge_partation::push(y,x,z);
    }
    for(int i=1;i<n;i++)
    {
        x=rd(),y=rd(),z=rd2();
        virtual_tree::add(x,y,z);
        virtual_tree::add(y,x,z);
    }
    virtual_tree::work();
    edge_partation::work();
    for(int i=1;i<=n;i++)
    {
        ans=max(ans,edge_partation::d[i]-virtual_tree::d[i]);
    }
    printf("%lld",ans);
}
posted @ 2018-12-26 20:54  The_Virtuoso  阅读(533)  评论(0编辑  收藏  举报