lolicute

导航

DRF中封装 + 认证 + 权限 + 限流

DRF中封装 + 认证 + 权限 + 限流

1. request封装

1.1 属性

先来学一个关于面向对象的知识点。

class Request(object):
    def __init__(self, req, xx):
        self._request = req
        self.xx = xx


obj = Request(1, 2)
print(obj.xx)
print(obj._request)

获取对象中的成员时,本质上会调用 __getattribute__方法,默认我们不定义就用父类中的。

class Request(object):
    def __init__(self, req, xx):
        self._request = req
        self.xx = xx

    def __getattribute__(self, item):
        print("执行__getattribute__", item)
        return super().__getattribute__(item)


obj = Request(1, 2)
print(obj.xx)
print(obj._request)
# int(obj.v1) # 报错
# 注意:如果不是对象中的成员,就会报错。

不过想要访问对象中不存在成员,则可以通过定义 __getattr__实现。

  • 先执行自己的 __getattribute__
  • 再执行父类的__getattribute__
    • 是自己对象,直接获取并返回
    • 不是自己对象,调用__getattr__
class Request(object):
    def __init__(self, req, xx):
        self._request = req
        self.xx = xx

    def __getattribute__(self, item):
        print("执行__getattribute__", item)
        return super().__getattribute__(item)

    def __getattr__(self, item):
        print("__getattr__", item)
        return 999


obj = Request(1, 2)
print(obj.xx)
print(obj._request)
print(obj.v1)

1.2 对象封装

class HttpRequest(object):
    def __init__(self):
        pass
    
    def v1(self):
        print("v1")
        
    def v2(self):
        print("v1")

class Request(object):
    def __init__(self,req, xx):
        self._request = req
        self.xx = xx

request = HttpRequest()
request.v1()
request.v2()

request = Request(request,111)
request._request.v1()
request._request.v2()
class HttpRequest(object):
    def __init__(self):
        pass
    
    def v1(self):
        print("v1")
        
    def v2(self):
        print("v1")

class Request(object):
    def __init__(self,req, xx):
        self._request = req
        self.xx = xx
        
    def __getattr__(self, attr):
        try:
            return getattr(self._request, attr)
        except AttributeError:
            return self.__getattribute__(attr)
        
request = HttpRequest()
request.v1()
request.v2()

request = Request(request,111)
request.v1()
request.v2()

1.3 源码分析

image-20210819150601089

  • 路由

    path('login/', views.LoginView.as_view()),#第一步
    
  • 视图

    • 去LoginView类中找as_view方法,没有去父类当中找

      class LoginView(APIView):
          authentication_classes = []
      
          def post(self, request):
              # 1.接收用户POST提交的用户名和密码
              # print(request.query_params)
              user = request.data.get("username")
              pwd = request.data.get("password")
      
              # 2.数据库校验
              user_object = models.UserInfo.objects.filter(username=user, password=pwd).first()
              if not user_object:
                  return Response({"status": False, 'msg': "用户名或密码错误"})
      
              # 3.正确
              token = str(uuid.uuid4())
              user_object.token = token
              user_object.save()
      
              return Response({"status": True, 'data': token})
      
    • drf框架中APIView类中as_view方法

      class APIView(View):
      
          @classmethod
          def as_view(cls, **initkwargs): #第二步
              if isinstance(getattr(cls, 'queryset', None), models.query.QuerySet):
                  def force_evaluation():
                      raise RuntimeError(
                          'Do not evaluate the `.queryset` attribute directly, '
                          'as the result will be cached and reused between requests. '
                          'Use `.all()` or call `.get_queryset()` instead.'
                      )
                  cls.queryset._fetch_all = force_evaluation
      
              view = super().as_view(**initkwargs) #调用APIView父类中的as_view方法
              view.cls = cls
              view.initkwargs = initkwargs
      
              if DJANGO_VERSION >= (5, 1):
                  view.login_required = False
      
              return csrf_exempt(view)
      
    • base.py(APIView父类中的as_view方法)

      class View:
          @classonlymethod
          def as_view(cls, **initkwargs):#第三步
              def view(request, *args, **kwargs):
                  self = cls(**initkwargs)
                  self.setup(request, *args, **kwargs)
                  if not hasattr(self, "request"):
                      raise AttributeError(
                          "%s instance has no 'request' attribute. Did you override "
                          "setup() and forget to call super()?" % cls.__name__
                      )
                  return self.dispatch(request, *args, **kwargs)#调用APIView中的dispatch方法
      
              view.view_class = cls
              view.view_initkwargs = initkwargs
      
      
    • drf框架中APIView类中dispatch方法

      class APIView(View):
              def dispatch(self, request, *args, **kwargs):
              self.args = args
              self.kwargs = kwargs
              request = self.initialize_request(request, *args, **kwargs)#初始化request对象
              self.request = request
              self.headers = self.default_response_headers 
      
              try:
                  self.initial(request, *args, **kwargs)
                  if request.method.lower() in self.http_method_names:
                      handler = getattr(self, request.method.lower(),
                                        self.http_method_not_allowed)
                  else:
                      handler = self.http_method_not_allowed
      
                  response = handler(request, *args, **kwargs)
      
              except Exception as exc:
                  response = self.handle_exception(exc)
      
              self.response = self.finalize_response(request, response, *args, **kwargs)
              return self.response
      
    • drf框架中APIView类中initialize_request方法

      class APIView(View):  
          def initialize_request(self, request, *args, **kwargs):
              """
              Returns the initial request object.
              """
              parser_context = self.get_parser_context(request)
      
              return Request( #封装的Request对象
                  request,
                  parsers=self.get_parsers(),
                  authenticators=self.get_authenticators(),
                  negotiator=self.get_content_negotiator(),
                  parser_context=parser_context
              )
      

1.4 request对象

drf中的request其实是对请求的再次封装,其目的就是在原来的request对象基础中再进行封装一些drf中需要用到的值。

image-20220904073610480

class UserView(NbApiView):
    def get(self, request):
        print(request.user, request.auth)
        print(request.data)
        print(request.auth)
        #django的request对象
        print(request.GET)
        print(request.method)
        print(request.path_info)

        return Response("UserView")

2.认证

在开发API过程中,有些功能需要登录才能访问,有些无需登录。drf中的认证组件主要就是用来实现此功能。

关于认证组件,我们用案例的形式,先来学习常见的用用场景,然后再来剖析源码。

2.1 案例1

项目要开发3个接口,其中1个无需登录接口、2个必须登录才能访问的接口。

image-20220904082534101

image-20220904085937250

在浏览器上中访问:/order/token=xxxdsfsdfdf

认证组件中返回的两个值,分别赋值给:request.userrequest.auth

  • 单独在每一个View类中配置
class MyAuthentication(BaseAuthentication):
    def authenticate(self, request):
        token = request.query_params.get("token")
        if not token:
            return

        user_object = models.UserInfo.objects.filter(token=token).first()
        if user_object:
            return user_object, token  # request.user = 用户对象; request.auth = token

    def authenticate_header(self, request):
        # return 'Basic realm="API"'
        return "API"

class LoginView(APIView):
    authentication_classes = [] #不用登录即可访问

    def post(self, request):
        # 1.接收用户POST提交的用户名和密码
        # print(request.query_params)
        user = request.data.get("username")
        pwd = request.data.get("password")

        # 2.数据库校验
        user_object = models.UserInfo.objects.filter(username=user, password=pwd).first()
        if not user_object:
            return Response({"status": False, 'msg': "用户名或密码错误"})

        # 3.正确
        token = str(uuid.uuid4())
        user_object.token = token
        user_object.save()

        return Response({"status": True, 'data': token})


class UserView(NbApiView):
    authentication_classes = [] #不用登录即可访问

    def get(self, request):
        return Response("UserView")

    def post(self, request):
        print(request.user, request.auth)
        return Response("UserView")


class OrderView(NbApiView):
    authentication_classes = [MyAuthentication] #登录才能访问

    def get(self, request):
        print(request.user, request.auth)
        return Response("OrderView")

class AvatarView(NbApiView):
    authentication_classes = [MyAuthentication] #登录才能访问

    def get(self, request):
        print(request.user, request.auth)
        return Response({"status": True, "data": [11, 22, 33, 44]})


2.2 案例2

项目要开发100个接口,其中1个无需登录接口、99个必须登录才能访问的接口。

此时,就需要用到drf的全局配置(认证组件的类不能放在视图view.py中,会因为导入APIView导致循环引用)。

image-20220904084906568

  • settings.py

    # ############## drf配置 ###############
    REST_FRAMEWORK = {
        "UNAUTHENTICATED_USER": None,
        "UNAUTHENTICATED_TOKEN": None,
        "DEFAULT_AUTHENTICATION_CLASSES": [
            "ext.auth.MyAuthentication"
        ],
    }
    

2.3 案例3

项目要开发100个接口,其中1个无需登录接口、98个必须登录才能访问的接口、1个公共接口(未登录时显示公共/已登录时显示个人信息)。

image-20220904090855727

class MyAuthentication(BaseAuthentication):
    def authenticate(self, request):
        token = request.query_params.get("token")
        if not token:
            return

        user_object = models.UserInfo.objects.filter(token=token).first()
        if user_object:
            return user_object, token  # request.user = 用户对象; request.auth = token

    def authenticate_header(self, request):
        # return 'Basic realm="API"'
        return "Token"

class CommonAuthentication(BaseAuthentication):#公共接口
    def authenticate(self, request):
        token = request.query_params.get("token")
        if not token:
            return

        user_object = models.UserInfo.objects.filter(token=token).first()
        if user_object:
            return user_object, token  # request.user = 用户对象; request.auth = token

    def authenticate_header(self, request):
        # return 'Basic realm="API"'
        return "Token"

2.4 案例4

项目要开发100个接口,其中1个无需登录接口、98个必须登录才能访问的接口、1个公共接口(未登录时显示公共/已登录时显示个人信息)。

原来的认证信息只能放在URL中传递,如果程序中支持放在很多地方,例如:URL中、请求头中等。

认证组件中,如果是使用了多个认证类,会按照顺序逐一执行其中的authenticate方法

  • 返回None或无返回值,表示继续执行后续的认证类
  • 返回 (user, auth) 元组,则不再继续并将值赋值给request.user和request.auth
  • 抛出异常 AuthenticationFailed(...),认证失败,不再继续向后走。

image-20220904093128480

2.5 源码分析

image-20210822092707803

  • drf框架中APIView类中dispatch方法
class APIView(View):
        def dispatch(self, request, *args, **kwargs):
        self.args = args
        self.kwargs = kwargs
        request = self.initialize_request(request, *args, **kwargs)#1.请求的封装(django的request对象+authenticators认证组件)--->加载认证组件过程
        self.request = request
        self.headers = self.default_response_headers 

        try:
            self.initial(request, *args, **kwargs)#2.获取全局的认证组件初始化
            if request.method.lower() in self.http_method_names:
                handler = getattr(self, request.method.lower(),
                                  self.http_method_not_allowed)
            else:
                handler = self.http_method_not_allowed

            response = handler(request, *args, **kwargs)

        except Exception as exc:
            response = self.handle_exception(exc)

        self.response = self.finalize_response(request, response, *args, **kwargs)
        return self.response
  • drf框架中APIView类中initialize_request方法

    class APIView(View):  
        authentication_classes = api_settings.DEFAULT_AUTHENTICATION_CLASSES
    
        def initialize_request(self, request, *args, **kwargs):
            """
            Returns the initial request object.
            """
            parser_context = self.get_parser_context(request)
    
            return Request( #封装的Request对象
                request,
                parsers=self.get_parsers(),
                authenticators=self.get_authenticators(),#	[对象,对象,对象,对象,对象]
                negotiator=self.get_content_negotiator(),
                parser_context=parser_context
            )
    
    • get_authenticators循环获取认证组件对象

      def get_authenticators(self):
          """
      	[认证类的对象,认证类的对象,认证类的对象,认证类的对象,认证类的对象]
      """
          return [auth() for auth in self.authentication_classes]
      
  • drf框架中APIView类中initial方法

    def initial(self, request, *args, **kwargs):
        self.format_kwarg = self.get_format_suffix(**kwargs)
    
        # Perform content negotiation and store the accepted info on the request
        neg = self.perform_content_negotiation(request)
        request.accepted_renderer, request.accepted_media_type = neg
    
        # Determine the API version, if versioning is in use.
        version, scheme = self.determine_version(request, *args, **kwargs)
        request.version, request.versioning_scheme = version, scheme
    
        # Ensure that the incoming request is permitted
        self.perform_authentication(request)#关于认证相关的authen
        self.check_permissions(request)
        self.check_throttles(request)
    
    • perform_authentication方法

      def perform_authentication(self, request):
          """
              Perform authentication on the incoming request.
      
              Note that if you override this and simply 'pass', then authentication
              will instead be performed lazily, the first time either
              `request.user` or `request.auth` is accessed.
              """
          request.user
      
  • 找到Request对象中的user方法

    class Request:
    
        def __init__(self, request, parsers=None, authenticators=None,
                     negotiator=None, parser_context=None):
    		...
            self._request = request
            self.authenticators = authenticators or ()
            ...
      	    @property
        def user(self):
            if not hasattr(self, '_user'):
                with wrap_attributeerrors():
                    self._authenticate()
            return self._user
        def c(self):
            #读取每个认证组件的对象,执行_authenticate方法,self=request对象
            for authenticator in self.authenticators:
                try:
                    user_auth_tuple = authenticator.authenticate(self)
                except exceptions.APIException:
                    self._not_authenticated()
                    raise
    
                if user_auth_tuple is not None:
                    self._authenticator = authenticator
                    self.user, self.auth = user_auth_tuple
                    return
    
            self._not_authenticated()
        def _not_authenticated(self):
            self._authenticator = None
    
            if api_settings.UNAUTHENTICATED_USER:
                self.user = api_settings.UNAUTHENTICATED_USER()
            else:
                self.user = None
    
            if api_settings.UNAUTHENTICATED_TOKEN:
                self.auth = api_settings.UNAUTHENTICATED_TOKEN()
            else:
                self.auth = None
        @user.setter
        def user(self, value):
            self._user = value
            self._request.user = value
    

3 权限

在drf开发中,如果有些接口必须同时满足:A条件、B条件、C条件。 有些接口只需要满足:B条件、C条件,此时就可以利用权限组件来编写这些条件。

  • 且关系,默认支持:A条件 且 B条件 且 C条件,同时满足。

    class PermissionA(BasePermission):
        message = {"code": 1003, 'data': "无权访问"}
    
        def has_permission(self, request, view):
            if request.user.role == 2:
                return True
            return False
    	
        # 暂时先这么写
        def has_object_permission(self, request, view, obj):
            return True
    
  • 或关系,自定义(方便扩展)

    class APIView(View):
    	def check_permissions(self, request):
            """
            Check if the request should be permitted.
            Raises an appropriate exception if the request is not permitted.
            """
            for permission in self.get_permissions():
                if not permission.has_permission(request, self):
                    self.permission_denied(
                        request,
                        message=getattr(permission, 'message', None),
                        code=getattr(permission, 'code', None)
                    )
    

4 限流

限流,限制用户访问频率,例如:用户1分钟最多访问100次 或者 短信验证码一天每天可以发送50次, 防止盗刷。

  • 对于匿名用户,使用用户IP作为唯一标识。
  • 对于登录用户,使用用户ID或名称作为唯一标识。
缓存={
	用户标识:[12:33,12:32,12:31,12:30,12,]    1小时/5次   12:34   11:34
{
pip3 install django-redis
# settings.py
CACHES = {
    "default": {
        "BACKEND": "django_redis.cache.RedisCache",
        "LOCATION": "redis://127.0.0.1:6379",
        "OPTIONS": {
            "CLIENT_CLASS": "django_redis.client.DefaultClient",
            "PASSWORD": "qwe123",
        }
    }
}

image-20210822115201724

CACHES = {
    "default": {
        "BACKEND": "django_redis.cache.RedisCache",
        "LOCATION": "redis://127.0.0.1:6379",
        "OPTIONS": {
            "CLIENT_CLASS": "django_redis.client.DefaultClient",
            "PASSWORD": "qwe123",
        }
    }
}
from django.urls import path, re_path
from app01 import views

urlpatterns = [
    path('api/order/', views.OrderView.as_view()),
]
# views.py

from rest_framework.views import APIView
from rest_framework.response import Response
from rest_framework import exceptions
from rest_framework import status
from rest_framework.throttling import SimpleRateThrottle
from django.core.cache import cache as default_cache


class ThrottledException(exceptions.APIException):
    status_code = status.HTTP_429_TOO_MANY_REQUESTS
    default_code = 'throttled'


class MyRateThrottle(SimpleRateThrottle):
    cache = default_cache  # 访问记录存放在django的缓存中(需设置缓存)
    scope = "user"  # 构造缓存中的key
    cache_format = 'throttle_%(scope)s_%(ident)s'

    # 设置访问频率,例如:1分钟允许访问10次
    # 其他:'s', 'sec', 'm', 'min', 'h', 'hour', 'd', 'day'
    THROTTLE_RATES = {"user": "10/m"}

    def get_cache_key(self, request, view):
        if request.user:
            ident = request.user.pk  # 用户ID
        else:
            ident = self.get_ident(request)  # 获取请求用户IP(去request中找请求头)

        # throttle_u # throttle_user_11.11.11.11ser_2

        return self.cache_format % {'scope': self.scope, 'ident': ident}

    def throttle_failure(self):
        wait = self.wait()
        detail = {
            "code": 1005,
            "data": "访问频率限制",
            'detail': "需等待{}s才能访问".format(int(wait))
        }
        raise ThrottledException(detail)


class OrderView(APIView):
    throttle_classes = [MyRateThrottle, ]

    def get(self, request):
        return Response({"code": 0, "data": "数据..."})

多个限流类

本质,每个限流的类中都有一个 allow_request 方法,此方法内部可以有三种情况:

  • 返回True,表示当前限流类允许访问,继续执行后续的限流类。
  • 返回False,表示当前限流类不允许访问,继续执行后续的限流类。所有的限流类执行完毕后,读取所有不允许的限流,并计算还需等待的时间。
  • 抛出异常,表示当前限流类不允许访问,后续限流类不再执行。

全局配置

REST_FRAMEWORK = {
    "DEFAULT_THROTTLE_CLASSES":["xxx.xxx.xx.限流类", ],
    "DEFAULT_THROTTLE_RATES": {
        "user": "10/m",
        "xx":"100/h"
    }
}

底层源码实现:

image-20210822121259284

image-20210822120127336

  • 对象加载
    获取每个限流类的对象,初始化(读取限制的配置,获取到 时间间隔+访问次数) --> num_requests, duration
def check_throttles(self, request):
        throttle_durations = []
        #循环每个对象
        for throttle in self.get_throttles():
            if not throttle.allow_request(request, self):#认证失败将还需要等待的时间放入到throttle_durations列表中
                throttle_durations.append(throttle.wait())

        if throttle_durations:
            durations = [
                duration for duration in throttle_durations
                if duration is not None
            ]

            duration = max(durations, default=None)
            self.throttled(request, duration)
  • allow_request是否限流
    def allow_request(self, request, view):
        if self.rate is None:
            return True
		#获取用户唯一标识
        self.key = self.get_cache_key(request, view)
        if self.key is None:
            return True
		#历史访问记录[16:45,16:40:16:38....]
        self.history = self.cache.get(self.key, [])
        self.now = self.timer()

        #拿到最后时间在一分钟以内的数据,剔除不需要的时间
        while self.history and self.history[-1] <= self.now - self.duration:
            self.history.pop()
        #超过限制
        if len(self.history) >= self.num_requests:
            return self.throttle_failure()
        return self.throttle_success()
    
    def throttle_success(self):
        self.history.insert(0, self.now)
        self.cache.set(self.key, self.history, self.duration)#存储当前时间到缓存当中
        return True

    def throttle_failure(self):
        return False
    
    def wait(self):
        if self.history:
            remaining_duration = self.duration - (self.now - self.history[-1])#计算还需要等待时间
        else:
            remaining_duration = self.duration

            available_requests = self.num_requests - len(self.history) + 1 #访问次数
            if available_requests <= 0:
                return None

            return remaining_duration / float(available_requests)

posted on 2026-06-12 23:01  恍惚aa  阅读(2)  评论(0)    收藏  举报