1 package org.example.algorithm.datastruct;
2
3 import java.util.ArrayList;
4 import java.util.Comparator;
5 import java.util.List;
6
7 public class Heap<T> {
8
9 private final Comparator<T> comparator;
10
11 private List<T> keys = new ArrayList<>();
12
13 public Heap() {
14 comparator = null;
15 }
16
17 public Heap(Comparator<T> comparator) {
18 this.comparator = comparator;
19 }
20
21 public Heap(List<T> keys) {
22 comparator = null;
23 this.keys.addAll(keys);
24 buildMaxHeap();
25 }
26
27 public Heap(List<T> keys, Comparator<T> comparator) {
28 this.comparator = comparator;
29 this.keys.addAll(keys);
30 buildMaxHeap();
31 }
32
33 public T top() {
34 return keys.get(0);
35 }
36
37 public T extract() {
38 if (keys.isEmpty())
39 return null;
40 T max = keys.get(0);
41 T last = keys.remove(keys.size() - 1);
42 if (!keys.isEmpty()) {
43 keys.set(0, last);
44 }
45 maxHeapify(0);
46 return max;
47 }
48
49 public void insert(T key) {
50 if (key == null)
51 throw new NullPointerException();
52 if (keys.isEmpty()) {
53 keys.add(key);
54 } else {
55 T last = keys.get(keys.size() - 1);
56 keys.add(compare(last, key) < 0 ? last : key);
57 }
58 increaseKey(keys.size() - 1, key);
59 }
60
61 private static int parent(int index) {
62 return (index - 1) >> 1;
63 }
64
65 private static int left(int index) {
66 return (index << 1) + 1;
67 }
68
69 private static int right(int index) {
70 return (index + 1) << 1;
71 }
72
73 private void maxHeapify(int index) {
74 int left = left(index);
75 int right = right(index);
76 int largest = index;
77 if (left < keys.size() && compare(keys.get(left), keys.get(largest)) > 0)
78 largest = left;
79 if (right < keys.size() && compare(keys.get(right), keys.get(largest)) > 0)
80 largest = right;
81 if (largest != index) {
82 swap(largest, index);
83 maxHeapify(largest);
84 }
85 }
86
87 private void buildMaxHeap() {
88 int n = (keys.size() - 1) >> 1;
89 for (int i = n; i >= 0; i--) {
90 maxHeapify(i);
91 }
92 }
93
94 private void increaseKey(int index, T key) {
95 if (key == null || compare(key, keys.get(index)) < 0)
96 throw new RuntimeException("New key is smaller than current key.");
97 keys.set(index, key);
98 while (index > 0 && compare(keys.get(parent(index)), keys.get(index)) < 0) {
99 swap(index, parent(index));
100 index = parent(index);
101 }
102 }
103
104 @SuppressWarnings("unchecked")
105 private int compare(T key1, T key2) {
106 return comparator != null ? comparator.compare(key1, key2) :
107 ((Comparable<T>)key1).compareTo(key2);
108 }
109
110 private void swap(int index1, int index2) {
111 T temp = keys.get(index1);
112 keys.set(index1, keys.get(index2));
113 keys.set(index2, temp);
114 }
115 }