JoyBeanRobber

导航

动态规划问题

今天写python课程的题目是,有一题是这样的:
如何计算将一个字符序列转换为另一个字符序列所需的最短步骤数,插入、删除、修改单个字符均视为一步操作。

我的想法是:

第一步:从匹配串第一个字符开始遍历,每次从目标串中第一个字符开始遍历,若目标串存在未被匹配过的与该字符相同的字符,记录索引,否则记录-1。最终得到匹配序列。

def minimum_mewtations(typed, source, limit):
    match_list = []
    for i in typed:
        matched = False
        for j in range(len(source)):
            if source[j] == i:
                if j not in match_list:
                    match_list.append(j)
                    matched = True
                    break
        if not matched:
            match_list.append(-1)
    print(match_list)

比如从“ckiteus”到“kittens”,匹配序列为:-1,0,1,2,4,-1,6

第二步:再遍历这个序列,尝试进行操作,比如将第一个-1删除,第二个-1左小右大就是修改,然后2和4之间需要增加......

实际动手实现第二步时发现有很多问题,首先是要明确最大递增子序列长度,这个也许可以实现,然后发现的致命问题是:匹配序列有些情况下无法很好地表达匹配情况,我设置了“若目标串存在未被匹配过的与该字符相同的字符”,以匹配在目标串中多次出现的字符,比如“kittens”中的“tt”,但是可能会造成字符占用情况:比如,从“eittens”到“kittens”,匹配序列为:4,1,2,3,-1,5,6,在进行匹配串的第二个e的匹配时,发生了目标串中的e被提前占用的情况,这个真不好解决,也许可以改为使用拉链表,那么从字符到序列就没意义了。

 

我放弃了靠自己解决这道题的想法,上网查答案,发现应该依靠动态规划的方法解决。

动态规划是什么?就是将问题拆分为一个个子问题,和普通递归算法不同的是:计算过程中,存储计算结果以减少重复计算。

比如经典的斐波那契数列:

def fib(n):
    if n <= 1:
        return n
    return fib(n-1) + fib(n-2)

执行fib(10),需要计算fib(9)和fib(8),而计算fib(9)时,又需要计算一遍fib(8),这显然是计算资源的浪费。

使用动态规划实现斐波那契数列求解:(迭代和递归)

def fib(n):
    dp = [0]*(n+1)
    dp[1] = 1
    for i in range(2, n+1):
        dp[i] = dp[i-1] + dp[i-2]
    return dp[n]
def fib(n):
    memo = {}

    def helper(k):
        if k in memo:
            return memo[k]
        if k <= 1:
            return k
        memo[k] = helper(k-1) + helper(k-2)
        return memo[k]
    return helper(n)

接下来,尝试用动态规划思想解决之前第二步遇到的问题,即求最大递增子序列,来进行一次练手:

示例:

输入:nums = [10,9,2,5,3,7,101,18]
输出:4
解释:最长递增子序列是 [2,3,7,101],因此长度为 4 。
输入:nums = [0,1,0,3,2,3]
输出:4

思考:以下将最大递增子序列简称为 lis(longest increaing subseqeue),向序列的一端加入一个新元素,其 lis 会发生什么变化呢,要么仍然是未加入元素前的 lis ,要么是以新元素为最后一个元素的 lis 。等等,换一个角度思考,以原序列中某一个元素 i 结尾的 lis 的长度,应该是在以 i 之前的、比 i 值更小的元素们结尾的 lis 的最大长度加一,方程表达就是 dp[ i ] = max(dp[ j ]) + 1 (0 <= j < i & num[ j ] < num[ i ])。所以要存储以每个元素结尾的lis长度,此时我觉得在我脑袋里已经实现这个算法了。

代码:

def find_lis(seq):
    dp = [1] * len(seq)
    for i in range(1, len(seq)):
        previous_lis_list = [dp[j] for j in range(i) if seq[j] < seq[i]]
        if previous_lis_list:
            max_dp = max(previous_lis_list)
            dp[i] = max_dp + 1
    return max(dp)

更标准的写法:

def find_lis(seq):
    if not seq:
        return 0
    dp = [1] * len(seq)
    for i in range(1, len(seq)):
        for j in range(i):
            if seq[j] < seq[i]:
                dp[i] = max(dp[i], dp[j]+1)
    return max(dp)

 

接下来,尝试解决字符串转换步骤数计算的问题:
思考(自己没有推导出来,只是做答案解析):三种操作插入、删除、修改,分析可得修改是插入和删除的叠加,联想到矢量的叠加 (1,0)+(0,1)=(1,1)。

设置一个二维数组dp[m][n],行数m为匹配字符串的长度+1,列数n为目标字符串的长度+1。某一位置 dp[i][j] 的值是从匹配串 str1[:i] 到目标串 str2[:j] 所需要的最短步骤数。

初始化边界:显然 dp[0][0]是从空字符串到空字符串,最短步骤为0,第一排和第一列,是从空字符串到字符串,和从字符串到空字符串,所以 dp[i][0] = i,dp[0][j] = j

如何确定dp[i][j]的值(即从匹配串 str1[:i] 到目标串 str2[:j] 所需要的最短步骤数)?

若 str1[i-1]==str2[j-1],说明 str1[:i] 到与 str2[:j] 最后一个元素相同,那么 dp[i][j] = dp[i-1][j-1]

否则,str2[:j] 可有以下三种方法得到:

1.将从 str1[:i] 到 str2[:j] 的转换改变为从 str1[:i-1] 到 str2[:j] 的转换,然后 str1[i] 删除末尾元素

2.将从 str1[:i] 到 str2[:j] 的转换改变为从 str1[:i-1] 到 str2[:j-1] 的转换,然后str1[i] 替换末尾元素

3.将从 str1[:i] 到 str2[:j] 的转换改变为从 str1[:i] 到 str2[:j-1] 的转换,然后str1[i] 在末尾插入元素

因此,从 str1[:i] 到 str2[:j] 转换的步骤最小值为:

min(dp[i-1][j], dp[i][j-1], dp[i-1][j-1]) + 1

完整代码如下,为了方便理解,将dp矩阵每一步的过程可视化:

def minimum_mewtations(typed, source):
    m, n = len(typed), len(source)
    dp = [[0] * (n+1) for _ in range(m+1)]
    for i in range(m+1):
        dp[i][0] = i
    for j in range(n+1):
        dp[0][j] = j
    display_dp(dp, typed, source)

    for i in range(1, m+1):
        for j in range(1, n+1):
            if typed[i-1] == source[j-1]:
                dp[i][j] = dp[i-1][j-1]
            else:
                dp[i][j] = min(dp[i-1][j], dp[i][j-1], dp[i-1][j-1]) + 1
            display_dp(dp, typed, source)
    return dp[m][n]


RED = "\033[31m"
RESET = "\033[0m"


def display_dp(dp, str1, str2):
    print("        ", end='')
    for s in str2:
        print(f"{RED+s+RESET}", end='   ')
    print('\n')

    def print_list(val_list):
        for val in val_list:
            print(val, end='   ')
        print('\n')
    str1 = " "+str1
    for i in range(len(str1)):
        print(f"{RED+str1[i]+RESET}", end='   ')
        print_list(dp[i])
    print('\n')

 

 

 

 

 

 

 

 

 

 

 

 



posted on 2025-07-31 22:17  欢乐豆掠夺者  阅读(9)  评论(0)    收藏  举报