python yield 妙用; send() vs 闭包函数入参

实现可暂停/中断/继续的任务抽象

起因

想到用yield可以保留函数内临时变量,配合next()可以实现可暂停/继续的任务功能:

最小示例

def inner():
    for frame in range(3):
        yield f'{frame=}'
    return 'inner'

def outer():
    ret = yield from inner()
    yield ret
    return 'outer'

gen = outer()
running = True
while running:
    try:
        print(f'{next(gen)=}')
    except StopIteration as e:
        running = False
        print(f'{e.value=}')

期望输出:

next(gen)='frame=0'
next(gen)='frame=1'
next(gen)='frame=2'
next(gen)='inner'
e.value='outer'

return vs 手动raise StopIteration

def gen_a():
    yield 1
    yield 2
    return "正常返回"  # ← 等价效果


def gen_b():
    yield 1
    yield 2
    raise StopIteration("手动抛的")  # ← 你写的这种情况


# 测试
for i, (name, g) in enumerate([("gen_a", gen_a()), ("gen_b", gen_b())]):
    print(f"\n=== {i} {name} ===")
    print(next(g))  # 1
    print(next(g))  # 2

    try:
        print(next(g))
    except StopIteration as e:
        print("第一次耗尽 →", f"{e=}")
        print(f"{e.value=}")

    # 再试一次
    try:
        print(next(g))
    except StopIteration as e:
        print("第二次 next →", e.value)
=== 0 gen_a ===
1
2
第一次耗尽 → e=StopIteration('正常返回')
e.value='正常返回'
第二次 next → None

=== 1 gen_b ===
1
2
Traceback (most recent call last):
  File "test.py", line 182, in gen_b
    raise StopIteration("手动抛的")  # ← 你写的这种情况
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
StopIteration: 手动抛的

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "test.py", line 192, in <module>
    print(next(g))
          ~~~~^^^
RuntimeError: generator raised StopIteration

send() vs 闭包函数入参

def dfs(v=0):
    try:
        # 先yield当前值,接收send传入的新值
        new_v = yield v  
        # 如果传入了新值,用新值替代当前v;否则保持原v
        current_v = new_v if new_v is not None else v
        # 递归处理两个分支
        yield from dfs(current_v + 0)
        yield from dfs(current_v + 1)
    except StopIteration:
        return

# 测试执行
gen = dfs()
next(gen)  # 启动生成器,返回初始值0
print(gen.send(-1))  # 传入-1,返回-1(第一个分支)
print(gen.send(-2))  # 传入-2,返回-2(第二个分支)

也可以写成闭包函数入参写法

def dfs(v=0, func=lambda x: x):
    # 递归终止条件(避免无限递归)
    if v > 3:  # 示例终止条件
        return func([])
    # 递归处理两个分支,收集结果
    left = dfs(v + 0, func)
    right = dfs(v + 1, func)
    # 用func处理当前节点+子节点的结果
    return func([v] + left + right)

# 测试执行:func=-x 改为对列表元素取反
result = dfs(func=lambda lst: [-x for x in lst])
print(result)  # 输出:[0, 0, 1, 1, 2, 2, 3, 3](递归路径的反值)

这个时候选择哪种写法比较好呢?

豆包:

如果你的需求是处理递归结果(如转换、过滤)、无动态传值—— 闭包写法更符合 Python 的 “显式优于隐式” 原则,代码可读性高,新手易维护,且无生成器的 “启动 / 发送” 等易错细节。
仅当你需要动态控制递归过程(如中途修改 v 的值、暂停递归、实时输出递归节点)—— 生成器的协程特性是不可替代的,适合交互式、惰性执行的场景。

Grok:
总体是建议用迭代器的写法

  • 最大的优势:内存友好
    闭包函数如果递归调用,可能会爆栈

  • 如果需要send()改变gen()自身的迭代队列,用迭代器
    闭包函数需要额外维护一个迭代队列

  • 如果需要外部变量与状态,迭代器可以直接访问上下文
    而在设计通用入参时,闭包函数的第三方业务状态只能通过**kwargs来获取,对类型提示不太友好

outer_val = 0
# 迭代器
for v in gen:
	pass # outer_val参与了一系列计算
	if pause:
		gen.close() # 清理资源
		# raise RuntimeError("pause")
	gen.send(final_val)

# 闭包函数
def calc(**kwargs):
	outer_val: int = getitem(kwargs, "outer_val", 0)
	pass # outer_val参与了一系列计算
	if pause:
		raise RuntimeError("pause")
	return final_val

try:
	gen(func=calc)
except RuntimeError as e:
	# 清理资源

本质上,yield/send/close/throw 是一种闭包函数的语法糖

类似options:{children:[{...}]}深层目录树的解析

from pprint import pprint
from collections import deque
from typing import Any, Tuple
from collections.abc import Iterable, Mapping, Callable, Generator, Sequence
from benedict import benedict

type NestNode[V, K] = Iterable[NestNode[V, K]] | Mapping[K, Any | NestNode[V]]


def dfs[V, K](
    node: NestNode[V, K],
    childKey: Sequence[K] = ["children"],
    raiseKeyError=False,
    canSkip: Callable[[NestNode[V, K]], bool] = lambda n: not n,
    log=False,
):
    """
    深度优先遍历 list[dict[list...]] 或 dict[list[dict...]] 的树结构, 可用`list(dfs(...))[0]`获得 flatten展平的结果

    Args:
        node: node|nodes, dict|list
        childKey: key name. 当只有1个元素(即`len(childKey)!=1`)时,会`yield keyName, node`
        raiseKeyError: 若要求多个childKey在每个结构中**必须同时存在**,则设为`True`
        canSkip (node|nodes): 用于减少for循环次数
        log: 是否打印调试信息
    """
    if not childKey:
        return

    initial_key = childKey[0]

    def _dfs(
        current: NestNode[V, K],
        path: tuple[int, ...],
        keyName: K,
        logical_parent: NestNode[V, K] | None,
    ):
        if isinstance(current, Mapping):
            new = yield current, path, keyName, logical_parent
            if new is not None:
                current = new

            for ckey in childKey:
                try:
                    childs = current[ckey]
                except KeyError as e:
                    if raiseKeyError:
                        raise
                    elif log:
                        print(f"{ckey=} not found in {current=}. {e=}")
                    continue

                if not hasattr(childs, "__iter__") or isinstance(
                    childs, (str, Mapping)
                ):
                    continue

                for j, child in enumerate(childs, 1):
                    if canSkip(child):
                        continue
                    yield from _dfs(child, path + (j,), ckey, current)

        else:
            # list-like container of nodes
            try:
                for j, child in enumerate(current, 1):
                    if canSkip(child):
                        continue
                    yield from _dfs(child, path + (j,), keyName, logical_parent)
            except TypeError:
                pass

    if isinstance(node, Mapping):
        return _dfs(node, (1,), initial_key, None)
    else:

        def forest_gen():
            try:
                for i, sub in enumerate(node, 1):
                    if canSkip(sub):
                        continue
                    yield from _dfs(sub, (i,), initial_key, None)
            except TypeError:
                pass

        return forest_gen()


def bfs[V, K](
    node: NestNode[V, K],
    childKey: Sequence[K] = ["children"],
    raiseKeyError=False,
    canSkip: Callable[[NestNode[V, K]], bool] = lambda n: not n,
    log=False,
):
    """
    广度优先遍历 list[dict[list...]] 或 dict[list[dict...]] 的树结构, 可用`list(bfs(...))[0]`获得 flatten展平的结果

    Args:
        node: node|nodes, dict|list
        childKey: key name. 当只有1个元素(即`len(childKey)!=1`)时,会`yield keyName, node`
        raiseKeyError: 若要求多个childKey在每个结构中**必须同时存在**,则设为`True`
        canSkip (node|nodes): 用于减少for循环次数
        log: 是否打印调试信息.
    """
    if not childKey:
        return

    initial_key = childKey[0]

    q: deque[Tuple[NestNode[V, K], tuple[int, ...], K, NestNode[V, K] | None]] = deque()

    # 初始入队,根节点 logical_parent 为 None
    if isinstance(node, Mapping):
        q.append((node, (1,), initial_key, None))
    else:
        try:
            for i, sub in enumerate(node, 1):
                if canSkip(sub):
                    continue
                q.append((sub, (i,), initial_key, None))
        except TypeError:
            return

    while q:
        current, path, keyName, logical_parent = q.popleft()

        if canSkip(current):
            continue

        # 如果是容器(非 Mapping),展开其子节点(不 yield 容器本身)
        if not isinstance(current, Mapping):
            try:
                for j, child in enumerate(current, 1):
                    if canSkip(child):
                        continue
                    q.append((child, path + (j,), keyName, logical_parent))
                continue
            except TypeError:
                continue

        # 当前是 Mapping 节点,yield 并支持修改
        sent = yield current, path, keyName, logical_parent
        if sent is not None:
            current = sent

        # 处理其子节点
        for ckey in childKey:
            try:
                childs = current[ckey]
            except KeyError as e:
                if raiseKeyError:
                    raise
                elif log:
                    print(f"{ckey=} not found in {current=}. {e=}")
                continue

            if not hasattr(childs, "__iter__") or isinstance(childs, (str, Mapping)):
                continue

            for j, child in enumerate(childs, 1):
                if canSkip(child):
                    continue
                q.append((child, path + (j,), ckey, current))


def dictDict(gen: Generator, idKey="id"):
    """
    将 list[dict[list...]] 或 dict[list[dict...]] 的树结构,转为 dict[dict[dict...]] 嵌套字典结构

    Args:
        generator (bfs | dfs): 生成器
        idKey: id key name. 建议唯一, 否则会报错

    Raises:
        KeyError: node内不存在idKey时
        ValueError: 当key在当前层级**重名**时

    Returns:
        dict[dict[dict...]]: 若idKey="id",则 `{1: {11: {}}, 2: {}, 3: {31: {}}}`
    """
    root_dict: dict[Any, dict] = {}
    id_to_subdict: dict[Any, dict] = {}

    for node, _path, _keyName, logical_parent in gen:
        if not isinstance(node, Mapping):
            continue

        try:
            nid = node[idKey]
        except KeyError:
            raise KeyError(f"node 内不存在 {idKey=}")

        if nid in id_to_subdict:
            raise ValueError(f"id {nid} 在树中重复")

        subdict: dict = {}
        id_to_subdict[nid] = subdict

        if logical_parent is None:
            root_dict[nid] = subdict
        else:
            parent_id = logical_parent[idKey]
            parent_subdict = id_to_subdict[parent_id]
            parent_subdict[nid] = subdict

    return root_dict


def test():
    List = [
        {
            "id": 1,
            "text": "主评论1",
            "children": [
                {"id": 11, "text": "回复11"},
                {
                    "id": 12,
                    "text": "回复12",
                    "children": [{"id": 121, "text": "回复121"}],
                },
            ],
        },
        {"id": 2, "text": "主评论2"},
    ]
    ret = list(dfs(List))
    pprint(ret)
    ret = list(bfs(List))
    pprint(ret)
    ret = dictDict(bfs(List))
    pprint(ret)
    ret = dictDict(dfs(List))
    pprint(ret)
    # Dict = benedict(ret[0])
    # print(Dict)
    # print(Dict.keypaths())


if __name__ == "__main__":
    test()

posted @ 2025-07-21 11:54  Nolca  阅读(17)  评论(0)    收藏  举报