LIS&LCS

最长上升子序列(LIS)和最长公共子序列(LCS)是DP算法里比较经典的问题了。今天来说说这两个问题的解法,包括常规的动态规划解法,还有一些拓展性的解法。

1. LIS

1.1 LIS长度(Leetcode 300

1.1.1 动态规划解法

DP问题的最大难点就是选择子问题,子问题选对了,状态转移方程往往比较简单。这里以求解最长生序子序列的长度为例,假设有长度为$n$的数组$nums$,我们用$DP[i]$表示以$nums[i]$结尾的最长上升子序列的长度,递推关系可以表示成:$$DP[j]=max\{DP[i]\ for\ 0 \leq i < j \ \  if\ nums[i] < nums[j]\} + 1$$

这个递推关系是比较明显的,代码写起来也简单。时间复杂度为$O(n^2)$,空间复杂度为$O(n)$,下面贴出代码。

1 class Solution:
2     def lengthOfLIS(self, nums: List[int]) -> int:
3         if not nums:
4             return 0
5         dp = [1] * len(nums)
6         for j in range(len(nums)):
7             dp[j] = max((dp[i] for i in range(j) if nums[i] < nums[j]), default=0) + 1
8         return max(dp)
View Code

1.1.2 贪心+二分解法

相比于上面的直观的动态规划解法,贪心+二分的解法就不是那么容易想到了。设置一个贪心数组$greedy$,$greedy[i]$表示的长度为$i+1$(因为数组下标从0开始)的升序序列结尾元素的最小值。显然对于一个上升序列,其结尾的元素越小,越有利于后面接其他的元素,就会变得越长。可以用反证法证明,$greedy$一定是升序的,程序结束时,$greedy$数组的长度就是$LIS$的长度。遍历$nums$中的每个元素,$greedy$数组的更新策略如下(用python描述):

  1. $nums[i] > greedy[-1]$,直接将$nums[i]$加到$greedy$数组的尾部:$greedy.append(nums[i])$
  2. 否则,在$greedy$中通过二分查询第一个大于等于$nums[i]$的位置,记为$j$,令$greedy[j] = nums[i]$。

代码如下...

1 class Solution:
2     def lengthOfLIS(self, nums: List[int]) -> int:
3         ans = []
4         for num in nums:
5             if not ans or num > ans[-1]:
6                 ans.append(num)
7             ans[bisect.bisect_left(ans, num)] = num
8         return len(ans)
View Code

1.1.3 树状数组和线段树解法

这种解法的思路就是$1.1.1$里面提到的动态规划解法,但是给出的动态规划朴素版本中计算$max$的时间复杂度为$O(n)$,而这种条件最值用树状数组和线段树维护的时间复杂度为$O(lgn)$,这里分别用树状数组和线段树来重写一下动态规划算法。对树状数组和线段树不了解的看<<稀疏表、树状数组和线段树>>

 1 # 树状数组解法
 2 class BIT:
 3     def __init__(self, n):
 4         self.bit = [0] * n
 5 
 6     def add(self, pos, val):
 7         while pos < len(self.bit):
 8             self.bit[pos] = max(self.bit[pos], val)
 9             pos += pos & (-pos)
10 
11     def findMax(self, pos):
12         res = 0
13         while pos > 0:
14             res = max(self.bit[pos], res)
15             pos -= pos & (-pos)
16         return res
17 
18 
19 class Solution:
20     def lengthOfLIS(self, nums: List[int]) -> int:
21         # 先离散化
22         a = sorted(set(nums))
23         m = {val: idx + 1 for idx, val in enumerate(a)}
24         bit = BIT(len(a) + 1)
25         res = 0
26         for num in nums:
27             pos = m[num]
28             val = bit.findMax(pos - 1) + 1
29             res = max(val, res)
30             bit.add(pos, val)
31         return res
View Code
 1 # 线段树解法。参考"单点更新,区间查询"的线段树模版
 2 class SegmentTree:
 3     def __init__(self, n):
 4         self.STree = [[0, 0, 0] for _ in range(4*n)]  # 分别表示左右区间和最大值
 5         self._build(1, 0, n - 1)
 6 
 7     def _build(self, rt, start, end):
 8         self.STree[rt][0], self.STree[rt][1] = start, end
 9         if start == end:
10             return
11         left, right, mid = rt << 1, rt << 1 | 1, start + end >> 1
12         self._build(left, start, mid)
13         self._build(right, mid + 1, end)
14 
15     def add(self, rt, pos, val):
16         if self.STree[rt][0] == pos and self.STree[rt][1] == pos:
17             self.STree[rt][2] = val
18             return val
19         left, right, mid = rt << 1, rt << 1 | 1, self.STree[rt][0] + self.STree[rt][1] >> 1
20         if pos <= mid:
21             left_result = self.add(left, pos, val)
22             self.STree[rt][2] = max(self.STree[rt][2], left_result)
23         else:
24             right_result = self.add(right, pos, val)
25             self.STree[rt][2] = max(self.STree[rt][2], right_result)
26         return self.STree[rt][2]
27 
28     def findMax(self, rt, qstart, qend):
29         if qstart > qend:
30             return 0
31         if self.STree[rt][0] == qstart and self.STree[rt][1] == qend:
32             return self.STree[rt][2]
33         left, right, mid = rt << 1, rt << 1 | 1, self.STree[rt][0] + self.STree[rt][1] >> 1
34         if qend <= mid:
35             return self.findMax(left, qstart, qend)
36         elif qstart > mid:
37             return self.findMax(right, qstart, qend)
38         else:
39             return max(self.findMax(left, qstart, mid), self.findMax(right, mid + 1, qend))
40 
41 
42 class Solution:
43     def lengthOfLIS(self, nums: List[int]) -> int:
44         if not nums:
45             return 0
46         # 先离散化
47         a = sorted(set(nums))
48         m = {val: idx for idx, val in enumerate(a)}
49         stree = SegmentTree(len(a))
50         res = 0
51         for num in nums:
52             pos = m[num]
53             val = stree.findMax(1, 0, pos - 1) + 1
54             res = max(val, res)
55             stree.add(1, pos, val)
56         return res
View Code

1.2 LIS数量(LeetCode 673

1.2.1 动态规划解法

 1 class Solution(object):
 2     def findNumberOfLIS(self, nums):
 3         N = len(nums)
 4         if N <= 1: return N
 5         lengths = [0] * N #lengths[i] = longest ending in nums[i]
 6         counts = [1] * N #count[i] = number of longest ending in nums[i]
 7 
 8         for j, num in enumerate(nums):
 9             for i in xrange(j):
10                 if nums[i] < nums[j]:
11                     if lengths[i] >= lengths[j]:
12                         lengths[j] = 1 + lengths[i]
13                         counts[j] = counts[i]
14                     elif lengths[i] + 1 == lengths[j]:
15                         counts[j] += counts[i]
16 
17         longest = max(lengths)
18         return sum(c for i, c in enumerate(counts) if lengths[i] == longest)
View Code

1.2.2 线段树解法

 1 class Node(object):
 2     def __init__(self, start, end):
 3         self.range_left, self.range_right = start, end
 4         self._left = self._right = None
 5         self.val = 0, 1 #length, count
 6     @property
 7     def range_mid(self):
 8         return (self.range_left + self.range_right) / 2
 9     @property
10     def left(self):
11         self._left = self._left or Node(self.range_left, self.range_mid)
12         return self._left
13     @property
14     def right(self):
15         self._right = self._right or Node(self.range_mid + 1, self.range_right)
16         return self._right
17 
18 def merge(v1, v2):
19     if v1[0] == v2[0]:
20         if v1[0] == 0: return (0, 1)
21         return v1[0], v1[1] + v2[1]
22     return max(v1, v2)
23 
24 def insert(node, key, val):
25     if node.range_left == node.range_right:
26         node.val = merge(val, node.val)
27         return
28     if key <= node.range_mid:
29         insert(node.left, key, val)
30     else:
31         insert(node.right, key, val)
32     node.val = merge(node.left.val, node.right.val)
33 
34 def query(node, key):
35     if node.range_right <= key:
36         return node.val
37     elif node.range_left > key:
38         return 0, 1
39     else:
40         return merge(query(node.left, key), query(node.right, key))
41 
42 class Solution(object):
43     def findNumberOfLIS(self, nums):
44         if not nums: return 0
45         root = Node(min(nums), max(nums))
46         for num in nums:
47             length, count = query(root, num-1)
48             insert(root, num, (length+1, count))
49         return root.val[1]
View Code

2. LCS(LeetCode 1143

2.1 动态规划解法

 1 class Solution:
 2     def longestCommonSubsequence(self, text1: str, text2: str) -> int:
 3         dp = [0] * (len(text2) + 1)
 4         for i in range(len(text1)):
 5             pre = dp[0]
 6             for j in range(len(text2)):
 7                 if text1[i] == text2[j]:
 8                     dp[j+1], pre = pre + 1, dp[j+1]
 9                 else:
10                     dp[j+1], pre = max(dp[j+1], dp[j]), dp[j+1]
11         return dp[-1]
View Code

2.2 贪心+二分解法

这个解法是在这个题的一个Solution中发现的(链接),思路比较奇特,有兴趣可以去看一下。

 1 class Solution {
 2 public:
 3     int longestCommonSubsequence(string s1, string s2) {
 4         int l1 = s1.size(), l2 = s2.size();
 5         vector<vector<int>> counting(128);
 6         for (int i = 0;i < l2;i++){
 7             counting[s2[i]].push_back(i);
 8         }
 9         vector<int> lis;
10         lis.push_back(-1);
11         for(int i = 0;i < l1;i++){
12             for(int j = counting[s1[i]].size() - 1;j >= 0;j--){
13                 int n = counting[s1[i]][j];
14                 if (n > lis.back())
15                     lis.push_back(n);
16                 *lower_bound(lis.begin(), lis.end(), n) = n;
17             }
18         }
19         return lis.size() - 1;
20 
21     }
22 };
View Code
posted @ 2020-03-09 01:00  wory  阅读(276)  评论(0)    收藏  举报