一、权限类Permissions

权限控制可以限制用户对于视图的访问和对于具体数据对象的访问。

  • 在执行视图的dispatch()方法前,会先进行视图访问权限的判断
  • 在通过get_object()获取具体对象时,会进行模型对象访问权限的判断

1.使用步骤

  • 第一步:写一个类,继承BasePermission
  • 第二步:重写has_permission方法
  • 第三步:在方法中校验用户是否有权限(request.user就是当前登录用户)
  • 第四步:如果有权限返回True,没有权限返回False
  • 第五步:self.message 是给前端的提示信息
  • 第六步:局部使用,全局使用,局部禁用

2.代码示例

from rest_framework.permissions import BasePermission
class UserTypePermission(BasePermission):
    def has_permission(self, request, view):
        # 只有超级管理员有权限
        if request.user.user_type == 1:
            return True  # 有权限
        else:
            # self.message = '普通用户和2b用户都没有权限' # 返回给前端的提示是什么样
            # 使用了choice后,user.user_type 拿到的是数字类型,想变成字符串 user.get_user_type_display()
            # self.message = '您是:%s 用户,您没有权限'%request.user.get_user_type_display()
            return False  # 没有权限

二、频率类

无论是否登录和是否有权限,都要限制访问的频率,比如一分钟访问3次

1.使用步骤

  • 第一步:写一个类:继承SimpleRateThrottle

  • 第二步:重写get_cache_key,返回唯一的字符串,会以这个字符串做频率限制

  • 第三步:写一个类属性scop='随意写',必须要跟配置文件对象

  • 第四步:配置文件中写

    'DEFAULT_THROTTLE_RATES': {
        '随意写': '3/m'  # 3/h  3/s  3/d
    }
    
  • 第五步:局部使用,全局使用,局部禁用

2.代码示例

频率类

from rest_framework.throttling import BaseThrottle, SimpleRateThrottle
class MyThrottling(SimpleRateThrottle):  # 我们继承SimpleRateThrottle去写,而不是继承BaseThrottle去写
    # 类属性,这个类属性可以随意命名,但要跟配置文件对应
    scope = 'luffy'
    def get_cache_key(self, request, view):
        # 返回什么,频率就以什么做限制
        # 可以通过用户id限制
        # 可以通过ip地址限制
        return request.META.get('REMOTE_ADDR')

配置文件

REST_FRAMEWORK = {
    'DEFAULT_THROTTLE_CLASSES':['app01.throttling.MyThrottling'],
    'DEFAULT_THROTTLE_RATES': {
        'luffy': '3/m'
    }
}

三、认证源码分析

  • 写个认证类,重写某个方法,配置在视图类上,就有认证了
  • 认证类加了,在视图类的方法中,request.user就是当前登录用户
  • 猜认证类的执行,是在视图类的方法之前执行的

1.源码分析

# 读APIView的执行流程:包装了新的request,执行了3大认证,执行视图类的方法,处理了全局异常

# 入口:APIView的dispatch

# 1.APIView里面重写了dispatch方法
class APIView:
    def dispatch(self, request, *args, **kwargs):
        try:
            # 并且里面有三大认证:认证、权限,频率
            self.initial(request, *args, **kwargs)
            
# 2.在initial里面写了三大认证
class APIView:
     def initial(self, request, *args, **kwargs):
            self.perform_authentication(request)  # 先走认证
            self.check_permissions(request)       # 再走权限
            self.check_throttles(request)         # 最后再走频率
            
# 3.进入perform_authentication方法,需要去drf的request对象中找user属性(方法)
class APIView:
    def perform_authentication(self, request):
	# 需要去drf的request对象中找user属性(方法)
        request.user
        
# 4.进入Request的user,刚开始还没用户,所以走self._authenticate()
from rest_framework.request import Request

    @property
    def user(self):
        if not hasattr(self, '_user'):
            # 没用户,打开上下文管理器
            with wrap_attributeerrors():
                # 没用户,认证出用户
                self._authenticate()  # 走的这一步
        # 有用户,直接返回用户
        return self._user
    
# 5.核心就是Request类中的_authenticate
class Request:
    def _authenticate(self):
        # 遍历拿到一个个认证器,进行认证
        # 所以说每次循环,拿到一个视图类中配置authentication_classes=[类名]对象
        for authenticate in self.authenticators:  # self:Request,去Request找authenticators属性
            try:
                # 认证器(对象)调用认证方法authenticate(认证类对象self.request请求对象)
                # 返回值:登录的用户和认证的信息组成的tuple
                # 该方法被try包裹,代表该方法抛异常,抛异常代表认证失败
                user_auth_tuple = authenticator.authenticate(self)  # 是Request对象
            except exceptions.APIException:
                self._not_authenticated()
                raise
                
            # 返回值处理
            if user_auth_tuple is not None:
                self._authenticator = authenticator
                # 解压赋值:如果有返回值,就将就将登录用户和登录认证分别保存到request分别保存到request.user、request.auth
                self.user, self.auth = user_auth_tuple
                return
        # 如果返回值user_auth_tuple为空,代表认证通过但是没有登录用户和登录认证信息,代表是游客
        self._not_authenticated()
     
    

# Request类中self.authenticators又是从哪里传过来的值?
     
# 1.先是在APIView中进入到drf中request方法
class APIView: 
    def dispatch(self, request, *args, **kwargs):
        # 封装新的request方法,self中的是django中request
        request = self.initialize_request(request, *args, **kwargs)
        
    # drf中request方法
    def initialize_request(self, request, *args, **kwargs):
        return Request(
            request,
            parsers=self.get_parsers(),
            # 而authenticators就是这传进来的Request
            authenticators=self.get_authenticators(),  # 列表[类的对象]
            negotiator=self.get_content_negotiator(),
            parser_context=parser_context
        )
                 
    def get_authenticators(self):
        # 列表中是一堆对象,视图类中配置authentication_classes=[类名]对象
        return [auth() for auth in self.authentication_classes]  # 配置认证类每次会实例化得到一个对象
    
# 2.又把authenticators返回给了Request类
class Request:
    # authenticators是APIView中dispatch方法传进来的值
    def __init__(self, request, parsers=None, authenticators=None,
        # 又把值赋值给了self.authenticators
        self.authenticators = authenticators or ()  # 这里又被Request类中的_authenticate循环

2.总结

认证类,要重写authenticate方法,认证通过返回两个值或None,认证不通过抛AuthenticationFailed(继承了APIException)异常

四、权限源码分析

# 1.APIView里面重写了dispatch方法
    def dispatch(self, request, *args, **kwargs):
        try:
            # 并且里面有三大认证:认证、权限,频率
            self.initial(request, *args, **kwargs)
            
# 2.在initial里面写了三大认证
     def initial(self, request, *args, **kwargs):
            self.perform_authentication(request)  # 先走认证
            self.check_permissions(request)       # 再走权限
            self.check_throttles(request)         # 最后再走频率
            
# 3.进入check_permissions方法
    def check_permissions(self, request):
        # self.get_permissions()是:视图类中的permission_classes = [SuperAdminPermission]
        for permission in self.get_permissions():
            # 每次从列表取出一个对象(配置在视图函数中限制类的对象),执行对象的has_permission方法
            # 其中self是视图类的对象,也就是执行视图类自己定义权限类的has_permission方法
            if not permission.has_permission(request, self):  # 认证失败才往下走
                self.permission_denied(
                    request,
                    message=getattr(permission, 'message', None),
                    code=getattr(permission, 'code', None)
                )
                
# 4.进入self.get_permissions():
        def get_permissions(self):
        """
        实例化并返回此视图所需的权限列表。
        """
        # 每次运行都会自动加括号,视图类permission_classes = [SuperAdminPermission()]
        return [permission() for permission in self.permission_classes]

五、频率类源码

1.源码分析

# 1、在进入源码里
class MyThrottling(SimpleRateThrottle):  # 这里进入
    scope = 'ip_1m_3' 
    
# 2、利用反射机制找属性
class SimpleRateThrottle(BaseThrottle):
    cache = default_cache  # 缓存
    timer = time.time      # 当前时间
    cache_format = 'throttle_%(scope)s_%(ident)s'
    scope = None  # 配置中key
    THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES  # 配置信息
    
    
    def __init__(self):
        if not getattr(self, 'rate', None):
            # 调用了get_rate方法
            self.rate = self.get_rate()    # rate:3/m
        # 将频率配置解析成次数和时间,分别 存放到self.num_requests、self.duration 
        self.num_requests, self.duration = self.parse_rate(self.rate)  # 调用了parse_rate方法
        
# 3、进入get_rate方法
    def get_rate(self):
        if not getattr(self, 'scope', None):
            msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
                   self.__class__.__name__)
            raise ImproperlyConfigured(msg)

        try:
            # 根据key取出values值,把values值返回__init__
            return self.THROTTLE_RATES[self.scope]  # scope就是配置文件中的"scope"
        except KeyError:
            msg = "No default throttle rate set for '%s' scope" % self.scope
            raise ImproperlyConfigured(msg)
            
# 4、进入parse_rate方法
    def parse_rate(self, rate):
        if rate is None:
            return (None, None)
        # 通过/把"3/m"转成了3  m,并且解压赋值
        num, period = rate.split('/')
        # 把3转为整形赋值给num_requests
        num_requests = int(num)
        # 根据字段0的索引取值,m对应'm': 60
        duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
        # 把3和60返回到__init__
        return (num_requests, duration)
    
# 5、走完上面几步,证明了有值,往下面走
    def allow_request(self, request, view):
        if self.rate is None:  # 如果没有值返回True
            return True
        
        # 当前登录用户的IP地址
        self.key = self.get_cache_key(request, view) # 返回的是ip地址
        # 判断是否有值,如果返回的是None,就不做限制
        if self.key is None:
            return True
        
        # 初次访问缓存为空,self.history为[],是存放时间的列表
        self.history = self.cache.get(self.key, []) # cache.get:获取缓存
        # 获取当前的时间,存放到self.now
        self.now = self.timer()
        
        # 判断self.history是否有值,并且当前时间与第一次访问时间间隔如果大于60s,第一次记录清除,不算作一次计数
        # self.now:10:56  比如
        # self.duration:[10:23,10:55]
        while self.history and self.history[-1] <= self.now - self.duration:
            self.history.pop()  # 有值,删除缓存
            
        # 用history和 限制次数3进行比较
        # history长度第一次访问是0,第二次访问是1,第三次访问是2,直到第四次访问是3为失败
        if len(self.history) >= self.num_requests:  # self.num_requests:是解压赋值中的3
            # 直接返回False,代表频率限制
            return self.throttle_failure()
        # 如果history的长度未达到限制次数,代表可以访问
        # 将当前时间插入到history列表的开头,将history列表作为数据存到缓存中
        return self.throttle_success()  # 调用了throttle_success方法
    
# 6、进入了throttle_success方法
    def throttle_success(self):
        # 把将当前时间插入到history列表的开头
        self.history.insert(0, self.now)
        # 将history列表作为数据存到缓存中,key就是'ip_1m_3': '3/m' 中的ip_1m_3
        # self.duration就是超时时间60s就会删除
        self.cache.set(self.key, self.history, self.duration) # duration:就是__init__解析的60
        return True
    
# 7、最后计算下一次访问需要等待的时间
    def wait(self):
        if self.history:
            remaining_duration = self.duration - (self.now - self.history[-1])
        else:
            remaining_duration = self.duration

        available_requests = self.num_requests - len(self.history) + 1
        if available_requests <= 0:
            return None

        return remaining_duration / float(available_requests)

2.总结

要写频率类,必须重写allow_request方法,返回True(没有到频率的限制)或False(到了频率的限制)

六、鸭子类型

走路像鸭子,说话像鸭子,它就是鸭子——指的是面向对象中,子类不需要显示的继承某个类,只要有某个的方法和属性,那我就属于这个类

  • 假设有个鸭子类Duck类,有两个方法,run,speak方法
  • 假设又有一个普通鸭子类,PDuck,如果它也是鸭子,它需要继承Duck类,只要继承了鸭子类,什么都不需要写,普通鸭子类的对象就是鸭子这种类型;如果不继承,普通鸭子类的对象就不是鸭子这种类型
  • 假设又有一个唐老鸭子类,TDuck,如果它也是鸭子,它需要继承Duck类,只要继承了鸭子类,什么都不需要写,唐老鸭子类的对象就是鸭子这种类型;如果不继承,唐老鸭子类的对象就不是鸭子这种类型

python不推崇这个,它推崇鸭子类型,指的是:不需要显示的继承某个类,只要我的类中有run和speak方法,我就是鸭子这个类

有小问题:如果使用python鸭子类型的写法,如果方法写错了,它就不是这个类型了,会有问题,解决方式:

  • 方式一:abc模块,装饰后,必须重写方法,不重写就报错
  • 方式二:drf源码中使用的:父类中写这个方法,但没有具体实现,直接抛异常

注意

django的配置文件不要乱导入,乱导入可能会出错

  • django的运行是在加载完配置文件后才能运行
  • 因为模块的导入会执行那个模块,而这个模块中又有别的导入,别的导入必须djagno运行起来才能使

作业

settings.py

REST_FRAMEWORK = {
    'DEFAULT_PERMISSION_CLASSES': ['app01.auth.UserTypePermission'],
    'DEFAULT_THROTTLE_CLASSES': ['app01.auth.MyThrottling'],
    'DEFAULT_THROTTLE_RATES': {
        'shoto': '5/m'
    }
}

models.py

from django.db import models


# Create your models here.
class Book(models.Model):
    name = models.CharField(max_length=32)
    price = models.CharField(max_length=32)
    publish = models.ForeignKey(to='Publish', on_delete=models.CASCADE)

    def __str__(self):
        return self.name

    def publish_de(self):
        return {'name': self.publish.name, 'address': self.publish.address}


class Publish(models.Model):
    name = models.CharField(max_length=32)
    address = models.CharField(max_length=32)


class User(models.Model):
    username = models.CharField(max_length=32)
    password = models.CharField(max_length=32)
    user_type = models.IntegerField(default=3, choices=((1, '超级管理员'), (2, '普通管理员'), (3, '普通用户')))

    def __str__(self):
        return self.username


class UserToken(models.Model):
    user = models.OneToOneField(to='User', on_delete=models.CASCADE)
    token = models.CharField(max_length=64, null=True)

urls.py

from django.contrib import admin
from django.urls import path, include
from rest_framework.routers import SimpleRouter

from app01 import views

router = SimpleRouter()

router.register('books', views.BookView, 'books')
router.register('publish', views.PublishView, 'publish')
router.register('user', views.UserView, 'user')


urlpatterns = [
    path('admin/', admin.site.urls),
    path('', include(router.urls))
]

serializer.py

from rest_framework import serializers
from .models import Book, Publish


class BookSerializer(serializers.ModelSerializer):
    class Meta:
        model = Book
        fields = ['name', 'price', 'publish', 'publish_de']


class PublishSerializer(serializers.ModelSerializer):
    class Meta:
        model = Publish
        fields = '__all__'

auth.py

from rest_framework.permissions import BasePermission
from rest_framework.throttling import BaseThrottle, SimpleRateThrottle

from .models import UserToken
from rest_framework.authentication import BaseAuthentication
from rest_framework.exceptions import AuthenticationFailed


class LoginAuth(BaseAuthentication):
    def authenticate(self, request):
        token = request.META.get('HTTP_TOKEN')
        user_token = UserToken.objects.filter(token=token).first()
        if user_token:
            return user_token.user, token
        else:
            raise AuthenticationFailed('您没有登录')


class UserTypePermission(BasePermission):
    def has_permission(self, request, view):
        if request.user.user_type == 1:
            return True
        else:
            self.message = '您是%s,您没有该权限' % request.user.get_user_type_display()
            return False


class MyThrottling(SimpleRateThrottle):
    scope = 'shoto'

    def get_cache_key(self, request, view):
        return request.META.get('REMOTE_ADDR')

views.py

import uuid

from django.shortcuts import render


# Create your views here.
from rest_framework.decorators import action
from rest_framework.response import Response
from rest_framework.viewsets import ViewSet, ModelViewSet

from .auth import LoginAuth, MyThrottling
from .models import Book, Publish, User, UserToken
from .serializer import BookSerializer, PublishSerializer


class BookView(ModelViewSet):
    authentication_classes = [LoginAuth, ]
    permission_classes = []
    # throttle_classes = [MyThrottling, ]

    queryset = Book.objects.all()
    serializer_class = BookSerializer


class PublishView(ModelViewSet):
    authentication_classes = [LoginAuth, ]

    queryset = Publish.objects.all()
    serializer_class = PublishSerializer


class UserView(ViewSet):
    permission_classes = []

    @action(methods=['POST', ], detail=False, url_path='login')
    def login(self, request):
        username = request.data.get('username')
        password = request.data.get('password')
        user = User.objects.filter(username=username, password=password).first()
        if user:
            token = str(uuid.uuid4())
            UserToken.objects.update_or_create(defaults={'token': token}, user=user)
            return Response({'code': 100, 'msg': '登录成功', 'token': token})
        else:
            return Response({'code': 101, 'msg': '用户名或密码错误'})
 posted on 2022-10-09 19:47  念白SAMA  阅读(62)  评论(0)    收藏  举报