DRF限流
DRF的限流(Throttling)是一种重要的API保护机制,用于控制客户端对API接口的请求频率,防止恶意请求、资源过度消耗和服务崩溃。
1、BaseThrottle
所有限流类的基类,位于 rest_framework.throttling 模块中。它主要定义了接口
class BaseThrottle:
def allow_request(self, request, view):
# 核心方法:判断是否允许请求。返回 True 则允许,False 则拒绝。
# 必须由子类重写。
raise NotImplementedError('.allow_request() must be overridden')
def get_ident(self, request):
# 获取请求的唯一标识(如IP地址)。会处理 HTTP_X_FORWARDED_FOR 和 REMOTE_ADDR。
# ... 源码逻辑 ...
return ident
def wait(self):
# 当请求被拒绝时,返回需要等待的秒数(提示用户)。
# 可选实现。
return None
2、SimpleRateThrottle
这是一个继承了 BaseThrottle 的重要类,DRF 的内置限流类都基于它。它实现了基于简单速率的限流逻辑
class SimpleRateThrottle(BaseThrottle):
"""
实现了基于频率限制的核心逻辑,大多数内置限流器基类
"""
def __init__(self):
"""
初始化时通过 get_rate() 和 parse_rate(rate) 方法获取并解析在 settings.py 中配置的速率(如 '100/day'),得到允许的请求次数(num_requests)和时间周期(以秒为单位的 duration)
"""
if not getattr(self, 'rate', None):
self.rate = self.get_rate()
self.num_requests, self.duration = self.parse_rate(self.rate)
def get_cache_key(self, request, view):
"""
必须由子类重写。返回一个唯一的字符串作为缓存键,用于区分不同的请求源(如 IP 或用户ID),返回None则代表不限流
"""
raise NotImplementedError('.get_cache_key() must be overridden')
def get_rate(self):
"""
获取配置中的限流配置
"""
def parse_rate(self, rate):
"""
rate: 100/(s/second)
解析速率,返回允许的请求次数, 周期秒数
"""
if rate is None:
return (None, None)
num, period = rate.split('/')
num_requests = int(num)
# period[0]是period第一个字符,因此rate的/后面只要是s,m,h,d开头的单词即可
duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
return (num_requests, duration)
def allow_request(self, request, view):
"""
检查核心逻辑
"""
if self.rate is None:
return True
# 获取缓存键 key
self.key = self.get_cache_key(request, view)
if self.key is None:
return True
# 从缓存(通常是 Redis 或 Django 的缓存后端)中获取该 key 对应的请求历史记录(history),这是一个时间戳列表
self.history = self.cache.get(self.key, [])
# 当前时间
self.now = self.timer()
# 清理历史记录中所有早于(当前时间 - duration)的时间戳(因为这些请求已超出时间窗口,不再影响当前限流)
while self.history and self.history[-1] <= self.now - self.duration:
self.history.pop()
# 如果清理后的 history 长度小于 num_requests,说明在时间窗口内请求次数未超限
if len(self.history) >= self.num_requests:
return self.throttle_failure()
return self.throttle_success()
def throttle_success(self):
"""
将当前时间插入 history 的开头,将新的 history 更新到缓存中,并设置过期时间为 duration
"""
self.history.insert(0, self.now)
self.cache.set(self.key, self.history, self.duration)
return True
def wait(self):
"""
计算并返回还需要等待多少秒后才能再次请求
"""
# 1分钟限流3次,self.history = [11:35:40,11:35:35,11:35:30], self.duration=60, now = 11:35:50, 则需要等待40秒
if self.history:
# 还需要等待的时间计算
remaining_duration = self.duration - (self.now - self.history[-1])
else:
remaining_duration = self.duration
# 可以访问的次数,走到限流这里,一般都是1
available_requests = self.num_requests - len(self.history) + 1
if available_requests <= 0:
return None
return remaining_duration / float(available_requests)
3、AnonRateThrottle
用于匿名用户。get_cache_key 通常返回 self.get_ident(request)(即客户端的 IP 地址)
class AnonRateThrottle(SimpleRateThrottle):
scope = 'anon'
def get_cache_key(self, request, view):
if request.user and request.user.is_authenticated:
return None
return self.cache_format % {
'scope': self.scope,
'ident': self.get_ident(request)
}
4、UserRateThrottle
用于认证用户。get_cache_key 通常返回 request.user.pk 或 request.user.username(即用户的唯一标识)。如果用户未认证,则可能返回 None(不进行限流)。
class UserRateThrottle(SimpleRateThrottle):
scope = 'user'
def get_cache_key(self, request, view):
if request.user and request.user.is_authenticated:
ident = request.user.pk
else:
ident = self.get_ident(request)
return self.cache_format % {
'scope': self.scope,
'ident': ident
}
5、ScopedRateThrottle
用于为不同的视图或接口设置不同的限流规则。它使用视图的 throttle_scope 属性与配置中的速率进行匹配。get_cache_key 通常会组合 scope 和用户标识(IP 或用户ID)。
class ScopedRateThrottle(SimpleRateThrottle):
scope_attr = 'throttle_scope'
def __init__(self):
pass
def allow_request(self, request, view):
# 获取视图的throttle_scope属性值
self.scope = getattr(view, self.scope_attr, None)
if not self.scope:
return True
self.rate = self.get_rate()
self.num_requests, self.duration = self.parse_rate(self.rate)
# 调用SimpleRateThrottle的allow_request
return super().allow_request(request, view)
def get_cache_key(self, request, view):
if request.user and request.user.is_authenticated:
ident = request.user.pk
else:
ident = self.get_ident(request)
return self.cache_format % {
'scope': self.scope,
'ident': ident
}
6、自定义限流
你可以通过继承 BaseThrottle或 SimpleRateThrottle来实现更复杂的限流逻辑,继承 SimpleRateThrottle是最常见和简单的方式,只需重写 get_cache_key方法并设置 scope即可
from rest_framework.throttling import SimpleRateThrottle
class IpThrottle(SimpleRateThrottle):
scope = "ip" # 在DEFAULT_THROTTLE_RATES中配置对应的速率
cache = default_cache # 通常使用Django的默认缓存
def get_cache_key(self, request, view):
# 以IP作为标识
ident = self.get_ident(request)
return self.cache_format % {'scope': self.scope, 'ident': ident}
7、配置与使用
7.1 全局配置
# settings.py
REST_FRAMEWORK = {
'DEFAULT_THROTTLE_CLASSES': [
'rest_framework.throttling.AnonRateThrottle',
'rest_framework.throttling.UserRateThrottle',
'myapp.throttling.TieredUserThrottle', # 自定义限流
],
'DEFAULT_THROTTLE_RATES': {
'anon': '100/day', # 匿名用户每天100次
'user': '1000/hour', # 认证用户每小时1000次
'premium': '5000/hour', # 高级用户每小时5000次
}
}
7.2 视图级限流配置
class PublicAPIView(APIView):
"""公共API,只限制匿名用户"""
throttle_classes = [AnonRateThrottle]
def get(self, request):
return Response({"message": "公共API"})
class UserAPIView(APIView):
"""用户API,限制所有用户"""
throttle_classes = [UserRateThrottle]
def get(self, request):
return Response({"message": "用户API"})
class PremiumAPIView(APIView):
"""高级API,使用自定义分层限流"""
throttle_classes = [TieredUserThrottle]
def get(self, request):
return Response({"message": "高级API"})
# 在视图集中使用
class MyViewSet(ModelViewSet):
queryset = MyModel.objects.all()
serializer_class = MyModelSerializer
throttle_classes = [UserRateThrottle]
@action(detail=False, methods=['get'], throttle_classes=[AnonRateThrottle])
def public_list(self, request):
"""公开列表,使用不同的限流策略"""
queryset = self.get_queryset().filter(is_public=True)
serializer = self.get_serializer(queryset, many=True)
return Response(serializer.data)
8、DRF限流流程分析
APIView: dispatch->initial->check_throttles
class APIView(View):
def dispatch(self, request, *args, **kwargs):
self.initial(request, *args, **kwargs)
return self.response
def initial(self, request, *args, **kwargs):
self.check_throttles(request)
def get_throttles(self):
# 实例化各个限流类
return [throttle() for throttle in self.throttle_classes]
def check_throttles(self, request):
throttle_durations = []
# 循环调用各个限流类的allow_request
for throttle in self.get_throttles():
if not throttle.allow_request(request, self):
# 添加需要等待的时间
throttle_durations.append(throttle.wait())
if throttle_durations:
durations = [
duration for duration in throttle_durations
if duration is not None
]
# 返回多个限流类中需要等大的最大时长
duration = max(durations, default=None)
self.throttled(request, duration)