window.cnblogsConfig = { homeTopImg: [ "https://cdn.luogu.com.cn/upload/image_hosting/clcd8ydf.png", "https://cdn.luogu.com.cn/upload/image_hosting/clcd8ydf.png" ], }

KM 算法

竟然还有模板题题解开放,必须水一发。

Km 算法:可以理解为在匈牙利算法上做的拓展。解决的是有完美匹配的二分图上,匹配的权值最大。

定义:

顶标:对于点有顶标 \(l_x\)。且必须满足 \(w(x,y) \le l_x + l_y\)

相等子图:对一个子图内所有的边和点满足 \(l_x + l_y = w(x,y)\)

那么显然当构造出一个相等子图之后,权值就是 \(\sum_x l_x\)

算法流程:

首先先把左边点的顶标设为与其相连的边的边权最大值。右边点的顶标为 \(0\)

每一次加入一个点。像匈牙利算法一样去搜最大匹配(匹配必须为相等子图)。

失配的时候:我们将左边匹配了的点 \(-v\),右边匹配了的点 \(+v\)

对于已经匹配的点是不影响的。

而对于那个新进来的点,相当于此时他可以与其他点试着匹配。

而难点就在于如何求这个 \(v\)

因为原来满足的 \(w(x,y) \le l_x + l_y\)。所以 \(v \le l_x+l_y-w(x,y)\)

所以我们在未匹配的点中找出最小的可行的 \(v\)。然后用 \(v\) 去调整这些点的顶标。

然后,结束了。

你写了个 DFS 版本的,恭喜你 TLE 55。

大概长这样:

int n,m;
int mp[N][N];
int lx[N],ly[N];// 顶标
int cx[N],cy[N];// 匹配点
int vx[N],vy[N];// 是否在增广轨
int n1,n2;// 点数
int mn[N];// 边权和顶标最小的差值
// 等效于顶点的顶标需要至少增加多少才会有相等子图
bool dfs(int x){
    vx[x]=1;
    For(it,1,n2)if(!vy[it] && mp[x][it] < 1e18){
        int tmp = lx[x]+ly[it]-mp[x][it];
        if(tmp==0){//相等子图
            vy[it]=1;
            if(cy[it]==-1||dfs(cy[it])){
                cx[x]=it,cy[it]=x;return 1;
            }
        }else if(tmp>0){
            mn[it]=min(mn[it],tmp);
        }
    }return 0;
}
int km(){
    mem(cx,-1),mem(cy,-1);
    mem(lx,0),mem(ly,0);
    For(i,1,n1)For(j,1,n2){//初始左边的顶标
        if(mp[i][j] > 1e18)continue;
        lx[i] = max(lx[i],mp[i][j]);
    }For(i,1,n1){
        mem(mn,0x3f);
        while(true){
            mem(vx,0),mem(vy,0);
            if(dfs(i))break;
            int Min=linf;//剩下的点中,至少加多少才有可能相等子图
            For(j,1,n2){
                if(!vy[j]) Min=min(Min,mn[j]);
            }For(j,1,n1) if(vx[j])lx[j]-=Min;
            For(j,1,n2) if(vy[j]) ly[j]+=Min; else mn[j]-=Min;
        }
    }int ans = 0;
    For(i,1,n1){
        if(cx[i]!=-1)ans += mp[i][cx[i]];
    }return ans;
}
void solve(){
    cin >> n >> m;
    n1=n2=n;
    mem(mp,0x3f);
    For(i,1,m){
        int x,y,z;cin >>x>>y>>z;
        mp[x][y] = z;
    }cout << km()<<endl;
    For(i,1,n1){
        cout << cy[i]<<" ";
    }cout << endl;
}

但实际上,你得用 BFS。

那么优化一下时间复杂度就来到了 \(O(n^2 + nm)\)

int n,m;
int mp[N][N];
int lx[N],ly[N];// 顶标
int vis[N],pre[N],mat[N];
int slack[N];
int mn[N];// 边权和顶标最小的差值
// 等效于顶点的顶标需要至少增加多少才会有相等子图

int km(){
    mem(ly,0);
    For(i,1,n) lx[i] = -linf;
    For(i,1,n)For(j,1,n){//初始左边的顶标
        if(mp[i][j] > 1e18)continue;
        lx[i] = max(lx[i],mp[i][j]);
    }For(i,1,n){
        int p=0,q=0,id=0;
        mat[0]=i;//初始点
        int v= linf;
        mem(slack,0x3f),mem(pre,0);
        mem(vis,0);
        while(mat[q]){
            q=id;p=mat[q];vis[q]=1;v = linf;
            For(j,1,n) if(!vis[j]){
                if(slack[j] > lx[p] + ly[j] - mp[p][j]) 
                    slack[j] = lx[p] + ly[j] - mp[p][j],pre[j]=q;
                if(slack[j] < v) v=slack[j],id = j;
            }For(j,0,n)if(vis[j])lx[mat[j]]-=v,ly[j]+=v; else slack[j]-=v;
        }for(;q;q=pre[q])mat[q]=mat[pre[q]];
    }int ans = 0;
    For(i,1,n){
        ans += lx[i] + ly[i];
    }return ans;
}
void solve(){
    cin >> n >> m;
    For(i,1,n)For(j,1,n) mp[i][j] = -linf;
    For(i,1,m){
        int x,y,z;cin >>x>>y>>z;
        mp[x][y] =z;
    }cout << km()<<endl;
    For(i,1,n){
        cout << mat[i]<<" ";
    }cout << endl;

posted @ 2026-02-07 13:34  gsczl71  阅读(13)  评论(0)    收藏  举报