Minimal Steiner Tree ACM

上图论课的时候无意之间看到了这个,然后花了几天的时间学习了下,接下来做一个总结。

一般斯坦纳树问题是指(来自百度百科):

斯坦纳树问题是组合优化问题,与最小生成树相似,是最短网络的一种。最小生成树是在给定的点集和边中寻求最短网络使所有点连通。而最小斯坦纳树允许在给定点外增加额外的点,使生成的最短网络开销最小。

然后听说已经被证明为是NP问题了,在ACM竞赛中我们不研究这个,我们研究更简单一些的问题。

对于图G(V,E),其中V表示图中点集,E表示图中边集。设A是V的某个子集,求至少包含A中所有点的最小子树。

这个问题也是NP难问题。所以|A|必须很小才行。于是对于|A|很小的时候,有两种做法。

 

做法一: 根据性质暴力。

有一段很经典的做法(from here):

首先用floyd算法求出两两之间的最短路径。
然后把所有点都两两链接起来,权值就是它们的最短路径。
假设指定必须连接的点有K个。
那么MinimalSteinerTree 树中的内点最多有K-2个。
在纸上画一下就知道了,内点最多的情况就是树为满二叉树的情况。
而由于之前的floyd算法。把整个图给“缩”了一下。
所以树最多有K-2+K个点。
枚举所有可能的点集。对每个点集求最小生成树。取最小值即可。

如果这种方法是正确的,那么我们只需要在所有N(N=|A|)个点中选择K-2个点,然后进行最小生成树算法即可。复杂度为C(N,K-2)*K^2,对于N和K很小的时候还是可以的。

接下来就来证明,为什么做法一是正确的。

前提1:因为图中已经使用Floyd算法,所以图中不存在度为2的点

前提2:根节点全部为A集合中的点

推论1:内点的度数大于3

然后根据数据结构中一点树的基础知识我们可以得到公式:

设内点有X个

2*(K+X-1) >= K+3*X

=》X <= K - 2

然后有同学可能会问,A集合中的点也可能不为根节点啊。稍微想下我们可以知道,如果存在Y个A集合中的点不为根节点,则对于公式左边没有影响,且必然会增大公式右边,所以使得X更小。

当时,这种思路虽然我觉得很好,但是这种方法在ACM中却都会超时。。。现在的出题人那里还会出纯的模板题。确实这方法局限性太大,学习学习思路就好了。

 

方法二:状态压缩DP

(学习了这个,了解到DP果真不是很适合我这种脑子不好用的玩家)

dp[mask][i],其中是以i为根,包含mask中的点的最小生成树的权值。(其实主要可以看成,这个树含有mask集合,且必有i这个点,mask集合是|A|的某个子集)

mask是点的集合,可以用二进制进行状态压缩。

 

在得知dp[mask-1 ~ 1][1...N]的情况下,如何推出dp[mask][1...N]呢?

两个步骤实现:

step1推出:

a = min{dp[m1][i] + dp[m2][i]},其中m1 | m2 = mask。 其中m1和m2就是mask的两个互补子集

step2推出

b = min(dp[mask][j] + d[j][i])

模板可以看我的这边博客。http://www.cnblogs.com/chenhuan001/p/4960239.html

接下来我简单来证明下这样做为什么是对的。

 

要得到dp[tmask][i] 也就是状态为tmask,且一定含有i这个点。 有两种情况

第一种,对于最优状态,i这个点为内点,那么一定可以找到tmask的两种互斥子集m1和m2,使得dp[tmask][i]=dp[m1][i]+dp[m2][i]

第二种,对于最优状态,i这个点为叶子节点,那么第一种方法已经无法找出,所有找到一个没有i节点的状态,添加一条指向i的边来进行转移。

这两种情况都考虑后,就能说明转移是正确的。

 

poj 3123

//
//  main.cpp
//  poj3123
//
//  Created by 陈加寿 on 15/11/10.
//  Copyright (c) 2015年 陈加寿. All rights reserved.
//

#include <iostream>
#include <stdio.h>
#include <string.h>
#include <algorithm>
#include <string>
#include <map>
using namespace std;
#define N 33
#define INF 1000000000

map<string,int> Maphash;
int Maphashid;

int mat[N][N];
int needconnect[5][2];

int cntcnt=0;
int mi;

void Init()
{
    Maphash.clear();
    Maphashid=0;
    for(int i=1;i<N;i++)
        for(int j=1;j<N;j++)
            if(i==j) mat[i][j]=0;
            else mat[i][j]=INF;
}

//来一发spfa
int dp[1010][N];
int que[1001000];
int qf,qd;
int spfamark[N];


int SteinerTreeDP(int mat[N][N],int maxid,int *sameset,int size,int* mark)
{
    //所有的标点从1到maxid
    //SteinerTree 所必须的点集为 sameset[0] - sameset[size-1]
    memset(dp,0,sizeof(dp));
    for(int i=1;i<(1<<size);i++)
        for(int j=1;j<=maxid;j++)
            dp[i][j]=INF;
    //初始化都为INF
    
    for (int i=0; i<size; i++) {
        dp[(1<<i)][sameset[i]]=0;
    }
    
    for (int i=1;i<(1<<size);i++)
    {
        int to[10];
        int tcnt=0;
        for(int j=0;j<size;j++)
        {
            if( (i&(1<<j)) != 0 )
            {
                to[tcnt]=j;
                tcnt++;
            }
        }
        for(int j = 1;j < (1<<tcnt)-1 ;j++)//对于内部的每一种情况.
        {
            int tmp=0;
            for(int j1=0;j1<tcnt;j1++)
            {
                if( ((1<<j1)&j) != 0)
                {
                    tmp |= (1<<to[j1]);
                }
            }
            int othertmp= (~tmp)&i;
            if( (tmp|othertmp) != i) printf("error! tmp|other != i\n"),needconnect[100][100];
            for(int k=1;k<=maxid;k++)
                dp[i][k]=min(dp[i][k],dp[tmp][k]+dp[othertmp][k]);//是含有,并不是只含有!
        }
        //然后开始SPFA
        //一开始把所有不为INF的点入队列
        qf=qd=0;
        memset(spfamark, 0, sizeof(spfamark));
        for(int j=1;j<=maxid;j++)
        {
            if(dp[i][j]!=INF)
            {
                que[qf++]=j;
                spfamark[j]=1;
            }
        }
        while(qf>qd)
        {
            int cur=que[qd++];
            spfamark[cur]=0;
            for(int j=1;j<=maxid;j++)
            {
                if( mat[cur][j]!=INF && j!=cur && dp[i][j] > dp[i][cur]+mat[cur][j] )
                {
                    dp[i][j] = dp[i][cur] + mat[cur][j];
                    if(spfamark[j]==0)
                    {
                        spfamark[j]=1;
                        que[qf++]=j;
                    }
                }
            }
        }
        
        
    }
    int tmin=INF;
    for(int i=1;i<=maxid;i++)
        tmin=min(tmin,dp[(1<<size)-1][i]);
    return tmin;
}

int getmin(int *sameset,int size)
{
    if(size>8) return 0;
    if(size<2) return 0;
    //两个的时候不要抗拒
//    if(size==2)
//    {
//        return mat[ sameset[0] ][ sameset[1] ];
//    }
    int save[N];
    int cnt=0;
    for(int i=1;i<=Maphashid;i++)
    {
        int sign=0;
        for(int j=0;j<size;j++)
        {
            if(sameset[j]==i) {
                sign=1;
                break;
            }
        }
        if(sign==0)
            save[cnt++]=i;
    }
    //然后dfs找size-2种可能。
    int mark[N];
    memset(mark,0,sizeof(mark));
    for(int i=0;i<size;i++)
        mark[ sameset[i] ]=1;
    mi=INF;
    return SteinerTreeDP(mat,Maphashid,sameset,size,mark);
}


int main(int argc, const char * argv[]) {
    // insert code here...
    int n,m;
    while (scanf("%d%d",&n,&m)&&(n+m)) {
        Init();
        for(int i=0;i<n;i++)
        {
            string tmp;
            cin>>tmp;
            if(Maphash[tmp]==0)
                Maphash[tmp] = ++Maphashid;
        }
        for(int i=0;i<m;i++)
        {
            string a,b;
            int hasha,hashb,len;
            cin>>a>>b>>len;
            hasha=Maphash[a];
            hashb=Maphash[b];
            
            mat[hasha][hashb]=mat[hashb][hasha]=min(mat[hasha][hashb],len);
        }
        for(int i=0;i<4;i++)
        {
            //四个
            string a,b;
            cin>>a>>b;
            needconnect[i][0]=Maphash[a];
            needconnect[i][1]=Maphash[b];
        }
        
        //floyd 还没有用
        
        for(int i=1;i<=Maphashid;i++)
            for(int j=1;j<=Maphashid;j++)
                for(int k=1;k<=Maphashid;k++)
                    if(mat[i][j] > mat[i][k]+mat[k][j])
                    {
                        mat[i][j] = mat[i][k]+mat[k][j];
                    }
        
        int sameset[10];
        int setnum;
        int ans=INF;
        for(int i=0;i<=0;i++)
            for(int i1=0;i1<=i+1;i1++)
                for(int i2=0;i2<=i1+1;i2++)
                    for(int i3=0;i3<=i2+1;i3++)
                    {
                        //if(i1==0&&i2==0&&i3==0&&i==0) continue;//这个减掉了竟然还超时 我就抄!
                        int sum=0;
                        for(int j=0;j<4;j++)
                        {
                            //取出同一集合
                            setnum=0;
                            if(i==j)
                            {
                                sameset[setnum++]=needconnect[0][0];
                                sameset[setnum++]=needconnect[0][1];
                            }
                            if(i1==j)
                            {
                                sameset[setnum++]=needconnect[1][0];
                                sameset[setnum++]=needconnect[1][1];
                            }
                            if(i2==j)
                            {
                                sameset[setnum++]=needconnect[2][0];
                                sameset[setnum++]=needconnect[2][1];
                            }
                            if(i3==j)
                            {
                                sameset[setnum++]=needconnect[3][0];
                                sameset[setnum++]=needconnect[3][1];
                            }
                            if(setnum==0||setnum==1) continue;
                            //去重复
                            sort(sameset,sameset+setnum);
                            int tsetnum=1;
                            for(int k=1;k<setnum;k++)
                            {
                                if(sameset[k] != sameset[k-1])
                                {
                                    sameset[tsetnum]=sameset[k];
                                    tsetnum++;
                                }
                            }
                            
                            sum += getmin(sameset,tsetnum);
                        }
                        ans=min(ans,sum);
                    }
        printf("%d\n",ans);
        //printf("cntcnt: %d\n",cntcnt);
        
    }
    return 0;
}
View Code

 

zoj 3613

//
//  main.cpp
//  zoj3613
//
//  Created by 陈加寿 on 15/11/12.
//  Copyright (c) 2015年 陈加寿. All rights reserved.
//

#include <iostream>
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <algorithm>
using namespace std;
#define N 202
#define INF 100000000
#define K 8

int R[5];
int F[5];
int palne[N];
int flagr[N];
int mat[N][N];
int cntr=0;
int cntf=0;
int maxnum=0;
int micost=0;


//SteinerTree 邻接矩阵模板。(稠密图)时间复杂度 O(N*2^K*(2^K+N))
int dp[(1<<K)+1][N];
int midp[(1<<K)+1];
int STV[N];


int SteinerTreeDP(int mat[N][N],int maxid,int *sameset,int size)
{
    //mat为表示距离的邻接矩阵
    //所有的标点从1到maxid
    //SteinerTree 所必须的点集为 sameset[0] 到 sameset[size-1]
    //函数放回最小Steiner Tree的值
    
    for(int i=1;i<(1<<size);i++)
        for(int j=1;j<=maxid;j++)
            dp[i][j]=INF;
    
    for (int i=0; i<size; i++) {
        dp[(1<<i)][sameset[i]]=0;
    }
    for (int i=1;i<(1<<size);i++)
    {
        //step 1
        for(int kk=1;kk<=maxid;kk++)
        {
            STV[kk]=0;
            for(int j = (i-1)&i ; j ;j = (j-1)&i)
            {
                dp[i][kk] = min(dp[i][kk],dp[j][kk]+dp[(~j)&i][kk]);
            }
        }
        //step 2
        int kk,stmin=INF,stminid=0;
        for (int j = 0; stmin = INF, j < maxid; j++)
        {
            for (kk = 1; kk <= maxid; kk++)
                if (dp[i][kk] <= stmin && !STV[kk])
                    stmin = dp[i][stminid = kk];
            
            for (STV[stminid]=1,kk = 1; kk <= maxid; kk++)
                if(STV[kk]==0) dp[i][kk] = min(dp[i][kk], dp[i][stminid] + mat[stminid][kk]);
        }
    }
    
    int tmin=INF;
    for(int j=1;j<=maxid;j++)
        tmin=min(tmin,dp[(1<<size)-1][j]);
    return tmin;
}

int check(int s)
{
    int cnt=0;
    int i;
    for(i=0;i<cntf;i++)
        if( ((1<<i)&s)!=0 ) cnt += palne[ F[i] ];
    for(int j=0;j<cntr;j++)
        if( ((1<<(cntf+j))&s) != 0 ) cnt --;
    if( cnt>=0 ) return 1;
    return 0;
}

int getwei(int s)
{
    int cnt=0;
    for(int j=0;j<cntr;j++)
    {
        if( ((1<<(cntf+j))&s) != 0 ) cnt ++;
    }
    return cnt;
}

int main(int argc, const char * argv[]) {
    int n;
    while(scanf( "%d",&n )!=EOF)
    {
        maxnum=0;
        micost=0;
        cntr=0;
        cntf=0;
        int saveans=0;
        for(int i=1;i<=n;i++)
        {
            scanf("%d%d",&palne[i],&flagr[i]);
            if(flagr[i]==1&&palne[i]!=0)
            {
                saveans++;
                flagr[i]=0;
                palne[i]--;
                //自厂自销
            }
            if(flagr[i]==1)
            {
                R[cntr++]=i;
            }
            if(palne[i]!=0)
            {
                F[cntf++]=i;
            }
        }
        int m;
        scanf("%d",&m);
        for(int i=1;i<=n;i++)
            for(int j=1;j<=n;j++)
            {
                if(i==j) mat[i][j]=0;
                else mat[i][j]=INF;
            }
        
        for(int i=0;i<m;i++)
        {
            int x,y,w;
            scanf("%d%d%d",&x,&y,&w);
            mat[x][y]=mat[y][x]=min(mat[x][y],w);
        }
        int sameset[20];
        int size=0;
        for(int i=0;i<cntf;i++)
            sameset[ size++ ] = F[i];
        for(int i=0;i<cntr;i++)
            sameset[ size++ ] = R[i];
        //等下在这里再写一个用连接表的。
        SteinerTreeDP(mat, n, sameset, size);
        
        int mask=((1<<size)-1);
        for(int i=1;i<=mask;i++)
        {
            midp[i]=INF;
            for(int j=1;j<=n;j++)
                midp[i]=min(midp[i] , dp[i][j]);
        }
        //然后最后的一个DP
        for(int i=1;i<=mask;i++)
        {
            //第一开始要满足要求
            if( check(i)==0 ) continue;
            for(int j=(i-1)&i;j;j=(j-1)&i)
            {
                if( check(j)&&check(i-j) ) midp[i]=min(midp[i],midp[j]+midp[i-j]);
            }
            int tnum=getwei(i);
            if(tnum>=maxnum)
            {
                if(tnum>maxnum)
                {
                    maxnum=tnum;
                    micost=midp[i];
                }
                else if(midp[i] < micost)
                {
                    micost=midp[i];
                }
            }
        }
        
        printf("%d %d\n",maxnum+saveans,micost);
    }
    return 0;
}
View Code

hdu4085和zoj 3613很类似,记住最后再来个DP就可以省很多代码量。

 

posted @ 2015-11-13 15:46  chenhuan001  阅读(1098)  评论(0编辑  收藏  举报