BZOJ1977: [BeiJing2010组队]次小生成树 Tree

Description

小 C 最近学了很多最小生成树的算法,Prim 算法、Kurskal 算法、消圈算法等等。 正当小 C 洋洋得意之时,小 P 又来泼小 C 冷水了。小 P 说,让小 C 求出一个无向图的次小生成树,而且这个次小生成树还得是严格次小的,也就是说: 如果最小生成树选择的边集是 EM,严格次小生成树选择的边集是 ES,那么需要满足:(value(e) 表示边 e的权值) 这下小 C 蒙了,他找到了你,希望你帮他解决这个问题。

Input

第一行包含两个整数N 和M,表示无向图的点数与边数。 接下来 M行,每行 3个数x y z 表示,点 x 和点y之间有一条边,边的权值为z。

Output

包含一行,仅一个数,表示严格次小生成树的边权和。(数据保证必定存在严格次小生成树)

Sample Input

5 6
1 2 1
1 3 2
2 4 3
3 5 4
3 4 3
4 5 6

Sample Output

11

HINT

数据中无向图无自环; 50% 的数据N≤2 000 M≤3 000; 80% 的数据N≤50 000 M≤100 000; 100% 的数据N≤100 000 M≤300 000 ,边权值非负且不超过 10^9 。

Solution

好恶心的一道题...
首先要知道有这么一个定理:

严格次小MST一定只有一条边和MST不同。

然后这题就可以做了。
先随便找出来一棵MST,然后把树建出来,称这\(n-1\)条边为“树边”,其他的边为“非树边”。
则如果把一条非树边\((x,y)\)接到树上,那么就会和树上x到y的路径产生一个环,为了保持树的形态,所以必须删掉树上x到y的路径上的一条边,而因为我萌要维护的是严格次小MST,所以需要知道树上x到y路径中的最大值和次大值(如果非树边\((x,y)\)的边权等于最大值,就必须换掉次大值,因为我萌要求的是严格次小MST)。
如何求树链的最大值和次大值?这个问题可以树上倍增解决。(当然也可以树剖,但是复杂度是两个log,而且更难写)。
\(g[x,i,0/1]\)表示节点\(x\)向上\(2^i\)个祖先路径中的最大值和次大值。
则显然有\(g[x,i,0]=edge(i,fa[i]),g[x,i,1]=-∞\)\(g[x,i,0]=max(g[x,i-1,0],g[f[x,i-1],i-1,0])\)
对次大值分类讨论一下:

\[g[x][i][1]= \begin{cases} &max(g[x,i-1,1],g[f[x,i-1],i-1,1])(g[x,i-1,0]=g[f[x,i-1],i-1,0])\\ &max(g[x,i-1,1],g[f[x,i-1],i-1,0])(g[x,i-1,0]>g[f[x,i-1],i-1,0])\\ &max(g[x,i-1,0],g[f[x,i-1],i-1,1])(g[x,i-1,0]<g[f[x,i-1],i-1,0]) \end{cases} \]

于是预处理完之后,分别处理每条非树边\((x,y)\),在求x和y的lca的同时类似预处理那样把路径上的最大值和次大值求出来(方法完全类似所以不想写了...,反正就是分个三类)
对每次替换得到的次小MST取个min,就是答案了。
代码贼长...

#include <bits/stdc++.h>
#define ll long long
#define il inline
const ll inf = 1e18;

namespace io {

#define in(a) a = read()
#define out(a) write(a)
#define outn(a) out(a), putchar('\n')

#define I_int ll
inline I_int read() {
    I_int x = 0, f = 1;
    char c = getchar();
    while (c < '0' || c > '9') {
        if (c == '-') f = -1;
        c = getchar();
    }
    while (c >= '0' && c <= '9') {
        x = x * 10 + c - '0';
        c = getchar();
    }
    return x * f;
}
char F[200];
inline void write(I_int x) {
    if (x == 0) return (void) (putchar('0'));
    I_int tmp = x > 0 ? x : -x;
    if (x < 0) putchar('-');
    int cnt = 0;
    while (tmp > 0) {
        F[cnt++] = tmp % 10 + '0';
        tmp /= 10;
    }
    while (cnt > 0) putchar(F[--cnt]);
}
#undef I_int

}
using namespace io;

using namespace std;

#define N 300010

int n = read(), m = read(), lim;
int cnt, head[N], fa[N], f[N][20], dep[N];
ll g[N][25][2];
struct Node {
    int x, y, v, flag;
}a[N];
struct edge {
    int to, nxt, v;
}e[N<<3];

void ins(int u, int v, int w) {
    e[++cnt] = (edge) {v, head[u], w};
    head[u] = cnt;
}

bool cmp(Node a, Node b) {
    return a.v < b.v;
}

void dfs(int u) {
    for(int i = head[u]; i; i = e[i].nxt) {
        if(e[i].to == f[u][0]) continue; int v = e[i].to;
        f[v][0] = u; g[v][0][0] = 1ll*e[i].v, g[v][0][1] = -inf;
        dep[v] = dep[u] + 1;
        for(int j = 1; j <= lim; ++j) {
            f[v][j] = f[f[v][j-1]][j-1];
            g[v][j][0] = max(g[v][j-1][0], g[f[v][j-1]][j-1][0]);
            if(g[v][j-1][0] == g[f[v][j-1]][j-1][0]) g[v][j][1] = max(g[v][j-1][1], g[f[v][j-1]][j-1][1]);
            else if(g[v][j-1][0] > g[f[v][j-1]][j-1][0]) g[v][j][1] = max(g[v][j-1][1], g[f[v][j-1]][j-1][0]);
            else g[v][j][1] = max(g[v][j-1][0], g[f[v][j-1]][j-1][1]); 
        }
        dfs(v);
    }
}

void lca(int x, int y, ll &t1, ll &t2) {
    t1 = -inf; t2 = -inf; 
	if(dep[x] < dep[y]) swap(x, y);
    for(int i = lim; i >= 0; --i) {
        if(dep[f[x][i]] >= dep[y]) {
            if(t1 == g[x][i][0]) t2 = max(t2, g[x][i][1]);
            else if(t1 < g[x][i][0]) t2 = max(t1, g[x][i][1]), t1 = g[x][i][0];
            else t2 = max(t2, g[x][i][0]);
            x = f[x][i];
        }
    }
    if(x == y) return;
    for(int i = lim; i >= 0; --i) {
        if(f[x][i] != f[y][i]) {
            if(t1 == g[x][i][0]) t2 = max(t2, g[x][i][1]);
            else if(t1 < g[x][i][0]) t2 = max(t1, g[x][i][1]), t1 = g[x][i][0];
            else t2 = max(t2, g[x][i][0]);
            x = f[x][i];
            
            if(t1 == g[y][i][0]) t2 = max(t2, g[y][i][1]);
            else if(t1 < g[y][i][0]) t2 = max(t1, g[y][i][1]), t1 = g[y][i][0];
            else t2 = max(t2, g[y][i][0]);
            y = f[y][i];
        }
    }
    
    if(t1 == g[x][0][0]) t2 = max(t2, g[x][0][1]);
    else if(t1 < g[x][0][0]) t2 = max(t1, g[x][0][1]), t1 = g[x][0][0];
    else if(t2 < g[x][0][0]) t2 = g[x][0][0];
    
	if(t1 == g[y][0][0]) t2 = max(t2, g[y][0][1]);
    else if(t1 < g[y][0][0]) t2 = max(t1, g[y][0][1]), t1 = g[y][0][0];
    else if(t2 < g[y][0][0]) t2 = g[y][0][0];
}

int find(int x) {
    if(fa[x] == x) return x;
    return fa[x] = find(fa[x]);
}

int main() {
    for(int i = 1; i <= n; ++i) fa[i] = i;
    for(int i = 1; i <= m; ++i) {
        int x = read(), y = read(), v = read();
        a[i] = (Node){x, y, v, 0};
    }
    sort(a+1,a+m+1,cmp); ll sum = 0;
    for(int i = 1, tot = 0; tot < n - 1 && i <= m; ++i) {
        int x = find(a[i].x), y = find(a[i].y);
        if(x != y) {
            fa[y] = x;
            sum += 1ll*a[i].v;
            ins(a[i].x, a[i].y, a[i].v);
            ins(a[i].y, a[i].x, a[i].v);
            a[i].flag = 1;
            ++tot;
        }
    }
    lim = (int)(log(n) / log(2)) + 1; 
    for(int i = 1; i <= lim; ++i) g[1][i][0] = g[1][i][1] = -inf;
    dep[1] = 1; dfs(1);
    ll ans = inf;
    for(int i = 1; i <= m; ++i) {
        if(a[i].flag) continue;
        ll mx = 0, se_mx = 0;
        lca(a[i].x, a[i].y, mx, se_mx);
        if(a[i].v == mx) ans = min(ans, sum + (ll)a[i].v - se_mx);
        else ans = min(ans, sum + (ll)a[i].v - mx);
    }
    outn(ans);
}
posted @ 2019-03-30 14:46  henry_y  阅读(148)  评论(0编辑  收藏  举报