欢迎来到Louis的博客

人生三从境界:昨夜西风凋碧树,独上高楼,望尽天涯路。 衣带渐宽终不悔,为伊消得人憔悴。 众里寻他千百度,蓦然回首,那人却在灯火阑珊处。
扩大
缩小

04-节流(访问限制)

一.自定义节流

限制单个IP60秒只能访问3次

1.节流器是单个组件,独立放在utils目录中

utils/throttle.py

from rest_framework.throttling import BaseThrottle
import time

VISIT_RECORD = {}  # 保存访问记录


class VisitThrottle(BaseThrottle):
    def __init__(self):
        self.history = None  # 初始化访问记录

    def allow_request(self, request, view):
        # 获取用户IP
        remote_addr = self.get_ident(request)
        ctime = time.time()
        if remote_addr not in VISIT_RECORD:
            VISIT_RECORD[remote_addr] = [ctime, ]
            return True  # 表示可以访问
        # 获取当前ip的历史访问记录
        history = VISIT_RECORD.get(remote_addr)
        # 初始化访问记录
        self.history = history

        # 如果有历史访问记录,并最早一次访问记录离当前记录超过60S,就删除最早的那个访问记录,
        while history and history[-1] < ctime - 60:
            history.pop()

        # 如果访问不操过三次,就把当前访问记录插入到第一个位置
        if len(history) < 3:
            history.insert(0, ctime)
            return True

    def wait(self):
        """
        还需要等多久才能访问
        :return: 
        """
        ctime = time.time()
        return 60 - (ctime - self.history[-1])

 2.在setting中配置全局节流

REST_FRAMEWORK = {
    # 节流
    "DEFAULT_THROTTLE_CLASSES": ['utils.throttle.VisitThrottle'],

}

 

3.看看访问结果

  • 60S之类访问次数超过3次,限制访问
  • 提示多少秒后能访问

正常访问

访问被限制,提示多少秒后能访问

2.节流源码简析

1.所有的drf功能都是基于APIView实现的,看下dispatch方法

APIView.dispatch

    def dispatch(self, request, *args, **kwargs):
        """
        `.dispatch()` is pretty much the same as Django's regular dispatch,
        but with extra hooks for startup, finalize, and exception handling.
        """
        self.args = args
        self.kwargs = kwargs
        # 对原始request进行再次封装,增加一些drf的功能
        # Request(
        #     request,
        #     parsers=self.get_parsers(),   解析器
        #     authenticators=self.get_authenticators(), 认证
        #     negotiator=self.get_content_negotiator(),
        #     parser_context=parser_context
        # )
        # request(原始request,[BasicAuthentications对象,])
        # 获取原生request,request._request
        # 获取认证类的对象,request.authticators
        # 1.封装request
        request = self.initialize_request(request, *args, **kwargs)
        self.request = request
        self.headers = self.default_response_headers  # deprecate?

        try:
            # 2.认证,权限,节流触发
            self.initial(request, *args, **kwargs)

            # Get the appropriate handler method
            if request.method.lower() in self.http_method_names:
                handler = getattr(self, request.method.lower(),
                                  self.http_method_not_allowed)
            else:
                handler = self.http_method_not_allowed

            response = handler(request, *args, **kwargs)

        except Exception as exc:
            response = self.handle_exception(exc)

        self.response = self.finalize_response(request, response, *args, **kwargs)
        return self.response

 2.APIView.initial

    def initial(self, request, *args, **kwargs):
        """
        Runs anything that needs to occur prior to calling the method handler.
        """
        self.format_kwarg = self.get_format_suffix(**kwargs)

        # Perform content negotiation and store the accepted info on the request
        neg = self.perform_content_negotiation(request)
        request.accepted_renderer, request.accepted_media_type = neg

        # Determine the API version, if versioning is in use.
        version, scheme = self.determine_version(request, *args, **kwargs)
        request.version, request.versioning_scheme = version, scheme

        # Ensure that the incoming request is permitted
        # 3.认证
        self.perform_authentication(request)
        # 4.权限
        self.check_permissions(request)
        # 5.节流
        self.check_throttles(request)

3.APIView.check_throttles

    def check_throttles(self, request):
        """
        Check if request should be throttled.
        Raises an appropriate exception if the request is throttled.
        """
        # 6.get_throttles 获取节流器
        for throttle in self.get_throttles():
            if not throttle.allow_request(request, self):
                self.throttled(request, throttle.wait())

 4.APIView.get_throttles

    def get_throttles(self):
        """
        Instantiates and returns the list of throttles that this view uses.
        """
        # 7返回所有的节流器
        return [throttle() for throttle in self.throttle_classes]   

5.throttle_class

class APIView(View):
    # The following policies may be set at either globally, or per-view.
    renderer_classes = api_settings.DEFAULT_RENDERER_CLASSES
    parser_classes = api_settings.DEFAULT_PARSER_CLASSES
    authentication_classes = api_settings.DEFAULT_AUTHENTICATION_CLASSES
    throttle_classes = api_settings.DEFAULT_THROTTLE_CLASSES    # 读取配置中的节流器
    permission_classes = api_settings.DEFAULT_PERMISSION_CLASSES
    content_negotiation_class = api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS
    metadata_class = api_settings.DEFAULT_METADATA_CLASS
    versioning_class = api_settings.DEFAULT_VERSIONING_CLASS

6.api_settings

api_settings = APISettings(None, DEFAULTS, IMPORT_STRINGS)    # 
 @property
    def user_settings(self):
        if not hasattr(self, '_user_settings'):
            self._user_settings = getattr(settings, 'REST_FRAMEWORK', {}) # 读取用户setting中的配置
        return self._user_settings

    def __getattr__(self, attr):
        if attr not in self.defaults:
            raise AttributeError("Invalid API setting: '%s'" % attr)

        try:
            # Check if present in user settings
            val = self.user_settings[attr]     
        except KeyError:
            # Fall back to defaults
            val = self.defaults[attr]

        # Coerce import strings into classes
        if attr in self.import_strings:
            val = perform_import(val, attr)    # 动态import用户自定义的节流类

        # Cache the result
        self._cached_attrs.add(attr)
        setattr(self, attr, val)
        return val

7.throttle.allow_request(request,self)

        for throttle in self.get_throttles():
            # throttle.allow_request(request, self) 节流逻辑,返回bool,返回Fales触发throttled节流,,如果自定义节流类需要重新改方法和wait方法
            if not throttle.allow_request(request, self):
                self.throttled(request, throttle.wait())

 

三.drf自带的节流器

上面写了一个自定义节流器,助于我们理解源码,drf内置了几个节流器,用起来比较简单方便

1.BaseThrottle类,

自己需要重写allow_request和wait方法

get_ident 获取IP地址

class BaseThrottle(object):
    """
    Rate throttling of requests.
    """
    
    # 节流逻辑,基础该类需要重写该方法
    def allow_request(self, request, view):
        """
        Return `True` if the request should be allowed, `False` otherwise.
        """
        raise NotImplementedError('.allow_request() must be overridden')
    # 获取远程IP地址
    def get_ident(self, request):
        """
        Identify the machine making the request by parsing HTTP_X_FORWARDED_FOR
        if present and number of proxies is > 0. If not use all of
        HTTP_X_FORWARDED_FOR if it is available, if not use REMOTE_ADDR.
        """
        xff = request.META.get('HTTP_X_FORWARDED_FOR')
        remote_addr = request.META.get('REMOTE_ADDR')
        num_proxies = api_settings.NUM_PROXIES

        if num_proxies is not None:
            if num_proxies == 0 or xff is None:
                return remote_addr
            addrs = xff.split(',')
            client_addr = addrs[-min(num_proxies, len(addrs))]
            return client_addr.strip()

        return ''.join(xff.split()) if xff else remote_addr
    
    # 触发节流器后给予用户的提示,需要重写
    def wait(self):
        """
        Optionally, return a recommended number of seconds to wait before
        the next request.
        """
        return None

 

2.SimpleRateThrottle类

我们可以通过继承SimpleRateThrottle类,来实现节流,会更加的简单,因为SimpleRateThrottle里面都帮我们写好了

class VisitThrottle(SimpleRateThrottle):
    scope = 'anonymous'  # 这里面的值,自己随便定义,settings里面根据这个值配置Rate

    def get_cache_key(self, request, view):
        # 通过IP标识节流
        return self.get_ident(request)


class UserThrottle(SimpleRateThrottle):
    scope = 'auth_user'
    
    def get_cache_key(self, request, view):
        # 验证用户通过用户名标示计数节流
        return request.user.username

 

settings,全局节流

REST_FRAMEWORK = {
    # 节流
    "DEFAULT_THROTTLE_CLASSES": ['utils.throttle.VisitThrottle'],
    "DEFAULT_THROTTLE_RATES": {
        # 匿名用户
        'anonymous': '3/m',
        # 登录用户
        'auth_user': '10/m'
    }

}

 

 views.py

局部节流

class Login(APIView):
    # 登录路由,非认证用户,节流3/m
    throttle_classes = [VisitThrottle]

    def post(self, request):
        response = dict()
        fields = {'username', 'password'}
        user_info = dict()
        if fields.issubset(set(request.data)):
            for key in fields:
                user_info[key] = request.data[key]

        user_instance = authenticate(**user_info)

        if user_instance is not None:
            access_token = generate_token()
            UserToken.objects.update_or_create(user=user_instance, defaults={'token': access_token})
            response['status_code'] = 200
            response['status_message'] = '登录成功'
            response['access_token'] = access_token
            response['user_role'] = user_instance.get_user_level_display()
        else:
            response['status_code'] = 201
            response['status_message'] = '用户名或密码错误'

        return Response(response)
class BookView(ModelViewSet):
    authentication_classes = [UserAuth]
    permission_classes = [UserPermission]
    # 登录用户节流,10/m
    throttle_classes = [UserThrottle]
    queryset = Book.objects.all()
    serializer_class = BookSerialize

 说明:

  • utils.throttle.UserThrottle   这个是全局配置(根据ip限制,10/m)
  • DEFAULT_THROTTLE_RATES      --->>>设置访问频率的
  • throttle_classes = [VisitThrottle,]     --->>>局部配置(不适用settings里面默认的全局配置)

总结

基本使用

  • 创建类,继承BaseThrottle, 实现:allow_request ,wait  
  • 创建类,继承SimpleRateThrottle,   实现:  get_cache_key, scope='auth_user'      (配置文件中的key) 

 

posted on 2018-12-11 16:10  Louiszj  阅读(173)  评论(0)    收藏  举报

导航