查找两个已排序数组的中位数(Median of Two Sorted Arrays)
https://leetcode.com/problems/median-of-two-sorted-arrays/
一个常见的思路就是利用递归查找第K个数来实现。
1 int findKth(vector<int>& nums1, int p1, vector<int>& nums2, int p2, int k) { 2 int len1 = nums1.size() - p1, len2 = nums2.size() - p2; 3 if (len1 < len2) { 4 return findKth(nums2, p2, nums1, p1, k); 5 } 6 7 if (p2 == nums2.size()) { 8 return nums1[p1 + k]; 9 } 10 11 if (k == 0) { 12 return min(nums1[p1], nums2[p2]); 13 } 14 15 int m = (k + 1) / 2; 16 int k2 = min(len2, m); 17 int k1 = k + 1 - k2; 18 if (nums2[p2 + k2 - 1] < nums1[p1 + k1 - 1]) { 19 return findKth(nums1, p1, nums2, p2 + k2, k1 - 1); 20 } 21 else { 22 return findKth(nums1, p1 + k1, nums2, p2, k2 - 1); 23 } 24 }
分析:
FindKth传入的参数k代表如果将2个数组看做一个数组,返回以0为初始下标的开始的第k个元素,所以其实是第k+1个元素。
假设要找两个数组中第k大的数,我们可以假设前(k+1)/2个数存在于数组1,剩下的k +1 - (k+1)/2个数存在于数组2,也就是16-17行:
int k2 = min(len2, m); int k1 = k + 1 - k2;
然后比较两个数组这两部分的最后一个元素,也就是18-23行:
if (nums2[p2 + k2 - 1] < nums1[p1 + k1 - 1]) { return findKth(nums1, p1, nums2, p2 + k2, k1 - 1); } else { return findKth(nums1, p1 + k1, nums2, p2, k2 - 1); }
1. 如果nums2的最后一个元素较小,那么nums2的前k2个元素肯定存在于最小的前k个数里,且第k小的数肯定不在这k2个元素里,否则nums2的第k2-1个元素应该不仅比前k个元素大,而且应该比nums1中的前k - k2个元素大,这样它才是两个数组中第k大的元素,但nums2[p2 + k2 - 1] < nums1[p1 + k1 - 1],矛盾。
这样我们已经找到了最小的k个元素这个集合里的j个元素,剩下的k-k2个元素肯定存在于剩下的元素集合里,所以只需继续在两个数组剩下的元素里找到最小的k - k2个元素,我们就找到了第k小的元素。
2. 同理,如果nums2[p2 + k2 - 1] >= nums1[p1 + k1 - 1], 说明a数组里的部分是两个数组两部分里较小的k - j个元素,第k个元素也不存在于这个部分,所以我们只需在两数组剩余的部分查找第k2大的元素即可。
接下来解释一下开头的几个判断:
if (len1 < len2) { return findKth(nums2, p2, nums1, p1, k); }
我们始终把剩余长度较大的数组作为第一个数组,这样是为了简化后续判断逻辑。
if (p2 == nums2.size()) { return nums1[p1 + k]; }
在前一个递归调用中,nums2剩余部分的长度已经不足k/2,所以我们只需在这次调用中直接返回a的第k大元素 。
if (k == 0) { return min(nums1[p1], nums2[p2]); }
循环至最后一个元素时,返回nums1和nums2剩余部分中较小的元素作为递归终止。
完整代码:
1 class Solution{ 2 private: 3 int findKth(vector<int>& nums1, int p1, vector<int>& nums2, int p2, int k) { 4 int len1 = nums1.size() - p1, len2 = nums2.size() - p2; 5 if (len1 < len2) { 6 return findKth(nums2, p2, nums1, p1, k); 7 } 8 9 if (p2 == nums2.size()) { 10 return nums1[p1 + k]; 11 } 12 13 if (k == 0) { 14 return min(nums1[p1], nums2[p2]); 15 } 16 17 int m = (k + 1) / 2; 18 int k2 = min(len2, m); 19 int k1 = k + 1 - k2; 20 if (nums2[p2 + k2 - 1] < nums1[p1 + k1 - 1]) { 21 return findKth(nums1, p1, nums2, p2 + k2, k1 - 1); 22 } 23 else { 24 return findKth(nums1, p1 + k1, nums2, p2, k2 - 1); 25 } 26 } 27 public: 28 double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) { 29 double median = 0; 30 int totalSize = nums1.size() + nums2.size(); 31 if(totalSize > 0){ 32 if(totalSize & 1){ 33 // Odd 34 median = findKth(nums1, 0, nums2, 0, totalSize/2); 35 }else{ 36 // Even 37 median = (findKth(nums1, 0, nums2, 0, totalSize / 2) + findKth(nums1, 0, nums2, 0, totalSize / 2 - 1)) / 2.0; 38 } 39 } 40 41 return median; 42 } 43 };
浙公网安备 33010602011771号