堆、优先队列和堆排序01:二叉堆和普通堆排序
动态数组实现二叉堆
二叉堆的定义
二叉堆是一颗完全二叉树,元素按层级从左到右排列成树。空的分支只能在右下方
堆中每个节点的值都大于等于其孩子的值(最大堆)
根节点索引从1开始,从上到下、从左到右依次用数组存储,父节点索引为n时,子节点为2n和2n + 1;子节点为n时,父节点为1/2n
根节点索引从0开始,父节点索引为n时,子节点为2n + 1和2n + 2;子节点为n时,父节点为(n - 1)/2
siftUp()和siftDown()
import java.util.Random;
public class Algorithm {
public static void main(String[] args) {
int n = 1000000;
MaxHeap<Integer> maxHeap = new MaxHeap<Integer>();
Random random = new Random();
/**
* extractMax()方法取出的元素都是最大值,如果每次取出都放入一个数组,那这个数组是有序的
* 这就是普通堆排序(原始的数组必须要先添加进一个maxHeap堆,再将堆中的元素取出来放到一个temp数组,这个堆和数组都需要消耗额外的空间,复杂度都是O(nlogn),是非原地排序)
*/
for (int i = 0; i < n; i++) {
maxHeap.add(random.nextInt(Integer.MAX_VALUE));
}
int[] temp = new int[n];
for (int i = 0; i < n; i++) {
temp[i] = maxHeap.extractMax();
}
for (int i = 0; i < n - 1; i++) {
if (temp[i] < temp[i + 1]){
throw new IllegalArgumentException("二叉堆实现失败");
}
}
System.out.println("Test completed!");
}
}
class MaxHeap<E extends Comparable<E>>{
private Array<E> heap;
public MaxHeap(int capacity){
heap = new Array<>(capacity);
}
public MaxHeap(){
heap = new Array<>(10);
}
public int size(){
return heap.getSize();
}
public boolean isEmpty(){
return size() == 0;
}
/**
* 获取父节点和子节点的索引
* 根节点从0开始
*/
private int parent(int index){
if (index == 0) {
throw new IllegalArgumentException("根节点没有父节点");
}
return (index - 1) / 2;
}
private int leftChild(int index){
return 2 * index + 1;
}
private int rightChild(int index){
return 2 * index + 2;
}
/**
* 添加节点
* 要满足:维持完全二叉树的形状、父节点大于等于子节点
* 先将该节点添加在数组末尾,然后依次和父节点比较,大于父节点则交换位置,直到满足条件,称为SiftUp
*/
public void add(E e){
heap.addLast(e);
siftUp(size() - 1);
}
private void siftUp(int index){
while (index > 0 && heap.get(index).compareTo(heap.get(parent(index))) > 0){
heap.swap(index, parent(index));
index = parent(index);
}
}
public E findMax(){
if (size() == 0){
throw new IllegalArgumentException("堆为空");
}
return heap.get(0);
}
/**
* 取出最大的节点,也就是根节点
* 删除根节点后,如果合并两颗子树很麻烦,可以将最后一个节点和根节点互换,删除最后一个节点,再将根节点依次和最大的子节点比较,小则交换位置,直到满足条件,称为SiftDown
*/
public E extractMax(){
E max = findMax();
heap.swap(0, size() - 1);
heap.remove(size() - 1);
siftDown(0);
return max;
}
private void siftDown(int index){
/**
* 如果左孩子都没有,那就是叶子节点,不用换
* 否则判断一下右孩子存不存在,存在则比较出最大值,否则直接和父节点比较
*/
while (leftChild(index) < size()){
int max = leftChild(index);
if (rightChild(index) < size() && heap.get(leftChild(index)).compareTo(heap.get(rightChild(index))) < 0){
max = rightChild(index);
}
if (heap.get(index).compareTo(heap.get(max)) >= 0){
break;
}
heap.swap(index, max);
index = max;
}
}
}
/**
* 动态数组实现二叉堆
*/
class Array<E>{
private E[] data;
private int size;
public Array(int capacity){
data = (E[]) new Object[capacity];
size = 0;
}
public Array(){
data = (E[]) new Object[10];
size = 0;
}
public int getSize(){
return size;
}
public void swap(int index1, int index2){
E temp;
temp = data[index1];
data[index1] = data[index2];
data[index2] = temp;
}
public E get(int index){
if (index < 0 || index >= size){
throw new IllegalArgumentException("索引值非法");
}
return data[index];
}
public void add(E e, int index){
if (index < 0 || index > size){
throw new IllegalArgumentException("索引值非法");
}
if (size == data.length){
resize(2 * data.length);
}
for (int i = size - 1; i >= index; i--) {
data[i + 1] = data[i];
}
data[index] = e;
size++;
}
public void addLast(E e){
add(e, size);
}
public void remove(int index){
if (index < 0 || index >= size){
throw new IllegalArgumentException("索引值非法");
}
for (int i = index + 1; i < size; i++) {
data[i - 1] = data[i];
}
size--;
data[size] = null;
if (size == data.length / 2 && data.length / 2 != 0){
resize(data.length / 2);
}
}
public void resize(int newCapacity){
E[] temp = (E[])new Object[newCapacity];
for (int i = 0; i < size; i++) {
temp[i] = data[i];
}
data = temp;
}
@Override
public String toString(){
StringBuilder str = new StringBuilder();
str.append("[");
for (int i = 0; i < size; i++) {
str.append(data[i]);
if (i != size - 1){
str.append(", ");
}
}
str.append("]");
return str.toString();
}
}
时间复杂度:SiftUp()和SiftDown()方法的时间复杂度为O(logn),因此add()和extractMax()方法也是O(logn)
而普通堆排序需要将n个元素从堆中取出来,复杂度为O(nlogn)
普通堆排序、归并排序和快速排序性能对比
import java.util.Arrays;
import java.util.Random;
public class Algorithm {
public static void main(String[] args) {
Integer[] testScale = {100000, 1000000};
for (Integer n : testScale){
Integer[] randomArr = ArrayGenerator.generatorRandomArray(n, n);
Integer[] sortedArr = ArrayGenerator.generatorSortedArray(n, n);
Integer[] arr1 = Arrays.copyOf(randomArr, randomArr.length);
Integer[] arr3 = Arrays.copyOf(randomArr, randomArr.length);
Integer[] arr5 = Arrays.copyOf(randomArr, randomArr.length);
Integer[] arr2 = Arrays.copyOf(sortedArr, sortedArr.length);
Integer[] arr4 = Arrays.copyOf(sortedArr, sortedArr.length);
Integer[] arr6 = Arrays.copyOf(sortedArr, sortedArr.length);
System.out.println("测试随机数组排序性能");
System.out.println();
Verify.testTime("HeapSort", arr1);
Verify.testTime("QuickSort3Ways", arr3);
Verify.testTime("MergeSort", arr5);
System.out.println();
System.out.println("测试有序数组排序性能");
System.out.println();
Verify.testTime("HeapSort", arr2);
Verify.testTime("QuickSort3Ways", arr4);
Verify.testTime("MergeSort", arr6);
System.out.println();
}
}
}
class HeapSort{
private HeapSort(){}
public static<E extends Comparable<E>> void sort(E[] arr){
MaxHeap<E> maxHeap = new MaxHeap<>();
for (E e: arr){
maxHeap.add(e);
}
for (int i = arr.length - 1; i >= 0; i--) {
arr[i] = maxHeap.extractMax();
}
}
}
class QuickSort {
private QuickSort() {}
public static<E extends Comparable<E>> void sort3ways(E[] arr){
Random random = new Random();
E temp = null;
sort3ways(arr, 0, arr.length - 1, temp, random);
}
public static<E extends Comparable<E>> void sort3ways(E[] arr, int left, int right, E temp, Random random){
if (left >= right){
return;
}
int[] res = partition3ways(arr, left, right, temp, random);
sort3ways(arr, left, res[0], temp, random);
sort3ways(arr, res[1], right, temp, random);
}
public static<E extends Comparable<E>> int[] partition3ways(E[] arr, int left, int right, E temp, Random random){
int p = random.nextInt(right - left + 1) + left;
swap(arr, p, left, temp);
int i = left + 1;
int lt = left;
int gt = right + 1;
while (i < gt){
if (arr[i].compareTo(arr[left]) < 0){
lt++;
swap(arr, lt, i, temp);
i++;
}
else if (arr[i].compareTo(arr[left]) == 0){
i++;
}
else if (arr[i].compareTo(arr[left]) > 0){
gt--;
swap(arr, gt, i, temp);
}
}
swap(arr, lt, left, temp);
int[] res = {lt - 1, gt};
return res;
}
public static<E extends Comparable<E>> void swap(E[] arr, int i, int j, E temp){
temp = arr[i];
arr[i] = arr[j];
arr[j] = temp;
}
}
class MergeSort {
private MergeSort(){}
public static<E extends Comparable<E>> void sort(E[] arr){
E[] temp = Arrays.copyOf(arr, arr.length);
sort(arr, 0, arr.length - 1, temp);
}
private static<E extends Comparable<E>> void sort(E[] arr, int left, int right, E[] temp){
if (left >= right){
return;
}
int mid = left + (right - left) / 2;
sort(arr, left, mid, temp);
sort(arr, mid + 1, right, temp);
if (arr[mid].compareTo(arr[mid + 1]) > 0) {
merge(arr, left, mid, right, temp);
}
}
public static<E extends Comparable<E>> void merge(E[] arr, int left, int mid, int right, E[] temp){
int i = left;
int j = mid + 1;
System.arraycopy(arr, left, temp, left, right - left + 1);
for (int n = left; n < right + 1; n++) {
if (i == mid + 1){
arr[n] = temp[j];
j++;
}
else if (j == right + 1) {
arr[n] = temp[i];
i++;
}
else if (temp[i].compareTo(temp[j]) <= 0) {
arr[n] = temp[i];
i++;
}
else{
arr[n] = temp[j];
j++;
}
}
}
}
class MaxHeap<E extends Comparable<E>>{
private Array<E> heap;
public MaxHeap(int capacity){
heap = new Array<>(capacity);
}
public MaxHeap(){
heap = new Array<>(10);
}
public int size(){
return heap.getSize();
}
public boolean isEmpty(){
return size() == 0;
}
private int parent(int index){
if (index == 0) {
throw new IllegalArgumentException("根节点没有父节点");
}
return (index - 1) / 2;
}
private int leftChild(int index){
return 2 * index + 1;
}
private int rightChild(int index){
return 2 * index + 2;
}
public void add(E e){
heap.add(e, size());
siftUp(size() - 1);
}
private void siftUp(int index){
while (index > 0 && heap.get(index).compareTo(heap.get(parent(index))) > 0){
heap.swap(index, parent(index));
index = parent(index);
}
}
public E findMax(){
if (size() == 0){
throw new IllegalArgumentException("堆为空");
}
return heap.get(0);
}
public E extractMax(){
E max = findMax();
heap.swap(0, size() - 1);
heap.remove(size() - 1);
siftDown(0);
return max;
}
private void siftDown(int index){
while (leftChild(index) < size()){
int max = leftChild(index);
if (rightChild(index) < size() && heap.get(leftChild(index)).compareTo(heap.get(rightChild(index))) < 0){
max = rightChild(index);
}
if (heap.get(index).compareTo(heap.get(max)) < 0){
heap.swap(index, max);
index = max;
}
else {
break;
}
}
}
}
class Array<E>{
private E[] data;
private int size;
public Array(int capacity){
data = (E[]) new Object[capacity];
size = 0;
}
public Array(){
data = (E[]) new Object[10];
size = 0;
}
public int getSize(){
return size;
}
public void swap(int index1, int index2){
E temp;
temp = data[index1];
data[index1] = data[index2];
data[index2] = temp;
}
public E get(int index){
if (index < 0 || index >= size){
throw new IllegalArgumentException("索引值非法");
}
return data[index];
}
public void add(E e, int index){
if (index < 0 || index > size){
throw new IllegalArgumentException("索引值非法");
}
if (size == data.length){
resize(2 * data.length);
}
for (int i = size - 1; i >= index; i--) {
data[i + 1] = data[i];
}
data[index] = e;
size++;
}
public void addLast(E e){
add(e, size);
}
public void remove(int index){
if (index < 0 || index >= size){
throw new IllegalArgumentException("索引值非法");
}
for (int i = index + 1; i < size; i++) {
data[i - 1] = data[i];
}
size--;
data[size] = null;
if (size == data.length / 2 && data.length / 2 != 0){
resize(data.length / 2);
}
}
public void resize(int newCapacity){
E[] temp = (E[])new Object[newCapacity];
for (int i = 0; i < size; i++) {
temp[i] = data[i];
}
data = temp;
}
@Override
public String toString(){
StringBuilder str = new StringBuilder();
str.append("[");
for (int i = 0; i < size; i++) {
str.append(data[i]);
if (i != size - 1){
str.append(", ");
}
}
str.append("]");
return str.toString();
}
}
class ArrayGenerator {
private ArrayGenerator (){}
public static Integer[] generatorRandomArray (Integer n, Integer maxBound){
Integer[] arr = new Integer[n];
Random random = new Random();
for (int i = 0; i < n; i++) {
arr[i] = random.nextInt(maxBound);
}
return arr;
}
public static Integer[] generatorSortedArray (Integer n, Integer maxBound){
Integer[] arr = new Integer[n];
for (int i = 0; i < n; i++) {
arr[i] = i;
}
return arr;
}
}
class Verify {
private Verify (){}
public static<E extends Comparable<E>> boolean isSorted(E[] arr){
for (int i = 0; i < arr.length - 1; i++) {
if (arr[i].compareTo(arr[i + 1]) > 0) {
return false;
}
}
return true;
}
public static<E extends Comparable<E>> void testTime(String AlgorithmName, E[] arr) {
long startTime = System.nanoTime();
if (AlgorithmName.equals("HeapSort")) {
HeapSort.sort(arr);
}
if (AlgorithmName.equals("QuickSort3Ways")) {
QuickSort.sort3ways(arr);
}
if (AlgorithmName.equals("MergeSort")) {
MergeSort.sort(arr);
}
long endTime = System.nanoTime();
if (!Verify.isSorted(arr)){
throw new RuntimeException(AlgorithmName + "算法排序失败!");
}
System.out.println(String.format("%s算法,测试用例为%d,执行时间:%f秒", AlgorithmName, arr.length, (endTime - startTime) / 1000000000.0));
}
}