P5024 [NOIP2018 提高组] 保卫王国

n个节点的树 每个节点都有p[] 代价  问 相连两点 至少选一个

m 个询问 每次强制规定 两个节点 0/1 求最小代价/ -1

/*
暴力 每次都重置 不合法状态 n^2
预处理+每次都重新 从深度最大的开始跳 n^2
预处理 加上 预处理出 u->root 1 之间 最小的代价 f[x][0/1][i][0/1]  
f[x][0][0][0]=inf
f[x][1][0][0]=dp[fa[x][0]][0]-dp[x][1]
f[x][0][0][1]=dp[fa[x][0]][1]-min(dp[x][1],dp[x][0])
f[x][1][0][1]=dp[fa[x][0]][1]-min(dp[x][1],dp[x][0])

f[x][1][i][0]=min(f[x][1][i-1][0]+f[fa[x][i-1]][0][i-1][0],f[x][1][i-1][1]+f[fa[x][i-1]][1][i-1][0]);
f[x][1][i][1]=min(f[x][1][i-1][0]+f[fa[x][i-1]][0][i-1][1],f[x][1][i-1][1]+f[fa[x][i-1]][1][i-1][1]);
f[x][0][i][1]=min(f[x][0][i-1][0]+f[fa[x][i-1]][0][i-1][1],f[x][0][i-1][1]+f[fa[x][i-1]][1][i-1][1]);                                   
f[x][0][i][0]=min(f[x][0][i-1][0]+f[fa[x][i-1]][0][i-1][0],f[x][0][i-1][1]+f[fa[x][i-1]][1][i-1][0]);  
    
    
u->lcaa+ v->lcaa +lcaa-> root 1
    u->lcaa 
        ll tmp0=u0,tmp1=u1;
        u0=min(tmp0+f[u][0][i][0],tmp1+f[u][1][i][0]);
        u1=min(tmp0+f[u][0][i][1],tmp1+f[u][1][i][1]);
    u,v->lcaa 
        ll tmp0=u0,tmp1=u1,p0=v0,p1=v1;
        u0=min(tmp0+f[u][0][i][0],tmp1+f[u][1][i][0]);
        u1=min(tmp0+f[u][0][i][1],tmp1+f[u][1][i][1]);
        v0=min(p0+f[v][0][i][0],p1+f[v][1][i][0]);
        v1=min(p0+f[v][0][i][1],p1+f[v][1][i][1]);        
        u=fa[u][i],v=fa[v][i];
    lcaa->root 
        ll t0=l0,t1=l1;
        l0=min(t0+f[lcaa][0][i][0],t1+f[lcaa][1][i][0]);
        l1=min(t0+f[lcaa][0][i][1],t1+f[lcaa][1][i][1]);
        
*/
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cmath>
//#include<queue>
//#include<vector>
//#include<bits/stdc++.h>
#define ll long long
#define ddd printf("-----------------------\n");
using namespace std;
const ll maxn=1e5+10 ;
const ll inf=0x3f3f3f3f3f3f3f3f;

char s[10];
ll n,m,p[maxn],ans,head[maxn],to[maxn<<1],nxt[maxn<<1],tot;
ll fa[maxn][20],dep[maxn],dp[maxn][3],f[maxn][3][20][3],lg[maxn]={-1};

void add(int a,int b){
    to[++tot]=b,nxt[tot]=head[a],head[a]=tot;
}

void dfs(int x,int faa)
{
    dep[x]=dep[faa]+1,fa[x][0]=faa,dp[x][1]+=p[x];//
    for(int i=1;(1<<(i-1))<=dep[x];i++) fa[x][i]=fa[fa[x][i-1]][i-1];
    
    for(int i=head[x];i;i=nxt[i])
    {
        int v=to[i];if(v==faa) continue;
        dfs(v,x);
        dp[x][0]+=dp[v][1];
        dp[x][1]+=min(dp[v][0],dp[v][1]);
    }
}
void get_f(int x,int faa)
{
    f[x][0][0][0]=inf;
    f[x][1][0][0]=dp[fa[x][0]][0]-dp[x][1];
    f[x][1][0][1]=f[x][0][0][1]=dp[fa[x][0]][1]-min(dp[x][1],dp[x][0]);
    
    for(int i=1;(1<<i)<=dep[x];i++){//
        f[x][1][i][0]=min(f[x][1][i-1][0]+f[fa[x][i-1]][0][i-1][0],f[x][1][i-1][1]+f[fa[x][i-1]][1][i-1][0]);
        f[x][1][i][1]=min(f[x][1][i-1][0]+f[fa[x][i-1]][0][i-1][1],f[x][1][i-1][1]+f[fa[x][i-1]][1][i-1][1]);
        f[x][0][i][1]=min(f[x][0][i-1][0]+f[fa[x][i-1]][0][i-1][1],f[x][0][i-1][1]+f[fa[x][i-1]][1][i-1][1]);                                   
        f[x][0][i][0]=min(f[x][0][i-1][0]+f[fa[x][i-1]][0][i-1][0],f[x][0][i-1][1]+f[fa[x][i-1]][1][i-1][0]);  
    }
    
    for(int i=head[x];i;i=nxt[i])
        if(to[i]!=faa) get_f(to[i],x);
}

void lca(int u,int x,int v,int y)
{
    if(dep[u]<dep[v])    swap(u,v),swap(x,y);
    ll u1=inf,u0=inf,v1=inf,v0=inf,l1=inf,l0=inf,lcaa;
    x? u1=dp[u][1]:u0=dp[u][0];// 以u为根的子树 0/1 最小值 
    y? v1=dp[v][1]:v0=dp[v][0];
    
    for(int i=lg[dep[u]-dep[v]];i>=0;i--)
    {
        if(dep[u]-(1<<i)>=dep[v]){
            ll tmp0=u0,tmp1=u1;
            u0=min(tmp0+f[u][0][i][0],tmp1+f[u][1][i][0]);
            u1=min(tmp0+f[u][0][i][1],tmp1+f[u][1][i][1]);
            u=fa[u][i];
        }
    }
    
    if(u==v) lcaa=u, y? l1=u1:l0=u0 ;    //
    else{
        for(int i=lg[dep[u]];i>=0;i--)
        {
            if(fa[u][i]!=fa[v][i])
            {
                ll tmp0=u0,tmp1=u1,p0=v0,p1=v1;
                u0=min(tmp0+f[u][0][i][0],tmp1+f[u][1][i][0]);
                u1=min(tmp0+f[u][0][i][1],tmp1+f[u][1][i][1]);
                v0=min(p0+f[v][0][i][0],p1+f[v][1][i][0]);
                v1=min(p0+f[v][0][i][1],p1+f[v][1][i][1]);        
                u=fa[u][i],v=fa[v][i];
            }
        }
        lcaa=fa[u][0];
        l0=dp[lcaa][0]-dp[u][1]-dp[v][1]+u1+v1;
        l1=dp[lcaa][1]-min(dp[u][0],dp[u][1])-min(dp[v][0],dp[v][1])+min(u0,u1)+min(v0,v1);        
    }
    
    if(lcaa==1) ans=min(l0,l1);
    else {
        for(int i=lg[dep[lcaa]];i>=0;i--)
        {
            if(dep[lcaa]-(1<<i)>dep[1])
            {    
                ll t0=l0,t1=l1;
                l0=min(t0+f[lcaa][0][i][0],t1+f[lcaa][1][i][0]);
                l1=min(t0+f[lcaa][0][i][1],t1+f[lcaa][1][i][1]);
                lcaa=fa[lcaa][i];
            }
        }
        ans=min(dp[1][0]-dp[lcaa][1]+l1,dp[1][1]-min(dp[lcaa][1],dp[lcaa][0])+min(l0,l1) );            
    }    
}

int main()
{
    ios::sync_with_stdio(false);
    cin>>n>>m>>s;
    for(int i=1;i<=n;i++) cin>>p[i],lg[i]=lg[i>>1]+1;
    for(int i=1;i<=n-1;i++){
        int a,b;cin>>a>>b;
        add(a,b),add(b,a);
    }
    dfs(1,0);
    get_f(1,0);
    
    for(int i=1;i<=m;i++){
        ans=-1;
        int a,x,b,y;cin>>a>>x>>b>>y;
        lca(a,x,b,y);
        if(ans<inf) cout<<ans<<'\n';
        else cout<<"-1\n";
    }
    return 0;
}



/*
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cmath>
//#include<queue>
//#include<vector>
#include<bits/stdc++.h>
#define ll long long
#define ddd printf("-----------------------\n");
using namespace std;
const int maxn=1e5+10 ;
const int inf=0x3f3f3f3f;
const int mod=2003;

int n,q,a[maxn],head[maxn],to[maxn<<1],nxt[maxn<<1],tot;
ll f[maxn][3];
char s[10];

void add(int u,int v){
    to[++tot]=v,nxt[tot]=head[u],head[u]=tot;
}

void dfs(int u,int faa)
{
    f[u][1]+=a[u];//inf
    for(int i=head[u];i;i=nxt[i])
    {
        int v=to[i];if(v==faa) continue;
        dfs(v,u);
        f[u][0]+=f[v][1];
        f[u][1]+=min(f[v][1],f[v][0]);
    }
}
int main()
{
    ios::sync_with_stdio(false);
    cin>>n>>q>>s;
    for(int i=1;i<=n;i++) cin>>a[i];
    for(int i=1;i<=n-1;i++){
        int a,b;cin>>a>>b;
        add(a,b),add(b,a);
    }
    for(int i=1;i<=q;i++){
        memset(f,0,sizeof(f));
        
        int a,x,b,y;cin>>a>>x>>b>>y;
        f[a][1-x]=inf,f[b][1-y]=inf;
        dfs(1,1);
        if(min(f[1][1],f[1][0])>=inf) cout<<"-1\n";
        else cout<<min(f[1][1],f[1][0])<<'\n';
    }
    return 0;
}
*/

 

posted @ 2023-09-28 10:59  JMXZ  阅读(11)  评论(0)    收藏  举报