堆、优先队列和堆排序01:二叉堆和普通堆排序

动态数组实现二叉堆

二叉堆的定义

二叉堆是一颗完全二叉树,元素按层级从左到右排列成树。空的分支只能在右下方

堆中每个节点的值都大于等于其孩子的值(最大堆)

根节点索引从1开始,从上到下、从左到右依次用数组存储,父节点索引为n时,子节点为2n和2n + 1;子节点为n时,父节点为1/2n

根节点索引从0开始,父节点索引为n时,子节点为2n + 1和2n + 2;子节点为n时,父节点为(n - 1)/2

siftUp()和siftDown()

import java.util.Random;

public class Algorithm {

    public static void main(String[] args) {

        int n = 1000000;
        MaxHeap<Integer> maxHeap = new MaxHeap<Integer>();
        Random random = new Random();

        /**
         * extractMax()方法取出的元素都是最大值,如果每次取出都放入一个数组,那这个数组是有序的
         * 这就是普通堆排序(原始的数组必须要先添加进一个maxHeap堆,再将堆中的元素取出来放到一个temp数组,这个堆和数组都需要消耗额外的空间,复杂度都是O(nlogn),是非原地排序)
         */
        for (int i = 0; i < n; i++) {
            maxHeap.add(random.nextInt(Integer.MAX_VALUE));
        }

        int[] temp = new int[n];

        for (int i = 0; i < n; i++) {
            temp[i] = maxHeap.extractMax();
        }

        for (int i = 0; i < n - 1; i++) {

            if (temp[i] < temp[i + 1]){
                throw new IllegalArgumentException("二叉堆实现失败");
            }
        }

        System.out.println("Test completed!");
    }
}

class MaxHeap<E extends Comparable<E>>{

    private Array<E> heap;

    public MaxHeap(int capacity){

        heap = new Array<>(capacity);
    }

    public MaxHeap(){

        heap = new Array<>(10);
    }

    public int size(){

        return heap.getSize();
    }

    public boolean isEmpty(){

        return size() == 0;
    }

    /**
     * 获取父节点和子节点的索引
     * 根节点从0开始
     */
    private int parent(int index){

        if (index == 0) {
            throw new IllegalArgumentException("根节点没有父节点");
        }
        
        return (index - 1) / 2;
    }

    private int leftChild(int index){

        return 2 * index + 1;
    }

    private int rightChild(int index){

        return 2 * index + 2;
    }

    /**
     * 添加节点
     * 要满足:维持完全二叉树的形状、父节点大于等于子节点
     * 先将该节点添加在数组末尾,然后依次和父节点比较,大于父节点则交换位置,直到满足条件,称为SiftUp
     */
    public void add(E e){

        heap.addLast(e);
        siftUp(size() - 1);
    }

    private void siftUp(int index){

        while (index > 0 && heap.get(index).compareTo(heap.get(parent(index))) > 0){

            heap.swap(index, parent(index));
            index = parent(index);
        }
    }

    public E findMax(){

        if (size() == 0){
            throw new IllegalArgumentException("堆为空");
        }

        return heap.get(0);
    }

    /**
     * 取出最大的节点,也就是根节点
     * 删除根节点后,如果合并两颗子树很麻烦,可以将最后一个节点和根节点互换,删除最后一个节点,再将根节点依次和最大的子节点比较,小则交换位置,直到满足条件,称为SiftDown
     */
    public E extractMax(){

        E max = findMax();
        heap.swap(0, size() - 1);
        heap.remove(size() - 1);
        siftDown(0);

        return max;
    }

    private void siftDown(int index){

        /**
         * 如果左孩子都没有,那就是叶子节点,不用换
         * 否则判断一下右孩子存不存在,存在则比较出最大值,否则直接和父节点比较
         */
        while (leftChild(index) < size()){

            int max = leftChild(index);

            if (rightChild(index) < size() && heap.get(leftChild(index)).compareTo(heap.get(rightChild(index))) < 0){
                max = rightChild(index);
            }

            if (heap.get(index).compareTo(heap.get(max)) >= 0){
                break;
            }

            heap.swap(index, max);
            index = max;
        }
    }
}

/**
 * 动态数组实现二叉堆
 */
class Array<E>{

    private E[] data;
    private int size;

    public Array(int capacity){

        data = (E[]) new Object[capacity];
        size = 0;
    }

    public Array(){

        data = (E[]) new Object[10];
        size = 0;
    }

    public int getSize(){

        return size;
    }

    public void swap(int index1, int index2){

        E temp;
        temp = data[index1];
        data[index1] = data[index2];
        data[index2] = temp;
    }

    public E get(int index){

        if (index < 0 || index >= size){
            throw new IllegalArgumentException("索引值非法");
        }

        return data[index];
    }

    public void add(E e, int index){

        if (index < 0 || index > size){
            throw new IllegalArgumentException("索引值非法");
        }

        if (size == data.length){
            resize(2 * data.length);
        }

        for (int i = size - 1; i >= index; i--) {
            data[i + 1] = data[i];
        }

        data[index] = e;
        size++;
    }

    public void addLast(E e){
        add(e, size);
    }

    public void remove(int index){

        if (index < 0 || index >= size){
            throw new IllegalArgumentException("索引值非法");
        }

        for (int i = index + 1; i < size; i++) {
            data[i - 1] = data[i];
        }

        size--;
        data[size] = null;

        if (size == data.length / 2 && data.length / 2 != 0){
            resize(data.length / 2);
        }
    }

    public void resize(int newCapacity){

        E[] temp = (E[])new Object[newCapacity];
        
        for (int i = 0; i < size; i++) {
            temp[i] = data[i];
        }

        data = temp;
    }

    @Override
    public String toString(){

        StringBuilder str = new StringBuilder();
        str.append("[");

        for (int i = 0; i < size; i++) {

            str.append(data[i]);
            
            if (i != size - 1){
                str.append(", ");
            }
        }

        str.append("]");

        return str.toString();
    }
}

时间复杂度:SiftUp()和SiftDown()方法的时间复杂度为O(logn),因此add()和extractMax()方法也是O(logn)

而普通堆排序需要将n个元素从堆中取出来,复杂度为O(nlogn)

普通堆排序、归并排序和快速排序性能对比

import java.util.Arrays;
import java.util.Random;

public class Algorithm {

    public static void main(String[] args) {

        Integer[] testScale = {100000, 1000000};

        for (Integer n : testScale){

            Integer[] randomArr = ArrayGenerator.generatorRandomArray(n, n);
            Integer[] sortedArr = ArrayGenerator.generatorSortedArray(n, n);

            Integer[] arr1 = Arrays.copyOf(randomArr, randomArr.length);
            Integer[] arr3 = Arrays.copyOf(randomArr, randomArr.length);
            Integer[] arr5 = Arrays.copyOf(randomArr, randomArr.length);

            Integer[] arr2 = Arrays.copyOf(sortedArr, sortedArr.length);
            Integer[] arr4 = Arrays.copyOf(sortedArr, sortedArr.length);
            Integer[] arr6 = Arrays.copyOf(sortedArr, sortedArr.length);

            System.out.println("测试随机数组排序性能");
            System.out.println();

            Verify.testTime("HeapSort", arr1);
            Verify.testTime("QuickSort3Ways", arr3);
            Verify.testTime("MergeSort", arr5);

            System.out.println();

            System.out.println("测试有序数组排序性能");
            System.out.println();

            Verify.testTime("HeapSort", arr2);
            Verify.testTime("QuickSort3Ways", arr4);
            Verify.testTime("MergeSort", arr6);

            System.out.println();
        }
    }
}

class HeapSort{

    private HeapSort(){}

    public static<E extends Comparable<E>> void sort(E[] arr){

        MaxHeap<E> maxHeap = new MaxHeap<>();

        for (E e: arr){
            maxHeap.add(e);
        }

        for (int i = arr.length - 1; i >= 0; i--) {
            arr[i] = maxHeap.extractMax();
        }
    }
}

class QuickSort {

    private QuickSort() {}

    public static<E extends Comparable<E>> void sort3ways(E[] arr){

        Random random = new Random();
        E temp = null;
        sort3ways(arr, 0, arr.length - 1, temp, random);
    }

    public static<E extends Comparable<E>> void sort3ways(E[] arr, int left, int right, E temp, Random random){

        if (left >= right){
            
            return;
        }

        int[] res = partition3ways(arr, left, right, temp, random);
        sort3ways(arr, left, res[0], temp, random);
        sort3ways(arr, res[1], right, temp, random);
    }

    public static<E extends Comparable<E>> int[] partition3ways(E[] arr, int left, int right, E temp, Random random){

        int p = random.nextInt(right - left + 1) + left;
        swap(arr, p, left, temp);
        
        int i = left + 1;
        int lt = left;
        int gt = right + 1;

        while (i < gt){

            if (arr[i].compareTo(arr[left]) < 0){

                lt++;
                swap(arr, lt, i, temp);
                i++;
            }
            else if (arr[i].compareTo(arr[left]) == 0){
                i++;
            }
            else if (arr[i].compareTo(arr[left]) > 0){
                gt--;
                swap(arr, gt, i, temp);
            }
        }

        swap(arr, lt, left, temp);
        int[] res = {lt - 1, gt};

        return res;
    }

    public static<E extends Comparable<E>> void swap(E[] arr, int i, int j, E temp){

        temp = arr[i];
        arr[i] = arr[j];
        arr[j] = temp;
    }
}

class MergeSort {

    private MergeSort(){}

    public static<E extends Comparable<E>> void sort(E[] arr){

        E[] temp = Arrays.copyOf(arr, arr.length);
        sort(arr, 0, arr.length - 1, temp);
    }

    private static<E extends Comparable<E>> void sort(E[] arr, int left, int right, E[] temp){

        if (left >= right){
            
            return;
        }

        int mid = left + (right - left) / 2;
        sort(arr, left, mid, temp);
        sort(arr, mid + 1, right, temp);

        if (arr[mid].compareTo(arr[mid + 1]) > 0) {
            merge(arr, left, mid, right, temp);
        }
    }

    public static<E extends Comparable<E>> void merge(E[] arr, int left, int mid, int right, E[] temp){

        int i = left;
        int j = mid + 1;
        System.arraycopy(arr, left, temp, left, right - left + 1);

        for (int n = left; n < right + 1; n++) {

            if (i == mid + 1){
                arr[n] = temp[j];
                j++;
            }
            else if (j == right + 1) {
                arr[n] = temp[i];
                i++;
            }
            else if (temp[i].compareTo(temp[j]) <= 0) {
                arr[n] = temp[i];
                i++;
            }
            else{
                arr[n] = temp[j];
                j++;
            }
        }
    }
}

class MaxHeap<E extends Comparable<E>>{

    private Array<E> heap;

    public MaxHeap(int capacity){

        heap = new Array<>(capacity);
    }

    public MaxHeap(){

        heap = new Array<>(10);
    }

    public int size(){

        return heap.getSize();
    }

    public boolean isEmpty(){

        return size() == 0;
    }

    private int parent(int index){

        if (index == 0) {
            throw new IllegalArgumentException("根节点没有父节点");
        }
        
        return (index - 1) / 2;
    }

    private int leftChild(int index){

        return 2 * index + 1;
    }

    private int rightChild(int index){

        return 2 * index + 2;
    }

    public void add(E e){

        heap.add(e, size());
        siftUp(size() - 1);
    }

    private void siftUp(int index){

        while (index > 0 && heap.get(index).compareTo(heap.get(parent(index))) > 0){

            heap.swap(index, parent(index));
            index = parent(index);
        }
    }

    public E findMax(){

        if (size() == 0){
            throw new IllegalArgumentException("堆为空");
        }

        return heap.get(0);
    }

    public E extractMax(){

        E max = findMax();

        heap.swap(0, size() - 1);
        heap.remove(size() - 1);
        siftDown(0);

        return max;
    }

    private void siftDown(int index){

        while (leftChild(index) < size()){

            int max = leftChild(index);
            
            if (rightChild(index) < size() && heap.get(leftChild(index)).compareTo(heap.get(rightChild(index))) < 0){
                max = rightChild(index);
            }

            if (heap.get(index).compareTo(heap.get(max)) < 0){
                heap.swap(index, max);
                index = max;
            }
            else {
                break;
            }
        }
    }
}

class Array<E>{

    private E[] data;
    private int size;

    public Array(int capacity){

        data = (E[]) new Object[capacity];
        size = 0;
    }

    public Array(){

        data = (E[]) new Object[10];
        size = 0;
    }

    public int getSize(){

        return size;
    }

    public void swap(int index1, int index2){

        E temp;
        
        temp = data[index1];
        data[index1] = data[index2];
        data[index2] = temp;
    }

    public E get(int index){

        if (index < 0 || index >= size){
            throw new IllegalArgumentException("索引值非法");
        }

        return data[index];
    }

    public void add(E e, int index){

        if (index < 0 || index > size){
            throw new IllegalArgumentException("索引值非法");
        }

        if (size == data.length){
            resize(2 * data.length);
        }

        for (int i = size - 1; i >= index; i--) {
            data[i + 1] = data[i];
        }

        data[index] = e;
        size++;
    }

    public void addLast(E e){
        add(e, size);
    }

    public void remove(int index){

        if (index < 0 || index >= size){
            throw new IllegalArgumentException("索引值非法");
        }

        for (int i = index + 1; i < size; i++) {
            data[i - 1] = data[i];
        }

        size--;
        data[size] = null;

        if (size == data.length / 2 && data.length / 2 != 0){
            resize(data.length / 2);
        }
    }

    public void resize(int newCapacity){

        E[] temp = (E[])new Object[newCapacity];

        for (int i = 0; i < size; i++) {
            temp[i] = data[i];
        }

        data = temp;
    }

    @Override
    public String toString(){

        StringBuilder str = new StringBuilder();
        str.append("[");

        for (int i = 0; i < size; i++) {

            str.append(data[i]);
            if (i != size - 1){
                str.append(", ");
            }
        }

        str.append("]");

        return str.toString();
    }
}

class ArrayGenerator {

    private ArrayGenerator (){}

    public static Integer[] generatorRandomArray (Integer n, Integer maxBound){

        Integer[] arr = new Integer[n];
        Random random = new Random();

        for (int i = 0; i < n; i++) {
            arr[i] = random.nextInt(maxBound);
        }

        return arr;
    }

    public static Integer[] generatorSortedArray (Integer n, Integer maxBound){

        Integer[] arr = new Integer[n];

        for (int i = 0; i < n; i++) {
            arr[i] = i;
        }

        return arr;
    }
}

class Verify {

    private Verify (){}

    public static<E extends Comparable<E>> boolean isSorted(E[] arr){

        for (int i = 0; i < arr.length - 1; i++) {
            
            if (arr[i].compareTo(arr[i + 1]) > 0) {
                
                return false;
            }
        }

        return true;
    }

    public static<E extends Comparable<E>> void testTime(String AlgorithmName, E[] arr) {

        long startTime = System.nanoTime();

        if (AlgorithmName.equals("HeapSort")) {
            HeapSort.sort(arr);
        }

        if (AlgorithmName.equals("QuickSort3Ways")) {
            QuickSort.sort3ways(arr);
        }

        if (AlgorithmName.equals("MergeSort")) {
            MergeSort.sort(arr);
        }

        long endTime = System.nanoTime();

        if (!Verify.isSorted(arr)){
            throw new RuntimeException(AlgorithmName + "算法排序失败!");
        }

        System.out.println(String.format("%s算法,测试用例为%d,执行时间:%f秒", AlgorithmName, arr.length, (endTime - startTime) / 1000000000.0));
    }
}
posted @ 2021-10-31 15:55  振袖秋枫问红叶  阅读(57)  评论(0)    收藏  举报