严格次小生成树 java实现--洛谷 P4180

题目链接:https://www.luogu.com.cn/problem/P4180
或者LOJ:https://loj.ac/p/10133
LOJ 上有所有的测试用例,可以下载下来看看。
在LOJ上用例都过了,但是提交时候会wrong answer....我也不知道为什么。

这个也算是模板题。最大生成树+LCA。
1.求出最小生成树sum,并且根据生成树所用的边,建立一个新的图关系。
由于最小生成树和次小生成树之间只有一条边的差距(一个我也不知道怎么证明的正确结论),所以,
2.预先处理数据,对于最小生成树的图,使用倍增得到跳跃路径和路径上的最大值+次大值。(因为要计算严格最小生成树,所以除了存储路径上最大边之外,也要存储一下次大边,以防万一在后面计算的时候,出现次小生成树的结果和最小相同,变成非严格最小生成树)
3.对于不在最小生成树上的边,两边距离为d。因为边上两点在最小生成树上,所以根据2的数据,可以得到最小生成树路径上两点之间的最大和次大值a,b.
计算次小生成树ans = sum-max(a,b)+d.
4.遍历每一条不在生成树上的边,都计算ans,取最小值作为严格次小生成树的值。

在提交时候,发现建立新图的时候,使用链式前向星最后两个用例会超时,使用邻接表不会。
在构建新图,得到层次关系的时候,最好使用广度优先遍历吧,深度搜索,如果是递归写的话,有可能会栈溢出。
洛谷AC代码:

import java.io.*;
import java.util.*;

public class Main {
    static int n,m,maxLevel;
    static boolean[] isEdgeUsed;
    static int[][] edges,max1,max2,jump;
    static LinkedList<int[]>[] adj;
    static int[] point,pointLevel;
    static long sum;
    static long result;
    public static void main(String[] args) throws Exception {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st = new StringTokenizer(br.readLine());
        n = Integer.parseInt(st.nextToken());
        m = Integer.parseInt(st.nextToken());

        isEdgeUsed = new boolean[m+1];
        point = new int[n+1];
        pointLevel = new int[n+1];
        edges = new int[m+1][4];
        adj = new LinkedList[n+1];
        max1 = new int[n+1][20];
        max2 = new int[n+1][20];
        jump = new int[n+1][20];

        for(int i = 0;i<=n;i++){
            point[i] = i;
            adj[i] = new LinkedList<>();
        }
        PriorityQueue<int[]> pq = new PriorityQueue<>(new Comparator<int[]>() {
            @Override
            public int compare(int[] ints, int[] t1) {
                if(ints[2] == t1[2]){
                    if(ints[0] == t1[0]){
                        return ints[1]-t1[1];
                    }
                    return ints[0]-t1[0];
                }
                return ints[2]-t1[2];
            }
        });

        for(int i = 1;i<=m;i++){
            st = new StringTokenizer(br.readLine());
            edges[i] = new int[]{
                    Integer.parseInt(st.nextToken()),
                    Integer.parseInt(st.nextToken()),
                    Integer.parseInt(st.nextToken()),
                    i
            };
            pq.add(edges[i]);
        }
        int cnt = 0;
        int[] poll;
        while (!pq.isEmpty()){
            if(cnt == n-1){
                break;
            }
            poll = pq.poll();
            if(!connect(poll[0],poll[1])){
                union(poll[0],poll[1]);
                adj[poll[0]].add(new int[]{poll[1],poll[2]});
                adj[poll[1]].add(new int[]{poll[0],poll[2]});
                cnt++;
                isEdgeUsed[poll[3]] = true;
                sum+=poll[2];
            }
        }

        bfs();
        int tempLevel = 0;
        while ((1<<tempLevel)<=maxLevel){
            tempLevel++;
        }
        maxLevel = tempLevel;
        st();
        result = Long.MAX_VALUE;

        for(int i = 1;i<=m;i++){
            if(!isEdgeUsed[i]){
                int from = edges[i][0];
                int to = edges[i][1];
                int currDist = edges[i][2];
                int lca = LCA(from,to);
                int q1 = questMax(from,lca,currDist);
                int q2 = questMax(to,lca,currDist);
                result = Math.min(result,sum-Math.max(q1,q2)+currDist);
            }
        }
        System.out.println(result);
    }

    private static int questMax(int from,int to,int currDist){
        int result = Integer.MIN_VALUE;
        for(int i = maxLevel;i>=0;i--){
            if(pointLevel[jump[from][i]]>=pointLevel[to]){
                if(currDist!=max1[from][i]){
                    result = Math.max(result,max1[from][i]);
                }
                else{
                    result = Math.max(result,max2[from][i]);
                }
                from = jump[from][i];
            }
        }
        return result;
    }

    private static int LCA(int from,int to){
        if(pointLevel[from]<pointLevel[to]){
            int temp = from;
            from = to;
            to = temp;
        }
        for(int i = maxLevel;i>=0;i--){
            if(pointLevel[jump[from][i]]>=pointLevel[to]){
                from = jump[from][i];
            }
        }
        if(from == to){
            return from;
        }
        for(int i = maxLevel;i>=0;i--){
            if(jump[from][i]!=jump[to][i]){
                from = jump[from][i];
                to = jump[to][i];
            }
        }
        return jump[from][0];
    }

    private static void st() throws IOException {
        for(int j = 1;j<=maxLevel;j++){
            for(int i = 1;i<=n;i++){
                jump[i][j] = jump[jump[i][j-1]][j-1];
                max1[i][j] = Math.max(max1[i][j-1],max1[jump[i][j-1]][j-1]);
                max2[i][j] = Math.max(max2[i][j-1],max2[jump[i][j-1]][j-1]);
                if(max1[i][j-1]>max1[jump[i][j-1]][j-1]){
                    max2[i][j] = Math.max(max2[i][j],max1[jump[i][j-1]][j-1]);
                }
                else if(max1[i][j-1]<max1[jump[i][j-1]][j-1]){
                    max2[i][j] = Math.max(max2[i][j],max1[i][j-1]);
                }
            }
        }
    }

    private static void bfs(){
        pointLevel[1] = 1;
        jump[1][0] = 0;
        max2[1][0] = Integer.MIN_VALUE;
        ArrayDeque<Integer> ad = new ArrayDeque<>();
        boolean[] isVisited = new boolean[n+1];
        isVisited[1] = true;
        ad.addLast(1);
        int current;
        while (!ad.isEmpty()){
            current = ad.pollFirst();
            for(int[] next:adj[current]){
                if(!isVisited[next[0]]){
                    isVisited[next[0]] = true;
                    pointLevel[next[0]] = pointLevel[current]+1;
                    maxLevel = Math.max(maxLevel,pointLevel[next[0]]);
                    max2[next[0]][0] = Integer.MIN_VALUE;
                    max1[next[0]][0] = next[1];
                    jump[next[0]][0] = current;
                    ad.addLast(next[0]);
                }
            }
        }
    }

    private static int find(int a){
        if(a == point[a]){
            return point[a];
        }
        return point[a] = find(point[a]);
    }
    private static void union(int a,int b){
        int A = point[a];
        int B = point[b];
        point[A] = B;
    }
    private static boolean connect(int a, int b){
        return find(a) == find(b);
    }
}

posted @ 2021-04-30 17:51  Monstro  阅读(74)  评论(0)    收藏  举报