认证、权限与频率组件

一、认证组件

局部视图认证

    url(r'^login/$', views.LoginView.as_view(),name="login"),

models.py

from django.db import models

class User(models.Model):
    name=models.CharField(max_length=32)
    pwd=models.CharField(max_length=32)

class Token(models.Model):
    user=models.OneToOneField("User")
    token = models.CharField(max_length=128)

    def __str__(self):
        return self.token

class Book(models.Model):
    title=models.CharField(max_length=32)
    price=models.IntegerField()
    pub_date=models.DateField()
    publish=models.ForeignKey("Publish")
    authors=models.ManyToManyField("Author")
    def __str__(self):
        return self.title

先看认证组件源代码流程:

当用户登录的时候会走APIView类下的dispatch

class APIView(View):
    authentication_classes = api_settings.DEFAULT_AUTHENTICATION_CLASSES#
    
    # for authenticator in self.authenticators:中的authenticators最终来源于这里
    def get_authenticators(self):# (12)

        return [auth() for auth in self.authentication_classes] #[TokenAuthor()]  (13)
        #如果我们定义了authentication_classes就使用自己定义的,没有定义就使用上面的
        #最终生成的列表中放着一个个认证类的实例对象

    def perform_authentication(self, request):# (4)

        request.user #查找request:->perform_authentication(self, request)->initial(self, request, *args, **kwargs)
        #->self.initial(request, *args, **kwargs)->dispatch(self, request, *args, **kwargs)最后找到的是dispatch类下
        #的request方法:request = self.initialize_request(request, *args, **kwargs),这个request是initialize_request类
        #下Request的实例化对象,可知user是Request类中的静态方法
        # 实例化对象,而且是新构建的request,  (5)
        #

    def initialize_request(self, request, *args, **kwargs):#
        return Request(#
            request,#
            authenticators=self.get_authenticators(),  #[TokenAuthor()] (11)
        )

    def initial(self, request, *args, **kwargs):# (2)
        #认证组件
        self.perform_authentication(request) (3)
      #权限组件
        self.check_permissions(request)
        #访问频率组件
        self.check_throttles(request)

    def dispatch(self, request, *args, **kwargs):#
        request = self.initialize_request(request, *args, **kwargs)#
        self.request = request#

        try:
            self.initial(request, *args, **kwargs)#  (1)

#request.py
class Request:
    def __init__(self, request, parsers=None, authenticators=None,):
        self._request = request

        self.authenticators = authenticators or () (10)  #这里我们自定义了,在Request类实例化的时候通过参数的方式传进来了  
        #authenticators=self.get_authenticators()

    @property
    def user(self):  (6)
        if not hasattr(self, '_user'):
            with wrap_attributeerrors():
                self._authenticate()  (7)
        return self._user

    #认证所有的源代码都在这
    def _authenticate(self):# (8)
        for authenticator in self.authenticators: #[TokenAuthor()]  (9)
            #查找authenticators:->_authenticate->user->Request(object)->get_authenticators
            #authenticator就是我们自定制的认证类的实例对象
            try:
                user_auth_tuple = authenticator.authenticate(self)
                #类下的实例对象调自己的方法本不需要传self,这里的self是新的request对象

            except exceptions.APIException:
                self._not_authenticated()
                raise

            if user_auth_tuple is not None:
                self._authenticator = authenticator
                self.user, self.auth = user_auth_tuple
                return

使用:

在app01.service.auth.py:

from rest_framework import exceptions
from rest_framework.authentication import BaseAuthentication #
class TokenAuth(BaseAuthentication):#
    def authenticate(self,request):#
        token=request.GET.get("token")#
        token_obj=Token.objects.filter(token=token).first()#
        if not token_obj:#
            raise exceptions.AuthenticationFailed("验证失败!")
        else:
            return token_obj.user.name,token_obj.token#

views.py

#生成token随机字符串
def get_random_str(user):
    import hashlib,time
    ctime=str(time.time())

    md5=hashlib.md5(bytes(user,encoding="utf8")) #构建一个md5对象,使用user加盐处理
    md5.update(bytes(ctime,encoding="utf8"))

    return md5.hexdigest()
#登录
from .models import User
from app01.service.auth import *
class LoginView(APIView): authentication_classes = [TokenAuth, ] # [TokenAuth(),] 这里只是做认证演示,登录是不需要认证的,
可以这样设置authentication_classes=[],当设置全局的认证组件的时候,这样做登录就不需要认证了
def post(self,request):# name=request.data.get("name") pwd=request.data.get("pwd") user=User.objects.filter(name=name,pwd=pwd).first() res = {"state_code": 1000, "msg": None}# if user: random_str=get_random_str(user.name) #取随机字符串 token=Token.objects.update_or_create(user=user,defaults={"token":random_str}) #更新token表 res["token"]=random_str# else: res["state_code"]=1001 #错误状态码 res["msg"] = "用户名或者密码错误"# import json return Response(json.dumps(res,ensure_ascii=False))#

以上只是对登录进行了认证

2 全局视图认证组件 

settings.py配置如下:

REST_FRAMEWORK={
    "DEFAULT_AUTHENTICATION_CLASSES":["app01.service.auth.TokenAuth",]
}

 二、权限组件

1 局部视图权限

源码:流程

class APIView(View):

    permission_classes = api_settings.DEFAULT_PERMISSION_CLASSES (7)
    #用户定义了自己的权限就用用户定义的,用户没有定义就使用这里的
    #……
    
    def get_permissions(self):# (5)
        """
        Instantiates and returns the list of permissions that this view requires.
        """
        return [permission() for permission in self.permission_classes]#(6)

    #……
    def check_permissions(self, request):#  (3)
        """
        Check if the request should be permitted.
        Raises an appropriate exception if the request is not permitted.
        """
        for permission in self.get_permissions():# (4)
            if not permission.has_permission(request, self): #self是当前的view
                self.permission_denied(
                    request, message=getattr(permission, 'message', None)#
                )
    #……
    def initial(self, request, *args, **kwargs):#

        #认证组件
        self.perform_authentication(request)
        #权限组件
        self.check_permissions(request)   (2)
        self.check_throttles(request)

    def dispatch(self, request, *args, **kwargs):# 
        #……
        try:
            self.initial(request, *args, **kwargs)# (1)
        #……

用法:

app01\models.py

class User(models.Model):
    name=models.CharField(max_length=32)
    pwd=models.CharField(max_length=32)
    type_choices=((1,"普通用户"),(2,"VIP"),(3,"SVIP"))
    user_type=models.IntegerField(choices=type_choices,default=1)

在app01.service.permissions.py中:

from rest_framework.authentication import BaseAuthentication
class SVIPPermission(BaseAuthentication):
    message="只有超级用户才能访问"
    def has_permission(self,request,view):  #根据源码写
        username=request.user #在登录成功后赋予的user,可以取出来
        user_type=User.objects.filter(name=username).first().user_type

        if user_type==3:

            return True # 通过权限认证
        else:
            return False

app01\views.py:

from app01.service.permissions import SVIPPermission
class AuthorModelView(viewsets.ModelViewSet):#

    authentication_classes = [TokenAuth,]#
    permission_classes=[SVIPPermission,]#

    queryset = Author.objects.all()#
    serializer_class = AuthorModelSerializers#

2 全局视图权限

settings.py配置如下:

REST_FRAMEWORK={
    "DEFAULT_AUTHENTICATION_CLASSES":["app01.service.auth.Authentication",],
    "DEFAULT_PERMISSION_CLASSES":["app01.service.permissions.SVIPPermission",]
}

三、throttle(访问频率)组件

 在app01.service.throttles.py中:

源码:和权限组件类似

class APIView(View):

    throttle_classes = api_settings.DEFAULT_THROTTLE_CLASSES


    def get_throttles(self):

        return [throttle() for throttle in self.throttle_classes]
    #……

    def check_throttles(self, request):

        throttle_durations = []
        for throttle in self.get_throttles():
            if not throttle.allow_request(request, self):
                throttle_durations.append(throttle.wait())

        if throttle_durations:
            durations = [
                duration for duration in throttle_durations
                if duration is not None
            ]

            duration = max(durations, default=None)
            self.throttled(request, duration)
            #……

    def initial(self, request, *args, **kwargs):

        # 频率组件
        self.check_throttles(request)

    def dispatch(self, request, *args, **kwargs):

        try:
            self.initial(request, *args, **kwargs)
            #……


#settings.py
DEFAULTS = {

    'DEFAULT_THROTTLE_CLASSES': [],

}


class APISettings:

    def __init__(self, user_settings=None, defaults=None, import_strings=None):
        if user_settings:
            self._user_settings = self.__check_user_settings(user_settings)
        self.defaults = defaults or DEFAULTS

    @property
    def user_settings(self):
        if not hasattr(self, '_user_settings'):
            self._user_settings = getattr(settings, 'REST_FRAMEWORK', {})
        return self._user_settings

    def __getattr__(self, attr):
        if attr not in self.defaults:
            raise AttributeError("Invalid API setting: '%s'" % attr)

        try:
            # Check if present in user settings
            val = self.user_settings[attr]
        except KeyError:
            # Fall back to defaults
            val = self.defaults[attr]

        # Coerce import strings into classes
        if attr in self.import_strings:
            val = perform_import(val, attr)

        # Cache the result
        self._cached_attrs.add(attr)
        setattr(self, attr, val)
        return val

api_settings = APISettings(None, DEFAULTS, IMPORT_STRINGS)
源码

用法:

from rest_framework.throttling import BaseThrottle

VISIT_RECORD={}
class VisitThrottle(BaseThrottle):

    def __init__(self):
        self.history=None

    def allow_request(self,request,view):
        remote_addr = request.META.get('REMOTE_ADDR')
        print(remote_addr)
        import time
        ctime=time.time()

        if remote_addr not in VISIT_RECORD:
            VISIT_RECORD[remote_addr]=[ctime,]
            return True

        history=VISIT_RECORD.get(remote_addr)
        self.history=history

        while history and history[-1]<ctime-60:
            history.pop()

        if len(history)<3:
            history.insert(0,ctime)
            return True
        else:
            return False

    def wait(self):
        import time
        ctime=time.time()
        return 60-(ctime-self.history[-1])

在views.py中:

from app01.service.throttles import *

class BookViewSet(generics.ListCreateAPIView):
    throttle_classes = [VisitThrottle,]
    queryset = Book.objects.all()
    serializer_class = BookSerializers

全局视图throttle

REST_FRAMEWORK={
    "DEFAULT_AUTHENTICATION_CLASSES":["app01.service.auth.Authentication",],
    "DEFAULT_PERMISSION_CLASSES":["app01.service.permissions.SVIPPermission",],
    "DEFAULT_THROTTLE_CLASSES":["app01.service.throttles.VisitThrottle",]
}

内置throttle类

在app01.service.throttles.py修改为:

class VisitThrottle(SimpleRateThrottle):

    scope="visit_rate"
    def get_cache_key(self, request, view):

        return self.get_ident(request)

settings.py设置:

REST_FRAMEWORK={
    "DEFAULT_AUTHENTICATION_CLASSES":["app01.service.auth.Authentication",],
    "DEFAULT_PERMISSION_CLASSES":["app01.service.permissions.SVIPPermission",],
    "DEFAULT_THROTTLE_CLASSES":["app01.service.throttles.VisitThrottle",],
    "DEFAULT_THROTTLE_RATES":{
        "visit_rate":"5/m",
    }
}

 rest framework框架访问频率限制推荐放到 redis/memecached

posted @ 2020-06-06 21:03  zh_小猿  阅读(144)  评论(0编辑  收藏  举报