HDU-2255-奔小康赚大钱(KM算法)

KM算法的模板题,记录一下

  • KM算法
    Accepted 2255 468MS 1756K 1600 B G++
    #include "bits/stdc++.h"
    using namespace std;
    const int INF = 0x3f3f3f3f;
    const int MAXN = 305;
    int mp[MAXN][MAXN];
    int match[MAXN], slack[MAXN];
    int ex_x[MAXN], ex_y[MAXN];
    bool vis_x[MAXN], vis_y[MAXN];
    int n;
    bool dfs(int x) {
        vis_x[x] = true;
        for (int y = 1; y <= n; y++) {
            if (vis_y[y]) {
                continue;
            }
            int gap = ex_x[x] + ex_y[y] - mp[x][y];
            if (gap == 0) {
                vis_y[y] = true;
                if (match[y] == -1 || dfs(match[y])) {
                    match[y] = x;
                    return true;
                } 
            } else {
                slack[y] = min(slack[y], gap);
            }
        }
        return false;
    }
    int KM() {
        memset(match, -1, sizeof(match));
        memset(ex_y, 0 , sizeof(ex_y));
        memset(ex_x, ~INF, sizeof(ex_x));
        for (int i = 1; i <= n; i++) {
            for (int j = 1; j <= n; j++) {
                ex_x[i] = max(ex_x[i], mp[i][j]);
            }
        }
        for (int i = 1; i <= n; i++) {
            memset(slack, INF, sizeof(slack));
            while (true) {
                memset(vis_x, false, sizeof(vis_x));
                memset(vis_y, false, sizeof(vis_y));
                if (dfs(i)) {
                    break;
                }
                int d = INF;
                for (int j = 1; j <= n; j++) {
                    if (!vis_y[j]) {
                        d = min(d, slack[j]);
                    }
                }
                for (int j = 1; j <= n; j++) {
                    if (vis_x[j]) {
                        ex_x[j] -= d;
                    }
                    if (vis_y[j]) {
                        ex_y[j] += d;
                    } else {
                        slack[j] -= d;
                    }
                }
            }
        }
        int res = 0;
        for (int i = 1; i <= n; i++) {
            res += mp[match[i]][i];
        }
        return res;
    }
    int main() {
        while (~scanf("%d", &n)) {
            for (int i = 1; i <= n; i++) {
                for (int j = 1; j <= n; j++) {
                    scanf("%d", &mp[i][j]);
                }
            }
            printf("%d\n", KM());
        }
        return 0;
    }

     

posted @ 2019-03-08 16:16  Jathon-cnblogs  阅读(339)  评论(0编辑  收藏  举报