认证组件
登录功能
# 表模型
class User(models.Model):
username = models.CharField('用户名', max_length=32)
password = models.CharField('密码', max_length=32)
class UserToken(models.Model):
user = models.OneToOneField(to='User', on_delete=models.CASCADE)
token = models.CharField(max_length=64)
# 视图类
from rest_framework.response import Response
from rest_framework.views import APIView
import uuid
class UserView(APIView):
def post(self, request, *args, **kwargs):
res = {'code': 100, 'msg': ''}
username = request.data.get('username')
password = request.data.get('password')
user = models.User.objects.filter(username=username, password=password).first()
if user:
# 用户登录成功,生成随机字符串
token = str(uuid.uuid4())
# 根据user去查询,如果存在则更新,不存在则添加
models.UserToken.objects.update_or_create(defaults={'token': token}, user=user)
res['msg'] = '登录成功'
res['token'] = token
else:
res['msg'] = '用户名或密码错误'
res['code'] = 1001
return Response(res)
# 路由
urlpatterns = [
path('login/',views.UserView.as_view())
]
认证类
# 新建一个py文件auth.py
from rest_framework.authentication import BaseAuthentication
from rest_framework.exceptions import AuthenticationFailed
from app01 import models
class LoginAuth(BaseAuthentication):
def authenticate(self, request):
# 写认证规则
# 从request.query_params取出token进行验证
token = request.query_params.get('token')
# 从数据库中查询有token
user_token = models.UserToken.objects.filter(token=token).first()
if user_token:
# 登录了认证通过
# 如果返回user,token,后面视图类中通过request对象可以获取到当前登录用户
return
else:
# 未登录,抛出异常
raise AuthenticationFailed('你没有登录,认证失败')
全局和局部使用
# 局部使用(在视图类中)
class UserView(APIView):
authentication_classes = [LoginAuth,]
# 全局使用
REST_FRAMEWORK={
# 全局使用写得认证类
'DEFAULT_AUTHENTICATION_CLASSES':['app01.auth.LoginAuth']
}
# 局部禁用(在视图类中)
class UserView(APIView):
authentication_classes = []
# 认证类的查找顺序
先找视图类自己中有没有:authentication_classes,如果没有----》项目的配置文件DEFAULT_AUTHENTICATION_CLASSES------》drf的配置文件(有两个,这两个,等于无)
源码分析
# 1 APIVIew中的dispatch方法中的self.initial(request, *args, **kwargs)
# 2 --->APIView的initial方法有三句话
self.perform_authentication(request) # 认证
self.check_permissions(request) # 权限
self.check_throttles(request) # 频率
# 3 认证perform_authentication--->新封装的request对象的.user (是个方法)
# 4 Request类中的user方法里:self._authenticate()
# 5 Request类中的_authenticate方法
def _authenticate(self):
for authenticator in self.authenticators: # 配置在视图类中认证类列表,实例化得到的对象
try:
user_auth_tuple = authenticator.authenticate(self)
except exceptions.APIException:
self._not_authenticated()
raise
if user_auth_tuple is not None:
self._authenticator = authenticator
self.user, self.auth = user_auth_tuple
return
# 6 self.authenticators:self是Request类的对象,authenticators属性是类实例化是传入的
def __init__(self, request, parsers=None, authenticators=None,
negotiator=None, parser_context=None):
self.authenticators = authenticators or ()
# 7 APIView中的dispatch方法完成的初始化:
request = self.initialize_request(request, *args, **kwargs)
# 8 进人APIView中的initialize_request方法:
return Request(
request,
parsers=self.get_parsers(),
authenticators=self.get_authenticators(), # 这个完成了给他赋值
negotiator=self.get_content_negotiator(),
parser_context=parser_context
)
# 9 进人APIView中get_authenticators()方法:这个方法将我们的配置在视图中的加括号实例化的到对象
def get_authenticators(self):
# 列表推导式
return [auth() for auth in self.authentication_classes]
权限组件
权限组件使用
# 写一个权限类
from rest_framework.permissions import BasePermission
class UserPermission(BasePermission):
# 在类中设置异常提示语
# message = '你没有权限'
def has_permission(self, request, view):
# 写权限的逻辑,返回True表示有权限
if request.user.user_type == 1:
return True
else:
# 在对象中设置异常提示语
self.message = '你没有权限'
return False
# 局部使用,在视图类上配置
class BookView(ViewSetMixin,ListAPIView):
permission_classes = [UserPermission,]
# 全局使用(配置文件中)
REST_FRAMEWORK={
# 全局使用写得权限类
'DEFAULT_PERMISSION_CLASSES':['app01.auth.UserPermission']
}
源码分析
# APIView---> dispatch-self.initial---> self.check_permissions(request)
for permission in self.get_permissions():
if not permission.has_permission(request, self):
self.permission_denied(
request,
message=getattr(permission, 'message', None),
code=getattr(permission, 'code', None)
)
频率组件
组件的使用
# 限制用户的访问次数,根据ip,根据用户id
# 写个类,继承基类,重写某个方法
from rest_framework.throttling import BaseThrottle,SimpleRateThrottle
class MyThrottle(SimpleRateThrottle):
scope = 'ip_throttle' # 一定要写
def get_cache_key(self, request, view):
# 返回谁,就以谁做限制
# 按ip限制
print(request.META)
return request.META.get('REMOTE_ADDR')
# 在视图类中配置
class BookView(ViewSetMixin,ListAPIView):
throttle_classes = [MyThrottle,]
# 在配置文件中配置
REST_FRAMEWORK = {
# 频率限制的配置信息
'DEFAULT_THROTTLE_RATES': {
'ip_throttle': '3/m' # key要跟类中的scop对应,1分钟只能访问3此
}
}
# 局部使用
class BookView(ViewSetMixin,ListAPIView):
throttle_classes = [MyThrottle,]
# 全局使用(在配置文件中)
REST_FRAMEWORK={
# 全局使用写频率类
'DEFAULT_THROTTLE_CLASSES': ['app01.auth.MyThrottle'],
}
自定义频率认证
# 自定义的逻辑
#(1)取出访问者ip
#(2)判断当前ip不在访问字典里,添加进去,并且直接返回True,表示第一次访问,在字典里,继续往下走
#(3)循环判断当前ip的列表,有值,并且当前时间减去列表的最后一个时间大于60s,把这种数据pop掉,这样列表中只有60s以内的访问时间,
#(4)判断,当列表小于3,说明一分钟以内访问不足三次,把当前时间插入到列表第一个位置,返回True,顺利通过
#(5)当大于等于3,说明一分钟内访问超过三次,返回False验证失败
class MyThrottles(BaseThrottle):
VISIT_RECORD = {}
def __init__(self):
self.history=None
def allow_request(self,request, view):
#(1)取出访问者ip
# print(request.META)
ip=request.META.get('REMOTE_ADDR')
import time
ctime=time.time()
# (2)判断当前ip不在访问字典里,添加进去,并且直接返回True,表示第一次访问
if ip not in self.VISIT_RECORD:
self.VISIT_RECORD[ip]=[ctime,]
return True
self.history=self.VISIT_RECORD.get(ip)
# (3)循环判断当前ip的列表,有值,并且当前时间减去列表的最后一个时间大于60s,把这种数据pop掉,这样列表中只有60s以内的访问时间,
while self.history and ctime-self.history[-1]>60:
self.history.pop()
# (4)判断,当列表小于3,说明一分钟以内访问不足三次,把当前时间插入到列表第一个位置,返回True,顺利通过
# (5)当大于等于3,说明一分钟内访问超过三次,返回False验证失败
if len(self.history)<3:
self.history.insert(0,ctime)
return True
else:
return False
def wait(self):
import time
ctime=time.time()
return 60-(ctime-self.history[-1])
源码分析
# APIVIew的dispatch---》self.initial(request, *args, **kwargs)----》self.check_throttles(request)
throttle_durations = []
for throttle in self.get_throttles():
if not throttle.allow_request(request, self):
throttle_durations.append(throttle.wait())
# 其中的allow_request方法和wait方法
if throttle_durations:
durations = [
duration for duration in throttle_durations
if duration is not None
]
duration = max(durations, default=None)
self.throttled(request, duration)
# SimpleRateThrottle---》allow_request
def allow_request(self, request, view):
# 肯定有,要么在自定义的频率类中写,要么配置scope和配置文件
if self.rate is None:
return True
# ip 地址,唯一
self.key = self.get_cache_key(request, view)
if self.key is None:
return True
# 当前ip地址访问的 时间列表
self.history = self.cache.get(self.key, [])
# 取当前时间
self.now = self.timer()
# 判断最后一个时间是否是无效时间,如果是就pop掉
while self.history and self.history[-1] <= self.now - self.duration:
self.history.pop()
if len(self.history) >= self.num_requests:
return self.throttle_failure()
return self.throttle_success()