LIS&LCS
最长上升子序列(LIS)和最长公共子序列(LCS)是DP算法里比较经典的问题了。今天来说说这两个问题的解法,包括常规的动态规划解法,还有一些拓展性的解法。
1. LIS
1.1 LIS长度(Leetcode 300)
1.1.1 动态规划解法
DP问题的最大难点就是选择子问题,子问题选对了,状态转移方程往往比较简单。这里以求解最长生序子序列的长度为例,假设有长度为$n$的数组$nums$,我们用$DP[i]$表示以$nums[i]$结尾的最长上升子序列的长度,递推关系可以表示成:$$DP[j]=max\{DP[i]\ for\ 0 \leq i < j \ \ if\ nums[i] < nums[j]\} + 1$$
这个递推关系是比较明显的,代码写起来也简单。时间复杂度为$O(n^2)$,空间复杂度为$O(n)$,下面贴出代码。
1 class Solution: 2 def lengthOfLIS(self, nums: List[int]) -> int: 3 if not nums: 4 return 0 5 dp = [1] * len(nums) 6 for j in range(len(nums)): 7 dp[j] = max((dp[i] for i in range(j) if nums[i] < nums[j]), default=0) + 1 8 return max(dp)
1.1.2 贪心+二分解法
相比于上面的直观的动态规划解法,贪心+二分的解法就不是那么容易想到了。设置一个贪心数组$greedy$,$greedy[i]$表示的长度为$i+1$(因为数组下标从0开始)的升序序列结尾元素的最小值。显然对于一个上升序列,其结尾的元素越小,越有利于后面接其他的元素,就会变得越长。可以用反证法证明,$greedy$一定是升序的,程序结束时,$greedy$数组的长度就是$LIS$的长度。遍历$nums$中的每个元素,$greedy$数组的更新策略如下(用python描述):
- $nums[i] > greedy[-1]$,直接将$nums[i]$加到$greedy$数组的尾部:$greedy.append(nums[i])$
- 否则,在$greedy$中通过二分查询第一个大于等于$nums[i]$的位置,记为$j$,令$greedy[j] = nums[i]$。
代码如下...
1 class Solution: 2 def lengthOfLIS(self, nums: List[int]) -> int: 3 ans = [] 4 for num in nums: 5 if not ans or num > ans[-1]: 6 ans.append(num) 7 ans[bisect.bisect_left(ans, num)] = num 8 return len(ans)
1.1.3 树状数组和线段树解法
这种解法的思路就是$1.1.1$里面提到的动态规划解法,但是给出的动态规划朴素版本中计算$max$的时间复杂度为$O(n)$,而这种条件最值用树状数组和线段树维护的时间复杂度为$O(lgn)$,这里分别用树状数组和线段树来重写一下动态规划算法。对树状数组和线段树不了解的看<<稀疏表、树状数组和线段树>>。
1 # 树状数组解法 2 class BIT: 3 def __init__(self, n): 4 self.bit = [0] * n 5 6 def add(self, pos, val): 7 while pos < len(self.bit): 8 self.bit[pos] = max(self.bit[pos], val) 9 pos += pos & (-pos) 10 11 def findMax(self, pos): 12 res = 0 13 while pos > 0: 14 res = max(self.bit[pos], res) 15 pos -= pos & (-pos) 16 return res 17 18 19 class Solution: 20 def lengthOfLIS(self, nums: List[int]) -> int: 21 # 先离散化 22 a = sorted(set(nums)) 23 m = {val: idx + 1 for idx, val in enumerate(a)} 24 bit = BIT(len(a) + 1) 25 res = 0 26 for num in nums: 27 pos = m[num] 28 val = bit.findMax(pos - 1) + 1 29 res = max(val, res) 30 bit.add(pos, val) 31 return res
1 # 线段树解法。参考"单点更新,区间查询"的线段树模版 2 class SegmentTree: 3 def __init__(self, n): 4 self.STree = [[0, 0, 0] for _ in range(4*n)] # 分别表示左右区间和最大值 5 self._build(1, 0, n - 1) 6 7 def _build(self, rt, start, end): 8 self.STree[rt][0], self.STree[rt][1] = start, end 9 if start == end: 10 return 11 left, right, mid = rt << 1, rt << 1 | 1, start + end >> 1 12 self._build(left, start, mid) 13 self._build(right, mid + 1, end) 14 15 def add(self, rt, pos, val): 16 if self.STree[rt][0] == pos and self.STree[rt][1] == pos: 17 self.STree[rt][2] = val 18 return val 19 left, right, mid = rt << 1, rt << 1 | 1, self.STree[rt][0] + self.STree[rt][1] >> 1 20 if pos <= mid: 21 left_result = self.add(left, pos, val) 22 self.STree[rt][2] = max(self.STree[rt][2], left_result) 23 else: 24 right_result = self.add(right, pos, val) 25 self.STree[rt][2] = max(self.STree[rt][2], right_result) 26 return self.STree[rt][2] 27 28 def findMax(self, rt, qstart, qend): 29 if qstart > qend: 30 return 0 31 if self.STree[rt][0] == qstart and self.STree[rt][1] == qend: 32 return self.STree[rt][2] 33 left, right, mid = rt << 1, rt << 1 | 1, self.STree[rt][0] + self.STree[rt][1] >> 1 34 if qend <= mid: 35 return self.findMax(left, qstart, qend) 36 elif qstart > mid: 37 return self.findMax(right, qstart, qend) 38 else: 39 return max(self.findMax(left, qstart, mid), self.findMax(right, mid + 1, qend)) 40 41 42 class Solution: 43 def lengthOfLIS(self, nums: List[int]) -> int: 44 if not nums: 45 return 0 46 # 先离散化 47 a = sorted(set(nums)) 48 m = {val: idx for idx, val in enumerate(a)} 49 stree = SegmentTree(len(a)) 50 res = 0 51 for num in nums: 52 pos = m[num] 53 val = stree.findMax(1, 0, pos - 1) + 1 54 res = max(val, res) 55 stree.add(1, pos, val) 56 return res
1.2 LIS数量(LeetCode 673)
1.2.1 动态规划解法
1 class Solution(object): 2 def findNumberOfLIS(self, nums): 3 N = len(nums) 4 if N <= 1: return N 5 lengths = [0] * N #lengths[i] = longest ending in nums[i] 6 counts = [1] * N #count[i] = number of longest ending in nums[i] 7 8 for j, num in enumerate(nums): 9 for i in xrange(j): 10 if nums[i] < nums[j]: 11 if lengths[i] >= lengths[j]: 12 lengths[j] = 1 + lengths[i] 13 counts[j] = counts[i] 14 elif lengths[i] + 1 == lengths[j]: 15 counts[j] += counts[i] 16 17 longest = max(lengths) 18 return sum(c for i, c in enumerate(counts) if lengths[i] == longest)
1.2.2 线段树解法
1 class Node(object): 2 def __init__(self, start, end): 3 self.range_left, self.range_right = start, end 4 self._left = self._right = None 5 self.val = 0, 1 #length, count 6 @property 7 def range_mid(self): 8 return (self.range_left + self.range_right) / 2 9 @property 10 def left(self): 11 self._left = self._left or Node(self.range_left, self.range_mid) 12 return self._left 13 @property 14 def right(self): 15 self._right = self._right or Node(self.range_mid + 1, self.range_right) 16 return self._right 17 18 def merge(v1, v2): 19 if v1[0] == v2[0]: 20 if v1[0] == 0: return (0, 1) 21 return v1[0], v1[1] + v2[1] 22 return max(v1, v2) 23 24 def insert(node, key, val): 25 if node.range_left == node.range_right: 26 node.val = merge(val, node.val) 27 return 28 if key <= node.range_mid: 29 insert(node.left, key, val) 30 else: 31 insert(node.right, key, val) 32 node.val = merge(node.left.val, node.right.val) 33 34 def query(node, key): 35 if node.range_right <= key: 36 return node.val 37 elif node.range_left > key: 38 return 0, 1 39 else: 40 return merge(query(node.left, key), query(node.right, key)) 41 42 class Solution(object): 43 def findNumberOfLIS(self, nums): 44 if not nums: return 0 45 root = Node(min(nums), max(nums)) 46 for num in nums: 47 length, count = query(root, num-1) 48 insert(root, num, (length+1, count)) 49 return root.val[1]
2. LCS(LeetCode 1143)
2.1 动态规划解法
1 class Solution: 2 def longestCommonSubsequence(self, text1: str, text2: str) -> int: 3 dp = [0] * (len(text2) + 1) 4 for i in range(len(text1)): 5 pre = dp[0] 6 for j in range(len(text2)): 7 if text1[i] == text2[j]: 8 dp[j+1], pre = pre + 1, dp[j+1] 9 else: 10 dp[j+1], pre = max(dp[j+1], dp[j]), dp[j+1] 11 return dp[-1]
2.2 贪心+二分解法
这个解法是在这个题的一个Solution中发现的(链接),思路比较奇特,有兴趣可以去看一下。
1 class Solution { 2 public: 3 int longestCommonSubsequence(string s1, string s2) { 4 int l1 = s1.size(), l2 = s2.size(); 5 vector<vector<int>> counting(128); 6 for (int i = 0;i < l2;i++){ 7 counting[s2[i]].push_back(i); 8 } 9 vector<int> lis; 10 lis.push_back(-1); 11 for(int i = 0;i < l1;i++){ 12 for(int j = counting[s1[i]].size() - 1;j >= 0;j--){ 13 int n = counting[s1[i]][j]; 14 if (n > lis.back()) 15 lis.push_back(n); 16 *lower_bound(lis.begin(), lis.end(), n) = n; 17 } 18 } 19 return lis.size() - 1; 20 21 } 22 };

浙公网安备 33010602011771号