python中的迭代器、生成器和装饰器

python中的迭代器、生成器和装饰器

1. 迭代器(Iterator)

迭代器是Python中一个重要的概念,它提供了一种统一的方式来访问各种集合类型中的元素。

1.1 迭代器基础

迭代器是一个实现了__iter__()和__next__()方法的对象:

__iter__():返回迭代器对象本身
__next__():返回下一个元素,如果没有更多元素则抛出StopIteration异常

# 使用内置函数iter()和next()
my_list = [1, 2, 3, 4, 5]
iterator = iter(my_list)  # 获取迭代器

print(next(iterator))  # 输出: 1
print(next(iterator))  # 输出: 2
print(next(iterator))  # 输出: 3
print(next(iterator))  # 输出: 4
print(next(iterator))  # 输出: 5
# print(next(iterator))  # 抛出StopIteration异常

迭代器的最大优势在于它的惰性计算特性:只有在需要时才计算下一个值,而不是一次性生成所有值。这使得它们特别适合处理大型数据集和无限序列。

当我们使用for循环遍历一个可迭代对象时,Python会在背后做以下工作:

  1. 调用iter()函数获取迭代器对象
  2. 不断调用next()函数获取下一个元素
  3. 捕获StopIteration异常以确定迭代结束
# for循环的内部实现类似于:
def for_loop_simulation(iterable):
    # 获取迭代器
    iterator = iter(iterable)
    while True:
        try:
            # 获取下一个值
            item = next(iterator)
            # 处理值
            print(item)
        except StopIteration:
            # 没有更多元素,退出循环
            break

# 使用自定义for循环模拟
for_loop_simulation([1, 2, 3, 4, 5])

1.2 创建自定义迭代器

让我们创建一个简单的自定义迭代器,生成斐波那契数列:

class Fibonacci:
    def __init__(self, limit):
        self.limit = limit
        self.a, self.b = 0, 1
        self.count = 0
    
    def __iter__(self):
        return self
    
    def __next__(self):
        if self.count >= self.limit:
            raise StopIteration
        
        result = self.a
        self.a, self.b = self.b, self.a + self.b
        self.count += 1
        
        return result

# 使用自定义迭代器
fib = Fibonacci(8)
for num in fib:
    print(num, end=" ")  # 输出: 0 1 1 2 3 5 8 13

1.3 可迭代对象vs迭代器

  • 可迭代对象(Iterable):实现了__iter__()方法,可以通过iter()函数获取迭代器
  • 迭代器(Iterator):同时实现了__iter__()和__next__()方法

Python中大多数集合类型(如列表、元组、字符串、字典等)都是可迭代对象,但不是迭代器。

# 判断对象是否是可迭代对象或迭代器

from collections.abc import Iterable, Iterator

print(isinstance([1, 2, 3], Iterable))   # True
print(isinstance([1, 2, 3], Iterator))   # False
print(isinstance(iter([1, 2, 3]), Iterator))  # True

1.4 迭代器的实际应用

1.4.1 数据流处理

迭代器特别适合处理数据流,因为它们可以逐个处理元素而不需要将整个数据集加载到内存中:

def process_large_file(file_path):
    """逐行处理大型文件"""
    with open(file_path, 'r') as f:
        for line in f:  # 文件对象是一个迭代器,逐行产生内容
            # 处理单行数据
            processed_line = line.strip().upper()
            yield processed_line

# 假设我们有一个很大的日志文件
# for line in process_large_file("large_log.txt"):
#     if "ERROR" in line:
#         print(line)

1.4.2 自定义数据过滤器

使用迭代器创建数据过滤管道:

class NumberFilter:
    """过滤数字序列的迭代器"""
    
    def __init__(self, iterable, predicate):
        self.iterator = iter(iterable)
        self.predicate = predicate
    
    def __iter__(self):
        return self
    
    def __next__(self):
        while True:
            item = next(self.iterator)  # 可能抛出StopIteration
            if self.predicate(item):
                return item
            # 如果不符合条件,继续获取下一个元素

# 使用示例
numbers = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]

# 过滤出偶数
even_filter = NumberFilter(numbers, lambda x: x % 2 == 0)
print("偶数:", end=" ")
for num in even_filter:
    print(num, end=" ")  # 输出: 2 4 6 8 10

# 过滤出大于5的数
greater_than_five = NumberFilter(numbers, lambda x: x > 5)
print("\n大于5的数:", end=" ")
for num in greater_than_five:
    print(num, end=" ")  # 输出: 6 7 8 9 10

# 链式过滤:找出大于5的偶数
even_greater_than_five = NumberFilter(
    NumberFilter(numbers, lambda x: x > 5),
    lambda x: x % 2 == 0
)
print("\n大于5的偶数:", end=" ")
for num in even_greater_than_five:
    print(num, end=" ")  # 输出: 6 8 10

1.4.3 无限序列

迭代器可以表示无限序列,只有在需要时才生成值:

class InfiniteCounter:
    """产生无限递增序列的迭代器"""
    
    def __init__(self, start=0, step=1):
        self.value = start
        self.step = step
    
    def __iter__(self):
        return self
    
    def __next__(self):
        current = self.value
        self.value += self.step
        return current

# 使用无限序列(结合islice来获取有限数量的元素)
from itertools import islice

counter = InfiniteCounter()
# 获取前10个元素
first_ten = list(islice(counter, 10))
print(first_ten)  # 输出: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

# 重新创建一个以100开始,步长为10的计数器
step_counter = InfiniteCounter(100, 10)
# 获取前5个元素
first_five = list(islice(step_counter, 5))
print(first_five)  # 输出: [100, 110, 120, 130, 140]

2. 生成器(Generator)

生成器是一种特殊的迭代器,它提供了一种更简单的方式来创建迭代器,而无需实现__iter__()和__next__()方法。

2.1 生成器函数

通过在函数中使用yield语句,可以创建一个生成器函数:

def fibonacci_generator(limit):
    a, b = 0, 1
    count = 0
    

    while count < limit:
        yield a
        a, b = b, a + b
        count += 1

# 使用生成器函数
for num in fibonacci_generator(8):
    print(num, end=" ")  # 输出: 0 1 1 2 3 5 8 13

当调用生成器函数时,它不会立即执行函数体,而是返回一个生成器对象。每次调用next()函数或使用for循环迭代时,函数会执行到yield语句,然后暂停并返回结果,保存当前的状态,等待下一次调用。

这种执行暂停和恢复的能力是生成器的核心特性,它使得生成器非常适合处理大量数据或无限序列,而不会一次性占用大量内存。

2.1.1 生成器函数的工作原理

为了更好地理解生成器函数的工作原理,我们来看一个简单的示例,并分析其执行过程:

def simple_generator():
    print("第一次yield之前")
    yield 1
    print("第二次yield之前")
    yield 2
    print("第三次yield之前")
    yield 3
    print("生成器函数结束")

# 创建生成器对象
gen = simple_generator()

# 第一次调用next()
print("调用next()第一次")
print(f"获得值: {next(gen)}")  # 执行到第一个yield并返回1

# 第二次调用next()
print("调用next()第二次")
print(f"获得值: {next(gen)}")  # 从第一个yield之后继续执行到第二个yield并返回2

# 第三次调用next()
print("调用next()第三次")
print(f"获得值: {next(gen)}")  # 从第二个yield之后继续执行到第三个yield并返回3

# 第四次调用next()
try:
    print("调用next()第四次")
    print(f"获得值: {next(gen)}")  # 从第三个yield之后继续执行,但没有更多yield,抛出StopIteration
except StopIteration:
    print("生成器已耗尽,抛出StopIteration异常")

执行结果:

调用next()第一次
第一次yield之前
获得值: 1
调用next()第二次
第二次yield之前
获得值: 2
调用next()第三次
第三次yield之前
获得值: 3
调用next()第四次
生成器函数结束
生成器已耗尽,抛出StopIteration异常

从这个执行过程可以看出:

  1. 调用生成器函数不会立即执行函数体,而是返回一个生成器对象
  2. 每次调用next(),生成器函数会从上次暂停的地方继续执行,直到遇到下一个yield语句
  3. yield语句暂停函数执行并返回一个值
  4. 当函数执行完毕(没有更多的yield语句)时,生成器会抛出StopIteration异常

2.2 生成器表达式

类似于列表推导式,生成器表达式提供了一种更简洁的创建生成器的方式:

# 列表推导式:立即计算所有值并存储在内存中
squares_list = [x**2 for x in range(10)]

# 生成器表达式:按需计算值,不占用大量内存
squares_gen = (x**2 for x in range(10))

print(squares_list)     # 输出完整列表: [0, 1, 4, 9, 16, 25, 36, 49, 64, 81]
print(squares_gen)      # 输出生成器对象: <generator object <genexpr> at 0x...>

# 使用生成器
for num in squares_gen:
    print(num, end=" ")  # 输出: 0 1 4 9 16 25 36 49 64 81

生成器表达式的语法与列表推导式非常相似,只是使用圆括号()而不是方括号[]。

2.2.1 性能比较:列表推导式 vs 生成器表达式

生成器表达式对于处理大量数据时能显著减少内存使用,提高性能:

import sys
import time

# 测试列表推导式和生成器表达式的内存占用和性能差异

def compare_memory_usage():
    # 使用列表推导式(一次性创建包含100万个元素的列表)
    start_time = time.time()
    list_comp = [i for i in range(1_000_000)]
    list_time = time.time() - start_time
    list_size = sys.getsizeof(list_comp)
    

    # 使用生成器表达式(创建生成器,按需生成元素)
    start_time = time.time()
    gen_exp = (i for i in range(1_000_000))
    gen_time = time.time() - start_time
    gen_size = sys.getsizeof(gen_exp)
    
    print(f"列表推导式:")
    print(f"  - 创建时间: {list_time:.6f} 秒")
    print(f"  - 内存占用: {list_size:,} 字节")
    
    print(f"生成器表达式:")
    print(f"  - 创建时间: {gen_time:.6f} 秒")
    print(f"  - 内存占用: {gen_size:,} 字节")
    
    print(f"内存节省: {list_size / gen_size:.1f}倍")

# compare_memory_usage()

这个比较显示,生成器表达式的内存占用远小于列表推导式,尤其是处理大型数据集时。

2.3 生成器的优势

生成器的主要优势在于内存效率。对于处理大量数据或无限序列时,生成器可以按需生成值,而不需要一次性将所有数据加载到内存中。

# 计算大文件中的行数

def count_lines(file_path):
    with open(file_path, 'r') as f:
        return sum(1 for _ in f)

# 这个版本更高效,因为它不会将整个文件加载到内存中

2.3.1 大数据处理案例

假设我们需要处理一个大型日志文件,提取所有错误信息并分析:

def error_lines(file_path):
    """生成器函数:从日志文件中提取包含'ERROR'的行"""
    with open(file_path, 'r', encoding='utf-8') as f:
        for line_num, line in enumerate(f, 1):
            if 'ERROR' in line:
                yield line_num, line.strip()

def analyze_errors(file_path):
    """分析日志文件中的错误"""
    error_count = 0
    error_types = {}
    

    for line_num, line in error_lines(file_path):
        error_count += 1
        
        # 假设错误格式为: ERROR: [Type] - Message
        # 提取错误类型
        try:
            error_type = line.split('[')[1].split(']')[0]
            error_types[error_type] = error_types.get(error_type, 0) + 1
        except IndexError:
            pass
    
    return error_count, error_types

# 使用示例
# error_count, error_types = analyze_errors('large_log.txt')
# print(f"总错误数: {error_count}")
# for error_type, count in sorted(error_types.items(), key=lambda x: x[1], reverse=True):
#     print(f"{error_type}: {count}次")

使用生成器处理大型文件的优势:

  1. 无需一次性加载整个文件到内存
  2. 可以及时开始处理数据,而不必等待全部文件读取完成
  3. 如果只需要部分结果,可以随时停止处理

2.3.2 处理无限序列

生成器非常适合表示无限序列,因为它们只在需要时才生成值:

def infinite_primes():
    """生成无限质数序列"""
    # 初始化一个用于存储已找到质数的列表
    primes = []
    num = 2
    

    while True:
        # 检查num是否是质数
        is_prime = True
        
        # 只需检查已知质数的整除性
        for prime in primes:
            if prime * prime > num:  # 优化:只需检查到sqrt(num)
                break
            if num % prime == 0:
                is_prime = False
                break
        
        if is_prime:
            primes.append(num)
            yield num
            
        num += 1

# 获取前10个质数
prime_gen = infinite_primes()
first_ten_primes = [next(prime_gen) for _ in range(10)]
print(first_ten_primes)  # 输出: [2, 3, 5, 7, 11, 13, 17, 19, 23, 29]

# 再获取5个质数
next_five_primes = [next(prime_gen) for _ in range(5)]
print(next_five_primes)  # 输出: [31, 37, 41, 43, 47]

2.4 生成器方法

生成器对象还支持一些特殊方法,如send()、throw()和close():

def echo_generator():
    value = yield "Ready"
    while True:
        value = yield f"Echo: {value}"

# 使用send()方法与生成器通信
gen = echo_generator()
print(next(gen))        # 输出: Ready
print(gen.send("Hello"))  # 输出: Echo: Hello
print(gen.send("Python"))  # 输出: Echo: Python
gen.close()             # 关闭生成器

2.4.1 深入理解send()方法

send()方法允许我们向生成器发送值,这个值会成为yield表达式的结果:

def data_processor():
    """一个可以接收外部数据的生成器"""
    data = None
    

    while True:
        # yield 返回当前结果,并等待下一个输入
        # 输入通过send()方法传入,并赋值给data
        processed_data = yield f"处理结果: {data}"
        
        # 当收到None时退出循环
        if processed_data is None:
            break
            
        # 处理数据(这里简单地将数据转换为大写)
        data = processed_data.upper()
    
    # 生成器完成时的返回消息
    return "数据处理完成"

# 使用示例
processor = data_processor()

# 第一次调用next()启动生成器
# 执行到第一个yield并返回,此时data为None
print(next(processor))  # 输出: 处理结果: None

# 发送数据并获取处理结果
print(processor.send("hello"))  # 输出: 处理结果: HELLO
print(processor.send("python"))  # 输出: 处理结果: PYTHON
print(processor.send("generator"))  # 输出: 处理结果: GENERATOR

try: 
    # 发送None表示处理完成
    processor.send(None)
except StopIteration as e:
    # 捕获异常并获取生成器的返回值
    print(f"生成器返回: {e.value}")  # 输出: 生成器返回: 数据处理完成

2.4.2 使用throw()方法

throw()方法允许向生成器内部抛出异常:

def resilient_generator():
    """一个能够处理异常的生成器"""
    try:
        yield "正常值 1"
        yield "正常值 2"
        yield "正常值 3"
    except ValueError:
        yield "捕获到ValueError"
    except Exception as e:
        yield f"捕获到其他异常: {type(e).__name__}"
    finally:
        yield "清理资源"

# 使用示例
gen = resilient_generator()
print(next(gen))  # 输出: 正常值 1

# 向生成器抛出ValueError
print(gen.throw(ValueError))  # 输出: 捕获到ValueError

# 继续执行
print(next(gen))  # 输出: 清理资源

# 再次创建生成器
gen = resilient_generator()
print(next(gen))  # 输出: 正常值 1
print(next(gen))  # 输出: 正常值 2

# 向生成器抛出TypeError
print(gen.throw(TypeError))  # 输出: 捕获到其他异常: TypeError

# 继续执行
print(next(gen))  # 输出: 清理资源

2.4.3 使用close()方法

close()方法用于关闭生成器,它会在生成器内部引发GeneratorExit异常:

def closeable_generator():
    """一个可以优雅关闭的生成器"""
    try:
        yield "值 1"
        yield "值 2"
        yield "值 3"
    except GeneratorExit:
        print("生成器收到关闭请求,正在清理资源...")
        # 这里可以执行清理操作
    finally:
        print("生成器已关闭")

# 使用示例
gen = closeable_generator()
print(next(gen))  # 输出: 值 1
print(next(gen))  # 输出: 值 2
gen.close()       # 输出: 生成器收到关闭请求,正在清理资源... 

                  #       生成器已关闭

2.5 生成器管道

生成器可以组合成管道,每个生成器处理上一个生成器产生的数据:

def read_large_csv(file_path):
    """读取大型CSV文件的生成器"""
    with open(file_path, 'r') as f:
        for line in f:
            # 将CSV行分割为字段
            yield line.strip().split(',')

def filter_rows(rows, column_index, value):
    """过滤特定列等于指定值的行"""
    for row in rows:
        if row[column_index] == value:
            yield row

def transform_rows(rows, transformer):
    """对每行应用转换函数"""
    for row in rows:
        yield transformer(row)

def process_sales_data(filename):
    """处理销售数据的生成器管道"""
    # 1. 读取CSV文件
    rows = read_large_csv(filename)

    # 2. 只保留"已完成"状态的订单
    completed_orders = filter_rows(rows, 3, "已完成")
    
    # 3. 计算每行的总金额(数量 * 单价)
    def calculate_total(row):
        quantity = int(row[1])
        price = float(row[2])
        row.append(str(quantity * price))
        return row
    
    orders_with_total = transform_rows(completed_orders, calculate_total)
    
    # 返回处理后的数据
    return orders_with_total

# 使用示例
# 假设有一个sales.csv文件,格式:产品ID,数量,单价,状态
# for row in process_sales_data('sales.csv'):
#     print(f"产品: {row[0]}, 数量: {row[1]}, 单价: {row[2]}, 状态: {row[3]}, 总额: {row[4]}")

生成器管道的优势:

  • 内存效率 - 一次只处理一条数据
  • 模块化 - 每个处理步骤独立且可复用
  • 惰性计算 - 只有请求数据时才执行计算
  • 简洁清晰 - 代码易于理解和维护

2.6 实际应用案例:网站爬虫

下面是一个使用生成器实现的简单网站爬虫示例:

import requests
from urllib.parse import urljoin, urlparse
from bs4 import BeautifulSoup

def get_links(url):
    """从网页提取所有链接"""
    try:
        response = requests.get(url, timeout=5)
        soup = BeautifulSoup(response.text, 'html.parser')
        for a_tag in soup.find_all('a', href=True):
            href = a_tag['href']
            # 转换为绝对URL
            full_url = urljoin(url, href)
            # 只返回同一域名下的URL
            if urlparse(full_url).netloc == urlparse(url).netloc:
                yield full_url
    except Exception as e:
        print(f"获取链接时出错: {e}")

def crawl_site(start_url, max_pages=10):
    """爬取网站"""
    visited = set()
    to_visit = [start_url]
    page_count = 0

    while to_visit and page_count < max_pages:
        current_url = to_visit.pop(0)
        
        if current_url in visited:
            continue
            
        print(f"正在爬取: {current_url}")
        visited.add(current_url)
        page_count += 1
        
        # 提取页面内容
        try:
            response = requests.get(current_url, timeout=5)
            yield current_url, response.text
            
            # 添加新链接到待访问列表
            for link in get_links(current_url):
                if link not in visited:
                    to_visit.append(link)
        except Exception as e:
            print(f"爬取 {current_url} 时出错: {e}")

# 使用示例
# 爬取Python官网
# for url, content in crawl_site('https://www.python.org', max_pages=5):
#     # 提取页面标题
#     soup = BeautifulSoup(content, 'html.parser')
#     title = soup.title.string if soup.title else "无标题"
#     print(f"URL: {url}")
#     print(f"标题: {title}")
#     print("-" * 50)

这个爬虫示例展示了生成器的几个优势:

  • 边爬取边处理,无需等待所有页面下载完成。
  • 内存占用小,即使爬取大量页面。
  • 代码结构清晰,易于扩展。

3. 装饰器(Decorator)

装饰器是Python中强大的工具,它允许你修改或增强函数或类的行为,而无需修改其代码。

3.1 函数装饰器基础

装饰器本质上是一个接受函数并返回函数的高阶函数:

def simple_decorator(func):
    def wrapper(*args, **kwargs):
        print("Before function call")
        result = func(*args, **kwargs)
        print("After function call")
        return result
    return wrapper

# 使用装饰器
@simple_decorator
def greet(name):
    print(f"Hello, {name}!")

greet("Alice")
# 输出:
# Before function call
# Hello, Alice!
# After function call

使用@decorator语法等价于:greet = simple_decorator(greet)

3.1.1 装饰器工作原理详解

让我们深入理解装饰器的工作原理:

def my_decorator(func):
    # 这个函数将包装原始函数
    def wrapper(*args, **kwargs):
        print(f"装饰器: 正在调用 {func.__name__}")
        # 调用原始函数
        result = func(*args, **kwargs)
        print(f"装饰器: {func.__name__} 调用完成")
        return result
    # 返回包装函数
    return wrapper

# 手动应用装饰器
def say_hello(name):
    print(f"Hello, {name}!")

# 这相当于使用 @my_decorator
decorated_say_hello = my_decorator(say_hello)
decorated_say_hello("手动装饰")  # 调用装饰后的函数

# 使用@语法应用装饰器
@my_decorator
def say_goodbye(name):
    print(f"Goodbye, {name}!")

say_goodbye("语法糖装饰")  # 调用装饰后的函数

执行流程:

  1. 当Python解释器遇到@my_decorator时,它会先执行my_decorator函数
  2. my_decorator函数接收原始函数say_goodbye作为参数
  3. my_decorator返回wrapper函数
  4. Python将say_goodbye名称重新绑定到返回的wrapper函数
  5. 当我们调用say_goodbye("语法糖装饰")时,实际上调用的是wrapper("语法糖装饰")

3.2 带参数的装饰器

我们可以创建接受参数的装饰器:

def repeat(n):
    def decorator(func):
        def wrapper(*args, **kwargs):
            result = None
            for _ in range(n):
                result = func(*args, **kwargs)
            return result
        return wrapper
    return decorator

@repeat(3)
def say_hello(name):
    print(f"Hello, {name}!")

say_hello("Bob")

# 输出:
# Hello, Bob!
# Hello, Bob!
# Hello, Bob!

3.2.1 带参数装饰器的工作原理

带参数的装饰器实际上是一个返回装饰器的函数。让我们分解这个过程:

def repeat(n):
    # 外部函数,接收装饰器参数
    print(f"定义装饰器,重复次数:{n}")
    

    def decorator(func):
        # 中间函数,接收被装饰函数
        print(f"装饰 {func.__name__} 函数")
        
        def wrapper(*args, **kwargs):
            # 内部函数,实际调用时执行
            print(f"开始调用 {func.__name__}, 将重复 {n} 次")
            result = None
            for i in range(n):
                print(f"第 {i+1} 次调用")
                result = func(*args, **kwargs)
            print(f"结束调用 {func.__name__}")
            return result
            
        return wrapper
        
    # 返回实际的装饰器函数
    return decorator

# 观察装饰器的执行过程
print("定义函数前")

@repeat(3)
def test_function(message):
    print(f"测试函数输出: {message}")
    return message

print("定义函数后,调用前")
result = test_function("Hello World")
print(f"调用结果: {result}")

执行流程:

  1. 首先执行repeat(3),返回decorator函数
  2. decorator函数接收test_function作为参数
  3. decorator返回wrapper函数
  4. Python将test_function名称重新绑定到wrapper函数
  5. 当我们调用test_function("Hello World")时,实际调用的是wrapper("Hello World")

3.3 保留函数元数据

使用装饰器会导致被装饰函数丢失原始的元数据(如名称、文档字符串等)。为了解决这个问题,可以使用functools.wraps装饰器:

from functools import wraps

def log_function_call(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        print(f"Calling {func.__name__} with args: {args}, kwargs: {kwargs}")
        result = func(*args, **kwargs)
        print(f"{func.__name__} returned: {result}")
        return result
    return wrapper

@log_function_call
def add(a, b):
    """Add two numbers and return the result."""
    return a + b

print(add.__name__)  # 输出: add(保留了原始函数名)
print(add.__doc__)   # 输出: Add two numbers and return the result.(保留了文档字符串)

add(3, 5)

# 输出:
# Calling add with args: (3, 5), kwargs: {}
# add returned: 8

3.3.1 不使用wraps的问题

为了理解functools.wraps的重要性,让我们看看不使用它会发生什么:

def without_wraps(func):
    def wrapper(*args, **kwargs):
        """This is wrapper function."""
        return func(*args, **kwargs)
    return wrapper

def with_wraps(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        """This is wrapper function."""
        return func(*args, **kwargs)
    return wrapper

@without_wraps
def func_a():
    """This is function A."""
    pass

@with_wraps
def func_b():
    """This is function B."""
    pass

print("不使用wraps:")
print(f"名称: {func_a.__name__}")  # 输出: wrapper
print(f"文档: {func_a.__doc__}")   # 输出: This is wrapper function.
print(f"模块: {func_a.__module__}")

print("\n使用wraps:")
print(f"名称: {func_b.__name__}")  # 输出: func_b
print(f"文档: {func_b.__doc__}")   # 输出: This is function B.
print(f"模块: {func_b.__module__}")

functools.wraps的作用是将原始函数的元数据复制到包装函数,包括:

  • __name__:函数名
  • __doc__:文档字符串
  • __module__:函数定义所在的模块
  • __annotations__:函数注解
  • __qualname__:函数的限定名称

3.4 类方法装饰器

Python提供了三个内置的类方法装饰器:@classmethod、@staticmethod和@property:

class Temperature:
    def __init__(self, celsius):
        self._celsius = celsius
    

    # 属性装饰器
    @property
    def celsius(self):
        return self._celsius
    
    @celsius.setter
    def celsius(self, value):
        if value < -273.15:
            raise ValueError("Temperature below absolute zero is not possible")
        self._celsius = value
    
    @property
    def fahrenheit(self):
        return self._celsius * 9/5 + 32
    
    @fahrenheit.setter
    def fahrenheit(self, value):
        self.celsius = (value - 32) * 5/9
    
    # 类方法
    @classmethod
    def from_fahrenheit(cls, value):
        return cls((value - 32) * 5/9)
    
    # 静态方法
    @staticmethod
    def is_valid_temperature(value):
        return value >= -273.15

# 使用
temp = Temperature(25)
print(temp.celsius)      # 输出: 25
print(temp.fahrenheit)   # 输出: 77.0

temp.celsius = 30
print(temp.fahrenheit)   # 输出: 86.0

temp.fahrenheit = 68
print(temp.celsius)      # 输出: 20.0

# 使用类方法创建实例
temp2 = Temperature.from_fahrenheit(32)
print(temp2.celsius)     # 输出: 0.0

# 使用静态方法
print(Temperature.is_valid_temperature(-300))  # 输出: False

3.4.1 深入理解 @property

@property装饰器允许我们像访问属性一样访问方法,并在需要时添加验证逻辑:

class Person:
    def __init__(self, first_name, last_name, age):
        self._first_name = first_name
        self._last_name = last_name
        self._age = age
    

    @property
    def full_name(self):
        return f"{self._first_name} {self._last_name}"
    
    @property
    def age(self):
        return self._age
    
    @age.setter
    def age(self, value):
        if not isinstance(value, int):
            raise TypeError("年龄必须是整数")
        if value < 0 or value > 150:
            raise ValueError("年龄必须在0到150之间")
        self._age = value
    
    @property
    def is_adult(self):
        return self._age >= 18

# 使用示例
person = Person("张", "三", 25)
print(person.full_name)  # 输出: 张 三
print(person.is_adult)   # 输出: True

# 修改年龄
person.age = 17
print(person.is_adult)   # 输出: False

try:
    person.age = -5      # 抛出ValueError
except ValueError as e:
    print(f"错误: {e}")

try:
    person.full_name = "李四"  # 抛出AttributeError
except AttributeError as e:
    print(f"错误: {e}")

@property的主要用途:

  • 封装 - 隐藏内部实现,提供受控的访问。
  • 验证 - 在设置值之前进行验证。
  • 计算属性 - 创建不需要存储但可以计算的属性(如full_name)。
  • 向后兼容 - 可以将方法改为属性而不破坏现有代码。

3.5 类装饰器

除了函数装饰器,Python还支持类装饰器:

def singleton(cls):
    instances = {}
    

    def get_instance(*args, **kwargs):
        if cls not in instances:
            instances[cls] = cls(*args, **kwargs)
        return instances[cls]
    
    return get_instance

@singleton
class Database:
    def __init__(self, url):
        self.url = url
        print(f"Database initialized with URL: {url}")
    

    def connect(self):
        print(f"Connected to database at {self.url}")

# 无论创建多少次实例,都只会有一个实例
db1 = Database("localhost:5432")
db2 = Database("example.com:5432")  # 不会真正创建新实例

print(db1 is db2)  # 输出: True
db1.connect()      # 输出: Connected to database at localhost:5432

3.5.1 类作为装饰器

类本身也可以作为装饰器,只需实现__call__方法:

class CountCalls:
    def __init__(self, func):
        self.func = func
        self.count = 0
        # 复制原始函数的元数据
        functools.update_wrapper(self, func)
    

    def __call__(self, *args, **kwargs):
        self.count += 1
        print(f"{self.func.__name__} 已被调用 {self.count} 次")
        return self.func(*args, **kwargs)

@CountCalls
def say_hello(name):
    return f"Hello, {name}!"

print(say_hello("Alice"))  # 输出: say_hello 已被调用 1 次; Hello, Alice!
print(say_hello("Bob"))    # 输出: say_hello 已被调用 2 次; Hello, Bob!

3.6 实用装饰器示例

3.6.1 计时装饰器

测量函数执行时间的装饰器:

import time
from functools import wraps

def timing_decorator(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        start_time = time.time()
        result = func(*args, **kwargs)
        end_time = time.time()
        print(f"{func.__name__} 执行耗时: {end_time - start_time:.6f} 秒")
        return result
    return wrapper

@timing_decorator
def slow_function(n):
    """一个耗时的函数"""
    time.sleep(n)  # 模拟耗时操作
    return f"完成 {n} 秒的操作"

slow_function(1.5)  # 输出: slow_function 执行耗时: 1.500123 秒

3.6.2 缓存装饰器

缓存函数结果以提高性能的装饰器:

def memoize(func):
    """缓存函数结果的简单装饰器"""
    cache = {}
    

    @wraps(func)
    def wrapper(*args):
        # 使用参数作为缓存的键
        if args not in cache:
            cache[args] = func(*args)
            print(f"计算并缓存结果: {args}")
        else:
            print(f"使用缓存结果: {args}")
        return cache[args]
    
    return wrapper

@memoize
def fibonacci(n):
    """计算斐波那契数列的第n项"""
    if n < 2:
        return n
    return fibonacci(n-1) + fibonacci(n-2)

print(fibonacci(10))  # 首次计算
print(fibonacci(10))  # 使用缓存
print(fibonacci(5))   # 从计算fibonacci(10)时已经缓存

注意:Python标准库中的functools.lru_cache提供了更完善的缓存装饰器。

3.6.3 重试装饰器

在失败时自动重试函数的装饰器:

import time
import random
from functools import wraps

def retry(max_attempts=3, delay=1):
    """重试装饰器:在失败时自动重试函数,最多重试max_attempts次"""
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            attempts = 0
            while attempts < max_attempts:
                try:
                    return func(*args, **kwargs)
                except Exception as e:
                    attempts += 1
                    if attempts == max_attempts:
                        raise  # 达到最大重试次数,重新抛出异常
                    

                    print(f"尝试 {attempts}/{max_attempts} 失败: {e}")
                    print(f"等待 {delay} 秒后重试...")
                    time.sleep(delay)
            
        return wrapper
    return decorator

@retry(max_attempts=5, delay=0.5)
def unstable_network_call(success_rate=0.3):
    """模拟不稳定的网络调用,有一定概率成功"""
    if random.random() > success_rate:
        raise ConnectionError("网络连接失败")
    return "数据获取成功"

try:
    result = unstable_network_call(0.2)
    print(f"结果: {result}")
except ConnectionError:
    print("所有重试都失败了")

3.6.4 参数验证装饰器

验证函数参数的装饰器:

from functools import wraps

def validate_types(**expected_types):
    """验证函数参数类型的装饰器"""
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            # 获取参数名称
            func_args = func.__code__.co_varnames[:func.__code__.co_argcount]
            

            # 合并位置参数和关键字参数
            all_args = dict(zip(func_args, args))
            all_args.update(kwargs)
            
            # 验证类型
            for arg_name, expected_type in expected_types.items():
                if arg_name in all_args:
                    actual_value = all_args[arg_name]
                    if not isinstance(actual_value, expected_type):
                        raise TypeError(
                            f"参数 '{arg_name}' 的类型必须是 {expected_type.__name__},"
                            f"但获得了 {type(actual_value).__name__}"
                        )
            
            return func(*args, **kwargs)
        return wrapper
    return decorator

@validate_types(name=str, age=int)
def greet_person(name, age):
    return f"你好,{name}!你今年{age}岁。"

# 正确的调用
print(greet_person("张三", 25))  # 输出: 你好,张三!你今年25岁。

# 类型错误的调用
try:
    print(greet_person("李四", "二十"))
except TypeError as e:
    print(f"错误: {e}")  # 输出错误信息

3.6.5 权限控制装饰器

检查用户权限的装饰器:

from functools import wraps

# 模拟用户会话
class UserSession:
    def __init__(self, user_id, roles=None):
        self.user_id = user_id
        self.roles = roles or []

# 全局会话对象
current_session = None

def login(user_id, roles=None):
    global current_session
    current_session = UserSession(user_id, roles)
    return current_session

def logout():
    global current_session
    current_session = None

def require_roles(*required_roles):
    """验证当前用户是否拥有所需角色的装饰器"""
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            if not current_session:
                raise PermissionError("用户未登录")
            

            # 检查用户是否拥有所需的所有角色
            missing_roles = [role for role in required_roles if role not in current_session.roles]
            
            if missing_roles:
                raise PermissionError(
                    f"权限不足。需要角色: {', '.join(missing_roles)}"
                )
            
            return func(*args, **kwargs)
        return wrapper
    return decorator

# 应用示例
@require_roles("admin")
def delete_user(user_id):
    return f"用户 {user_id} 已删除"

@require_roles("editor", "content_manager")
def publish_article(article_id):
    return f"文章 {article_id} 已发布"

# 测试
try:
    # 未登录
    delete_user(123)
except PermissionError as e:
    print(f"错误: {e}")

# 以普通用户身份登录
login("user123", ["user"])

try:
    # 尝试执行管理员操作
    delete_user(123)
except PermissionError as e:
    print(f"错误: {e}")

# 以管理员身份登录
login("admin456", ["admin", "user"])

# 现在可以执行管理员操作
print(delete_user(123))

# 尝试需要多个角色的操作
try:
    publish_article(456)
except PermissionError as e:
    print(f"错误: {e}")

# 登录具有所需所有角色的用户
login("editor789", ["editor", "content_manager", "user"])

# 现在可以执行该操作
print(publish_article(456))
posted @ 2025-04-30 16:05  零の守墓人  阅读(34)  评论(0)    收藏  举报