Solution 5: Top K
问题描述
输入一个整形数组和K,输出数组中前K大的元素们。
解决思路
思路1:排序
如果用快排,平均时间复杂度为O(nlogn),最坏时间复杂度为O(n^2);空间复杂度为O(logn)~O(n);
如果用堆排,时间复杂度为O(nlogn),空间复杂度为O(1).
注意:
Java中Arrays.sort()方法默认实现为归并排序,时间O(nlogn),空间O(n)。
思路2:借用快排的partition函数
较思路1的改进在于,不一定要完全将整个数组进行排序,快排中的partition函数能够保证partition后的元素位置之前的元素均大于等于(或小于等于)该指向元素。
平均时间复杂度为O(n),最坏时间复杂度和快排的一样O(n^2)。
思路3:大数据下的堆排
如果场景为数据量很大,或者甚至是无穷量的数据时,此时可借用堆排的思想。
具体做法为,如果是输出前K大,那么需要维护一个大小为K的最小堆,之后的元素与堆顶元素进行比较,如果更大则进入堆中,再调整堆。
时间复杂度为O(n*logk + k),空间复杂度为O(k)。
程序
public class TopK {
// sort
public List<Integer> getTopKBySort(int[] nums, int k) {
List<Integer> res = new ArrayList<Integer>();
if (nums == null || nums.length == 0 || nums.length < k || k <= 0) {
return res;
}
Arrays.sort(nums);
for (int i = 0; i < k; i++) {
res.add(nums[i]);
}
return res;
}
// partition
public List<Integer> getTopKByPartition(int[] nums, int k) {
List<Integer> res = new ArrayList<Integer>();
if (nums == null || nums.length == 0 || nums.length < k || k <= 0) {
return res;
}
int part = partition(nums, 0, nums.length - 1);
while (true) {
if (part == k - 1) {
for (int i = 0; i < k; i++) {
res.add(nums[i]);
}
break;
} else if (part < k - 1) {
part = partition(nums, part + 1, nums.length - 1);
} else {
part = partition(nums, 0, part - 1);
}
}
return res;
}
private int partition(int[] nums, int begin, int end) {
int low = begin - 1, high = end;
int pivot = nums[end];
while (true) {
while (low < high && nums[++low] >= pivot) {
;
}
while (low < high && nums[--high] <= pivot) {
;
}
if (low >= high) {
break;
}
swap(nums, low, high);
}
swap(nums, low, end);
return low;
}
private void swap(int[] nums, int low, int high) {
int tmp = nums[low];
nums[low] = nums[high];
nums[high] = tmp;
}
// heap
public int[] getTopKByHeap(int[] nums, int k) {
if (nums == null || nums.length == 0 || nums.length < k || k <= 0) {
return null;
}
int[] res = new int[k];
for (int i = 0; i < nums.length; i++) {
if (i < k) {
res[i] = nums[i];
} else if (i == k) {
buildMinHeap(res);
} else {
if (nums[i] > res[0]) {
res[0] = nums[i];
fixMaxDown(res, 0);
}
}
}
return res;
}
private void fixMaxDown(int[] heap, int i) {
int tmp = heap[i];
int j = 2*i +1;
while (j < heap.length) {
while (j+1 < heap.length && heap[j+1] < heap[j]) {
++j;
}
if (tmp<heap[j]) {
break;
}
heap[i] = heap[j];
i = j;
j = 2*i + 1;
}
heap[i] = tmp;
}
private void buildMinHeap(int[] heap) {
for (int i = heap.length/2 - 1; i >= 0; i--) {
fixMaxDown(heap, i);
}
}
}

浙公网安备 33010602011771号