代码随想录算法训练营|Day 22

Day 22

第七章 回溯算法part01

理论基础

其实在讲解二叉树的时候,就给大家介绍过回溯,这次正式开启回溯算法,大家可以先看视频,对回溯算法有一个整体的了解。

题目链接/文章讲解:https://programmercarl.com/回溯算法理论基础.html
视频讲解:https://www.bilibili.com/video/BV1cy4y167mM

Notes:

回溯算法三部曲:

  1. 递归函数参数
  2. 递归终止条件
  3. 单层搜索的逻辑

只要有递归,就会有回溯

回溯法是纯暴力搜索,并不高效。但有些问题能依靠回溯暴力搜索出来已经很好了。

问题类型:
img

  • 组合问题:给定集合,在集合中找出大小为2的组合

  • 切割问题:给定字符串,有几种切割方式
    带附加条件:给定字符串,如何切割才能保证它的子串都是回文子串?有几种切割方式?

  • 子集问题:把子集合列出

  • 排列问题:组合:强调没有顺序。
    集合{1,2},只有一种组合,那就是{1,2}
    排列有两种: [1,2], [2,1]

  • 棋盘问题:n皇后,解数独

理解回溯法:

img

回溯法可以抽象为一个树形结构-> n叉树

回溯是一个递归的过程,而递归一定有终止

树的深度,就是递归的深度

回溯法的模版:

void backtracking(参数){
    if (终止条件){
        收集结果;
        return;
    }
    for (选择:本层集合中元素(树中节点孩子的数量就是集合的大小)) {
    处理节点;
    backtracking(路径,选择列表); // 递归
    回溯,撤销处理结果
}
    return;
}

终止的时候,就到我们收集结果的时候 (只有子集问题是在每个节点都要去收集结果

单层搜索的逻辑:一般情况下是个for循环

处理节点:for循环参数用来处理集合中的每一个元素(for循环遍历的是集合中每个元素->对应节点所有子节点的个数

img

77. 组合

对着 在 回溯算法理论基础 给出的 代码模板,来做本题组合问题,大家就会发现 写回溯算法套路。

在回溯算法解决实际问题的过程中,大家会有各种疑问,先看视频介绍,基本可以解决大家的疑惑。

本题关于剪枝操作是大家要理解的重点,因为后面很多回溯算法解决的题目,都是这个剪枝套路。

题目链接/文章讲解:https://programmercarl.com/0077.组合.html
视频讲解:https://www.bilibili.com/video/BV1ti4y1L7cv
剪枝操作:https://www.bilibili.com/video/BV1wi4y157er

class Solution:
    def combine(self, n: int, k: int) -> List[List[int]]:
        res=[]
        self.backtracking(n,k,1,[],res)
        return res
    def backtracking(self, n, k, start, sub_res,res) -> None:
        if len(sub_res) == k:
            res.append(sub_res[:])
            return
        for i in range(start, n+2-(k-len(sub_res))):
            sub_res.append(i)
            self.backtracking(n,k,i+1,sub_res, res)
            sub_res.pop()

回溯三部曲

  • 递归函数 返回值+参数

    • 必定有 n, k
    • startIndex ->记录下一层递归,搜索的起始位置
    • img
  • 回溯函数 终止条件

    • img
  • 单层搜索的过程

    • img

剪枝优化

img


“可以剪枝的地方就在递归中每一层的for循环所选择的起始位置。

如果for循环选择的起始位置之后的元素个数 已经不足 我们需要的元素个数了,那么就没有必要搜索了。”


img

img

优化之后的for循环是:

for (int i = startIndex; i <= n - (k - path.size()) + 1; i++) // i为本次搜索的起始位置

img

216.组合总和III

如果把 组合问题理解了,本题就容易一些了。

题目链接/文章讲解:https://programmercarl.com/0216.组合总和III.html
视频讲解:https://www.bilibili.com/video/BV1wg411873x

代码随想录解法

class Solution:
    def combinationSum3(self, k: int, n: int) -> List[List[int]]:
        result = []  # 存放结果集
        self.backtracking(n, k, 0, 1, [], result)
        return result

    def backtracking(self, targetSum, k, currentSum, startIndex, path, result):
        #不够k个已超过和
        if currentSum > targetSum:  # 剪枝操作
            return  # 如果currentSum已经超过targetSum,则直接返回
        #够k个:等和 or 不等和
        if len(path) == k:

            if currentSum == targetSum:
                result.append(path[:])
            return
        for i in range(startIndex, 9 - (k - len(path)) + 2):  # 剪枝
            currentSum += i  # 处理
            path.append(i)  # 处理
            self.backtracking(targetSum, k, currentSum, i + 1, path, result)  # 注意i+1调整startIndex
            currentSum -= i  # 回溯
            path.pop()  # 回溯

可以过的解法

使用到目前为止的sum

class Solution:
    def combinationSum3(self, k: int, n: int) -> List[List[int]]:
        res=[]
        self.backtracking(k,0,n,1,[],res)
        return res
    def backtracking(self, k, curr_sum, end_sum,start, sub_res, res) -> List[List[int]]:
        if curr_sum > end_sum:
            return
        if len(sub_res) == k and curr_sum == end_sum:
            res.append(sub_res[:])
            return
        for i in range(start, 9+2-(k-len(sub_res))):    
            sub_res.append(i)
            self.backtracking(k,sum(sub_res),end_sum, i+1,sub_res, res)
            sub_res.pop()
        

使用到目前为止的差

class Solution:
    def backtracking(self, k,n, diff, start, sub, res):
        if diff == 0 and len(sub) == k:
            res.append(sub[:])
            return
        if start > diff:
            return
        for i in range(start, 9-(k-len(sub))+2):
            sub.append(i)
            self.backtracking(k, n, diff-i, i+1, sub, res)
            sub.pop()
    def combinationSum3(self, k: int, n: int) -> List[List[int]]:
        res = []
        self.backtracking(k, n, n, 1, [], res)
        return res

comments:
不算“彻底”。你现在的两处剪枝思路是对的,但还缺了几个关键点:

  • 你没有在循环里用 i > diff 直接停止(后面的更大数只会更大),因此会多跑不少无效分支;
  • 你没有做“可行性区间”剪枝:选了 i 之后,剩下 r-1 个数(严格递增、范围在 i+1..9)的最小可能和最大可能和可以提前算出来,如果目标剩余和不在这个区间,分支可以立刻跳过;
  • 进入递归后也可以先判断 diff < 0k 已用完但 diff ≠ 0 直接返回。

下面给出一个在你代码基础上补齐这些剪枝的版本(思路不变,剪枝更“到位”):

class Solution:
    def combinationSum3(self, k: int, n: int) -> List[List[int]]:
        res = []

        def dfs(start: int, k_left: int, diff: int, path: list[int]) -> None:
            ## k_left:剩余个数
            ## diff:剩余差

            # 命中
            if diff == 0 and k_left == 0:
                res.append(path[:])
                return
            # 不可行
            ## 没选满k个数但是 diff已经<= 0
            ## 选择了k个数但是 凑不够n
            if diff <= 0 or k_left == 0:
                return

            ###没选满k个,也没凑满n的情况
            ###仍然可以剪枝
            # k = 3, n= 9
            # sub = [1, 7]
    
            r = k_left  # 还需要选 r 个数
            # 上界:保证还能选出 r 个不同的数(包含 9-r+1)
            upper = 10 - r

            # 额外上界:i 也不必超过 diff
            for i in range(start, min(upper, diff) + 1):
                # 可行性区间剪枝:选择 i 之后,还需 r-1 个数
                r1 = r - 1
                # 剩余 r1 个数的最小可能和: (i+1) + (i+2) + ... + (i+r1)
                min_after = r1 * (2 * (i + 1) + (r1 - 1)) // 2
                # 剩余 r1 个数的最大可能和: 9 + 8 + ... + (9 - r1 + 1)
                max_after = r1 * (19 - r1) // 2
                rest = diff - i
                if rest < min_after or rest > max_after:
                    continue

                path.append(i)
                dfs(i + 1, r1, rest, path)
                path.pop()

        dfs(1, k, n, [])
        return res

为什么这些剪枝能“彻底”?

  1. i > diff 时直接不再尝试更大的 i,因为序列递增,和只会更大——这一条保证了单层循环的无效搜索被截断。
  2. “可行性区间”保证了选择 i 后,剩余 r-1 个数确实存在可达到 rest = diff - i 的解;不在区间就 continue,避免了大量深层无效递归。
  3. 入口的 diff <= 0 / k_left == 0 的快速返回,砍掉显然不可能的分支。

复杂度方面:理论上问题规模只有 1..9、长度 k,最坏仍是组合数量级 O(C(9,k)),但上述三处剪枝把大多数无效分支都挡掉了,基本接近最优搜索量。
空间复杂度 O(k)(递归栈 + 路径)。

如果你想只在你原框架上做最小改动,也可以保留你的函数签名,只加两行关键剪枝:

for i in range(start, 10 - (k - len(sub))):  # 和你原来等价
    if i > diff:  # ✅ 单层停止
        break
    # ✅ 可行性区间
    r1 = k - (len(sub) + 1)
    min_after = r1 * (2 * (i + 1) + (r1 - 1)) // 2
    max_after = r1 * (19 - r1) // 2
    rest = diff - i
    if rest < min_after or rest > max_after:
        continue
    ...

这样,你的解法就算是“剪枝到位”了。

我帮你拆开解释一下这几行。它的目的就是:判断在选择了当前数 i 之后,剩下的 r1 = r-1 个数是否还有可能凑出目标和 rest = diff - i


1. 剩余要选的数量

r1 = r - 1
  • 你当前选了 i,所以还需要选 r1 个数。

2. 剩余数的最小可能和

min_after = r1 * (2 * (i + 1) + (r1 - 1)) // 2

这里用了等差数列求和公式。

  • 剩下的数必须比 i 大(严格递增)。
  • 所以最小的选择是:i+1, i+2, ..., i+r1
  • 这是一个首项 a1 = i+1,末项 an = i+r1,项数 r1 的等差数列。
  • 和公式:(a1 + an) * r1 // 2
    化简:( (i+1) + (i+r1) ) * r1 // 2 = (2i + r1 + 1) * r1 // 2
  • 代码里写的 r1 * (2*(i+1) + (r1-1)) // 2,其实就是这个展开式。

直观: 这是剩下 r1 个数能达到的最小和


3. 剩余数的最大可能和

max_after = r1 * (19 - r1) // 2
  • 最大的选择就是从 9 往下取 r1 个数:9, 8, ..., 9-r1+1
  • 首项 a1 = 9-r1+1 = 10-r1,末项 an = 9,项数 r1
  • 和公式:(a1 + an) * r1 // 2 = (10-r1 + 9) * r1 // 2 = (19-r1)*r1 // 2

直观: 这是剩下 r1 个数能达到的最大和


4. 判断是否可行

rest = diff - i
if rest < min_after or rest > max_after:
    continue
  • rest 是我们必须用剩下 r1 个数拼出的目标和。
  • 如果 rest 小于 min_after(目标太小)或者大于 max_after(目标太大),就说明无解,直接 continue 跳过这一支。

有两类不同的剪枝在配合:

  1. 数量可行性(count guard)
    需要还选 r 个数时,循环上界要保证能取到这么多不同的数:
    i ≤ 10 - r(也就是 upper = 10 - r)。
    这能避免出现“根本凑不够 k 个数”的分支。例如:还需要 3 个数(r=3),那么 upper=7,循环不会让你从 8 或 9 开始,自然就不会走进“数量不够”的死路。

  2. 和的可行性(sum-range guard)
    选了 i 后还需 r1 = r-1 个数:

  • 最小可能和:min_after = (i+1) + (i+2) + ... + (i+r1)
  • 最大可能和:max_after = 9 + 8 + ... + (9-r1+1)

rest = diff - i。如果 rest < min_afterrest > max_after,就剪掉。

这样把“数量可行性”和“和的可行性”分开理解,你会更清楚每一层在剪什么分支。

17.电话号码的字母组合

本题大家刚开始做会有点难度,先自己思考20min,没思路就直接看题解。

题目链接/文章讲解:https://programmercarl.com/0017.电话号码的字母组合.html
视频讲解:https://www.bilibili.com/video/BV1yV4y1V7Ug

class Solution:
    def letterCombinations(self, digits: str) -> List[str]:
        keyboards = [" ", " ", "abc", "def", "ghi", "jkl", "mno", "pqrs", "tuv", "wxyz"]
        res = []
        #控制我们到digits的第几个位置:index
        def _dfs(digits, index, sub):
            if index == len(digits):
                res.append(sub)
                return
            options = keyboards[int(digits[index])]
            for letter in options:
                _dfs(digits, index+1, sub+letter)
        if digits:
            _dfs(digits, 0, "")
        return res  

img

class Solution:
    def letterCombinations(self, digits: str) -> List[str]:
        map = ["","", "abc","def","ghi","jkl","mno","pqrs","tuv","wxyz"]
        res = []
        if digits:
            self.backtracking(digits,map, 0,"",res)
        return res
    
    def backtracking(self,digits,map,index,sub_res,res):
        if len(sub_res) == len(digits):
            res.append(sub_res)
            return 
        
        all_ele = map[int(digits[index])]
        for i in range(len(all_ele)):
            sub_res += all_ele[i]
            self.backtracking(digits,map,index+1, sub_res,res)
            sub_res = sub_res[:-1]
posted @ 2025-08-31 19:32  ForeverEver333  阅读(13)  评论(0)    收藏  举报