django_rest_framework源码分析

CBV&APIView

'''原生django as_view方法'''
class View(object):

    http_method_names = ['get', 'post', 'put', 'patch', 'delete', 'head', 'options', 'trace']  #用于判断请求方式是否符合http协议规定

    @classonlymethod
    def as_view(cls, **initkwargs):

        for key in initkwargs:
            if key in cls.http_method_names:    
                raise TypeError("You tried to pass in the %s method name as a "
                                "keyword argument to %s(). Don't do that."
                                % (key, cls.__name__))
            if not hasattr(cls, key):
                raise TypeError("%s() received an invalid keyword %r. as_view "
                                "only accepts arguments that are already "
                                "attributes of the class." % (cls.__name__, key))

        def view(request, *args, **kwargs):                   #定义一个view函数
            self = cls(**initkwargs)                            #实例化该类 生成对象
            if hasattr(self, 'get') and not hasattr(self, 'head'):
                self.head = self.get
            self.request = request
            self.args = args
            self.kwargs = kwargs
            return self.dispatch(request, *args, **kwargs)    #view函数返回View对象的dispatch方法 在此进入dispatch方法
        view.view_class = cls
        view.view_initkwargs = initkwargs

        update_wrapper(view, cls, updated=())

        update_wrapper(view, cls.dispatch, assigned=())
        return view




'''APIView类'''
class APIView(View):
    renderer_classes = api_settings.DEFAULT_RENDERER_CLASSES       #如果CBV中没有给这些参数赋值 则使用drf setting文件中的默认值 
    parser_classes = api_settings.DEFAULT_PARSER_CLASSES                  
    authentication_classes = api_settings.DEFAULT_AUTHENTICATION_CLASSES  #认证组件
    throttle_classes = api_settings.DEFAULT_THROTTLE_CLASSES              #频率组件
    permission_classes = api_settings.DEFAULT_PERMISSION_CLASSES          #权限组件
    content_negotiation_class = api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS
    metadata_class = api_settings.DEFAULT_METADATA_CLASS
    versioning_class = api_settings.DEFAULT_VERSIONING_CLASS

    settings = api_settings
'''as_view'''
    
      def as_view(cls, **initkwargs):

        if isinstance(getattr(cls, 'queryset', None), models.query.QuerySet):
            def force_evaluation():
                raise RuntimeError(
                    'Do not evaluate the `.queryset` attribute directly, '
                    'as the result will be cached and reused between requests. '
                    'Use `.all()` or call `.get_queryset()` instead.'
                )
            cls.queryset._fetch_all = force_evaluation

        view = super(APIView, cls).as_view(**initkwargs)   #继承原生view的as_view方法 按原生view 会走到dispatch方法中
        view.cls = cls
        view.initkwargs = initkwargs

        return csrf_exempt(view)   #装饰器csrf_exempt可以取消APIView的csrf认证   该方法也可以在视图中使用取消指定视图函数的csrf认证
        



'''dispatch'''

    def dispatch(self, request, *args, **kwargs):
        
        self.args = args
        self.kwargs = kwargs
        request = self.initialize_request(request, *args, **kwargs)     #将原生request对象封装为drf request对象
        self.request = request
        self.headers = self.default_response_headers  

        try:
            self.initial(request, *args, **kwargs)               #走drf request对象中的版本控制/认证/权限/频率组件


            if request.method.lower() in self.http_method_names: #http_method_names=['get', 'post', 'put', 'patch', 'delete', 'head', 'options', 'trace']
                handler = getattr(self, request.method.lower(),
                                  self.http_method_not_allowed)    #与原生view一样用反射调用与请求响应的视图函数
            else:
                handler = self.http_method_not_allowed        #请求方式错误报异常

            response = handler(request, *args, **kwargs)      

        except Exception as exc:                              #request中的各组件抛出的异常都在这被捕获
            response = self.handle_exception(exc)               #handle_exception中调用过认证组件的authenticate_header方法 所以认证组件必须有该方法

        self.response = self.finalize_response(request, response, *args, **kwargs)
        return self.response

'''initialize_request'''

     def initialize_request(self, request, *args, **kwargs):

        parser_context = self.get_parser_context(request)

        return Request(                                   #用Request类封装原生request对象
            request,
            parsers=self.get_parsers(),
            authenticators=self.get_authenticators(),       #get_authenticators 内部以列表生成式的方式调用配置的认证组件
            parser_context=parser_context
        )

    def get_authenticators(self):
    #以列表生成式的方式调用配置的认证组件 authentication_classes中为认证组件的类加上()实例化为对象
    #然后返回都是认证组件对象的列表
        return [auth() for auth in self.authentication_classes]   


'''Request对象'''
class Response(SimpleTemplateResponse):
        def __init__(self, request, parsers=None, authenticators=None,
                 negotiator=None, parser_context=None):
        assert isinstance(request, HttpRequest), (
            'The `request` argument must be an instance of '
            '`django.http.HttpRequest`, not `{}.{}`.'
            .format(request.__class__.__module__, request.__class__.__name__)
        )

        self._request = request                                         #调用_request来获取原生request对象
        self.parsers = parsers or ()       
        self.authenticators = authenticators or ()                         #一系列赋值操作来封装request
        self.negotiator = negotiator or self._default_negotiator()
        self.parser_context = parser_context
        self._data = Empty
        self._files = Empty
        self._full_data = Empty
        self._content_type = Empty
        self._stream = Empty

        if self.parser_context is None:
            self.parser_context = {}
        self.parser_context['request'] = self
        self.parser_context['encoding'] = request.encoding or settings.DEFAULT_CHARSET

        force_user = getattr(request, '_force_auth_user', None)
        force_token = getattr(request, '_force_auth_token', None)
        if force_user is not None or force_token is not None:
            forced_auth = ForcedAuthentication(force_user, force_token)
            self.authenticators = (forced_auth,)

    @property
    def user(self):                                  

        if not hasattr(self, '_user'):
            with wrap_attributeerrors():
                self._authenticate()                 
        return self._user
    
    def _authenticate(self):

        for authenticator in self.authenticators:                #authenticators是认证组件对象的列表
            try:
                user_auth_tuple = authenticator.authenticate(self)     #启用认证组件调用组件中的authenticate方法  所以自定义认证组件时必须定义此方法
            except exceptions.APIException:
                self._not_authenticated()
                raise

            if user_auth_tuple is not None:                #user_auth_tuple为None 说明组件中的authenticate方法返回None 此代码块不走该次循环的认证组件
                self._authenticator = authenticator
                self.user, self.auth = user_auth_tuple       #自定义的认证组件必须返回两个值 第一个值一般为用户名或token 
                return            #认证组件有返回值 直接退出该函数          

        self._not_authenticated()     #如果所有组件返回值都为None 走此方法

    def _not_authenticated(self):

        self._authenticator = None

        if api_settings.UNAUTHENTICATED_USER:                 
            self.user = api_settings.UNAUTHENTICATED_USER()     # 所有认证组件均没有返回值,则user使用默认配置为匿名用户
        else:
            self.user = None

        if api_settings.UNAUTHENTICATED_TOKEN:
            self.auth = api_settings.UNAUTHENTICATED_TOKEN()
        else:
            self.auth = None


'''initial'''

     def initial(self, request, *args, **kwargs):

        self.format_kwarg = self.get_format_suffix(**kwargs)


        neg = self.perform_content_negotiation(request)
        request.accepted_renderer, request.accepted_media_type = neg


        version, scheme = self.determine_version(request, *args, **kwargs)
        request.version, request.versioning_scheme = version, scheme        #版本控制放回版本号


        self.perform_authentication(request)       #启用认证 调用request.user
        self.check_permissions(request)                #启用权限 同样以列表生成式的方式调用配置
        self.check_throttles(request)                #启用频率 同样以列表生成式的方式调用配置

权限&频率组件

    def initial(self, request, *args, **kwargs):

        self.format_kwarg = self.get_format_suffix(**kwargs)


        neg = self.perform_content_negotiation(request)
        request.accepted_renderer, request.accepted_media_type = neg


        version, scheme = self.determine_version(request, *args, **kwargs)
        request.version, request.versioning_scheme = version, scheme


        self.perform_authentication(request)           #认证
        self.check_permissions(request)                 #权限
        self.check_throttles(request)                    #频率



'''check_permissions'''

        def check_permissions(self, request):

        for permission in self.get_permissions():                  #get_permissions函数返回值做for循环
            if not permission.has_permission(request, self):         #执行组件对象中的has_permission方法 判断返回值是T/F    所以权限组件必须写has_permission方法并返回Ture或False
                self.permission_denied(
                    request, message=getattr(permission, 'message', None)            #权限组件未通过抛异常 可自定message
                )

                
                
                
    def get_permissions(self):
    #与认证组件一个套路
        return [permission() for permission in self.permission_classes]        


'''check_throttles'''
        #套路基本一样
    def check_throttles(self, request):

        for throttle in self.get_throttles():           #get_throttles返回列表
            if not throttle.allow_request(request, self):       #执行allow_request  返回T/F
                self.throttled(request, throttle.wait())       #频率组件必须定义allow_request 与 wait 两个方法
                

    def get_throttles(self):
    #老套路
        return [throttle() for throttle in self.throttle_classes]

内置频率组件

    def check_throttles(self, request):

        for throttle in self.get_throttles():           
            if not throttle.allow_request(request, self):       
                self.throttled(request, throttle.wait())       #频率组件中的wait方法

'''BaseThrottle 内置频率组件基类'''

        def get_ident(self, request):      #获取唯一标示

        xff = request.META.get('HTTP_X_FORWARDED_FOR')  #识别代理ip时获取用户本机ip
        remote_addr = request.META.get('REMOTE_ADDR')   #获取客户端ip
        num_proxies = api_settings.NUM_PROXIES           

        if num_proxies is not None:
            if num_proxies == 0 or xff is None:
                return remote_addr
            addrs = xff.split(',')
            client_addr = addrs[-min(num_proxies, len(addrs))]
            return client_addr.strip()

        return ''.join(xff.split()) if xff else remote_addr        #一般xff为None 返回remote_addr


'''SimpleRateThrottle 简单频率组件'''

    class SimpleRateThrottle(BaseThrottle):

    cache = default_cache
    timer = time.time
    cache_format = 'throttle_%(scope)s_%(ident)s'
    scope = None
    THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES    #配置
    

    def __init__(self):                    #实例化对象先走init方法
    if not getattr(self, 'rate', None):     #用反射的方式获取rate 
        self.rate = self.get_rate()            #没有rate 走get_rate方法
    self.num_requests, self.duration = self.parse_rate(self.rate)     #走parse_rate
    
    
    
    def get_rate(self):

        if not getattr(self, 'scope', None):          #获取scope scope没有值抛出异常 所以组件需要给scope赋值
            msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
                   self.__class__.__name__)
            raise ImproperlyConfigured(msg)

        try:
            return self.THROTTLE_RATES[self.scope]     #返回THROTTLE_RATES   为配置信息中的key为self.scope的值 赋值给rate
        except KeyError:
            msg = "No default throttle rate set for '%s' scope" % self.scope
            raise ImproperlyConfigured(msg)

            
    def parse_rate(self, rate):
    
        if rate is None:
            return (None, None)
        num, period = rate.split('/')               #将rate做字符串分隔
        num_requests = int(num)
        duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
        return (num_requests, duration)             #最后返回数字 与 秒的倍数 
        
        
    def allow_request(self, request, view):
        
        if self.rate is None:                
            return True

        self.key = self.get_cache_key(request, view)       #改写get_cache_key方法使用get_ident方法获取客户端ip
        if self.key is None:            
            return True

        self.history = self.cache.get(self.key, [])         #缓存中的访问列表   值为时间戳或时间戳列表
        self.now = self.timer()                            #当前时间戳

        
        while self.history and self.history[-1] <= self.now - self.duration:         #如果访问时间超过指定的时间 删掉最早的一次访问记录
            self.history.pop()
        if len(self.history) >= self.num_requests:           #如果访问次数超过指定次数  返回False限制访问
            return self.throttle_failure()
        return self.throttle_success()                       
        
        
    def throttle_success(self):                #所有验证通过返回True

        self.history.insert(0, self.now)             
        self.cache.set(self.key, self.history, self.duration)        #将本次访问记录在缓存中
        return True

版本控制组件

    def initial(self, request, *args, **kwargs):
        
        self.format_kwarg = self.get_format_suffix(**kwargs)

        neg = self.perform_content_negotiation(request)
        request.accepted_renderer, request.accepted_media_type = neg

       
        version, scheme = self.determine_version(request, *args, **kwargs)      #版本控制 返回两个值
        request.version, request.versioning_scheme = version, scheme          #赋值操作

       
        self.perform_authentication(request)
        self.check_permissions(request)
        self.check_throttles(request)
        
        

        
        
     def determine_version(self, request, *args, **kwargs):
       
        if self.versioning_class is None:   #versioning_class可以配置
            return (None, None)
        scheme = self.versioning_class()       #处理版本的对象
        return (scheme.determine_version(request, *args, **kwargs), scheme)   #scheme.determine_version()为版本控制组件中的determine_version方法
        
        
'''内置QueryParameterVersioning组件'''

    class BaseVersioning(object):
        default_version = api_settings.DEFAULT_VERSION    #默认版本
        allowed_versions = api_settings.ALLOWED_VERSIONS  #允许出现的版本
        version_param = api_settings.VERSION_PARAM        #版本参数
        
        def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra):
            return _reverse(viewname, args, kwargs, request, format, **extra)


    class QueryParameterVersioning(BaseVersioning):

    invalid_version_message = _('Invalid version in query parameter.')

        def determine_version(self, request, *args, **kwargs):
            version = request.query_params.get(self.version_param, self.default_version)    #原生request.GET 中的version_param   version_param可配置
            if not self.is_allowed_version(version):                         
                raise exceptions.NotFound(self.invalid_version_message)
            return version

        def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra):
            url = super(QueryParameterVersioning, self).reverse(
                viewname, args, kwargs, request, format, **extra
            )
            if request.version is not None:
                return replace_query_param(url, self.version_param, request.version)
            return url
            
        

        
        
        
        @property
    def query_params(self):
        #request对象中的方法 返回原生request中GET请求的数据
        return self._request.GET

        
        
    def is_allowed_version(self, version):
    if not self.allowed_versions:
        return True
    return ((version is not None and version == self.default_version) or     #判断版本号是否在允许的版本中
            (version in self.allowed_versions))

 

解析器

request.data

    @property
    def data(self):
        if not _hasattr(self, '_full_data'):
            self._load_data_and_files()               
        return self._full_data        #返回data数据
        
        
        
     def _load_data_and_files(self):
     
        if not _hasattr(self, '_data'):
            self._data, self._files = self._parse()    #解析
            if self._files:                                #_files一般不会有值
                self._full_data = self._data.copy()            
                self._full_data.update(self._files)        
            else:
                self._full_data = self._data     #赋值

           
            if is_form_media_type(self.content_type):
                self._request._post = self.POST
                self._request._files = self.FILES


'''Request'''
  def initialize_request(self, request, *args, **kwargs):

        parser_context = self.get_parser_context(request)

        return Request(
            request,
            parsers=self.get_parsers(),                    #封装request时传入解析器对象   内部赋值给self.parser
            authenticators=self.get_authenticators(),
            negotiator=self.get_content_negotiator(),
            parser_context=parser_context
        )

    def get_parsers(self):
        #也是老套路
        return [parser() for parser in self.parser_classes]




'''parse解析'''
    
    def _parse(self):

        media_type = self.content_type    #获取客户端META中的 CONTENT_TYPE    :*  Content-Type: text/html;charset=utf-8
        try:
            stream = self.stream
        except RawPostDataException:
            if not hasattr(self._request, '_post'):
                raise

            if self._supports_form_parsing():
                return (self._request.POST, self._request.FILES)
            stream = None

        if stream is None or media_type is None:
            if media_type and is_form_media_type(media_type):
                empty_data = QueryDict('', encoding=self._request._encoding)
            else:
                empty_data = {}
            empty_files = MultiValueDict()
            return (empty_data, empty_files)

        parser = self.negotiator.select_parser(self, self.parsers)      #选择解析器   self.parser为解析器对象列表  返回对应的与数据格式对应的解析器对象

        if not parser:
            raise exceptions.UnsupportedMediaType(media_type)

        try:
            parsed = parser.parse(stream, media_type, self.parser_context)   #执行解析器对象中的parse方法 拿到数据
        except Exception:
            self._data = QueryDict('', encoding=self._request._encoding)
            self._files = MultiValueDict()
            self._full_data = self._data
            raise

        try:
            return (parsed.data, parsed.files)
        except AttributeError:
            empty_files = MultiValueDict()
            return (parsed, empty_files)



'''select_parser选择解析器'''

        def select_parser(self, request, parsers):

        for parser in parsers:
            #获取当前请求的数据类型 与解析器列表中的支持格式对应然后返回相应的解析器对象 如果没有返回None
            if media_type_matches(parser.media_type, request.content_type):       #parser.media_type 为解析器中支持的数据类型类似media_type = 'application/json'
                return parser                                            
        return None
        

'''内置JSONParser解析器'''


class JSONParser(BaseParser):
    media_type = 'application/json'
    renderer_class = renderers.JSONRenderer
    strict = api_settings.STRICT_JSON

    def parse(self, stream, media_type=None, parser_context=None):     

        parser_context = parser_context or {}
        encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET)

        try:
            decoded_stream = codecs.getreader(encoding)(stream)
            parse_constant = json.strict_constant if self.strict else None
            return json.load(decoded_stream, parse_constant=parse_constant)     #用json模块解析前端发送来的json数据
        except ValueError as exc:
            raise ParseError('JSON parse error - %s' % six.text_type(exc))

 

序列化组件

'''序列化组件构造和初始化'''
class BaseSerializer(Field):

    def __init__(self, instance=None, data=empty, **kwargs):
        self.instance = instance
        if data is not empty:
            self.initial_data = data
        self.partial = kwargs.pop('partial', False)
        self._context = kwargs.pop('context', {})                #赋值_context属性
        kwargs.pop('many', None)
        super(BaseSerializer, self).__init__(**kwargs)


        def __new__(cls, *args, **kwargs):            

        if kwargs.pop('many', False):                    #many=True时对queryset进行处理    没有many默认为False
            return cls.many_init(*args, **kwargs)
        return super(BaseSerializer, cls).__new__(cls, *args, **kwargs)      #many=False时以当前实例化对象来处理model对象 构造的是本身做实例化操作的那个类

      
    def many_init(cls, *args, **kwargs):        #many=True时调用
       
        allow_empty = kwargs.pop('allow_empty', None)
        child_serializer = cls(*args, **kwargs)
        list_kwargs = {
            'child': child_serializer,
        }
        if allow_empty is not None:
            list_kwargs['allow_empty'] = allow_empty
        list_kwargs.update({
            key: value for key, value in kwargs.items()
            if key in LIST_SERIALIZER_KWARGS
        })
        meta = getattr(cls, 'Meta', None)                    #取到Meta类    
        list_serializer_class = getattr(meta, 'list_serializer_class', ListSerializer)        #Meta中没有list_serializer_class字段  则用ListSerializer    
        return list_serializer_class(*args, **list_kwargs)        #内置的ListSerializer组件可以处理queryset对象
        


'''data'''


class BaseSerializer(Field):        
    @property
    def data(self):
        if hasattr(self, 'initial_data') and not hasattr(self, '_validated_data'):
            msg = (
                'When a serializer is passed a `data` keyword argument you '
                'must call `.is_valid()` before attempting to access the '
                'serialized `.data` representation.\n'
                'You should either call `.is_valid()` first, '
                'or access `.initial_data` instead.'
            )
            raise AssertionError(msg)

        if not hasattr(self, '_data'):
            if self.instance is not None and not getattr(self, '_errors', None):
                self._data = self.to_representation(self.instance)                        #data的值由to_representation方法得到
            elif hasattr(self, '_validated_data') and not getattr(self, '_errors', None):
                self._data = self.to_representation(self.validated_data)
            else:
                self._data = self.get_initial()
        return self._data




'''serializer'''

class Serializer(BaseSerializer):
    
    @property
    def data(self):
        ret = super(Serializer, self).data                
        return ReturnDict(ret, serializer=self)                
    
===
    def to_representation(self, instance):
        ret = OrderedDict()                        #有序字典
        fields = self._readable_fields

        for field in fields:        #遍历字段
            try:
                attribute = field.get_attribute(instance)        #调用字段对象的get_attribute方法   -> CharField().get_attribute()
            except SkipField:
                continue

            check_for_none = attribute.pk if isinstance(attribute, PKOnlyObject) else attribute
            if check_for_none is None:
                ret[field.field_name] = None
            else:
                ret[field.field_name] = field.to_representation(attribute)            #执行每个字段的to_representation方法
        return ret


'''field'''

class CharField(Field):
    def to_representation(self, value):
        return six.text_type(value)          #相当于 str(value)


#CharField继承Field

class Field(object):

    def get_attribute(self, instance):

        try:
            return get_attribute(instance, self.source_attrs)    #instance是传入的对象        source_attrs是source参数用'.'分隔的列表
        except (KeyError, AttributeError) as exc:
            if self.default is not empty:
                return self.get_default()
            if self.allow_null:
                return None
            if not self.required:
                raise SkipField()
            msg = (
                'Got {exc_type} when attempting to get a value for field '
                '`{field}` on serializer `{serializer}`.\nThe serializer '
                'field might be named incorrectly and not match '
                'any attribute or key on the `{instance}` instance.\n'
                'Original exception text was: {exc}.'.format(
                    exc_type=type(exc).__name__,
                    field=self.field_name,
                    serializer=self.parent.__class__.__name__,
                    instance=instance.__class__.__name__,
                    exc=exc
                )
            )
            raise type(exc)(msg)


    def bind(self, field_name, parent):

        # my_field = serializer.CharField(source='my_field')
        assert self.source != field_name, (
            "It is redundant to specify `source='%s'` on field '%s' in "
            "serializer '%s', because it is the same as the field name. "
            "Remove the `source` keyword argument." %
            (field_name, self.__class__.__name__, parent.__class__.__name__)
        )

        self.field_name = field_name
        self.parent = parent

        if self.label is None:
            self.label = field_name.replace('_', ' ').capitalize()    

        if self.source is None:
            self.source = field_name

        if self.source == '*':
            self.source_attrs = []
        else:
            self.source_attrs = self.source.split('.')            #字段source属性可以用'.'来查找的原因

    
def get_attribute(instance, attrs):      #该方法为文件中的函数并不是类中的方法,与前一个不同
    # -> attrs = ['userdetail', 'pk']
    for attr in attrs:
        try:
            if isinstance(instance, collections.Mapping):
                instance = instance[attr]
            else:
                instance = getattr(instance, attr)            #在instance对象中拿attr  覆盖当前instance  多次循环后相当于->user.userdetail.pk
        except ObjectDoesNotExist:
            return None
        if is_simple_callable(instance):        #判断是否已经到底  是对象还是属性
            try:
                instance = instance()            #还没到底就帮忙执行
            except (AttributeError, KeyError) as exc:
                raise ValueError('Exception raised in callable attribute "{0}"; original exception was: {1}'.format(attr, exc))
    return instance


'''HyperlinkedRelatedField返回url'''

class HyperlinkedRelatedField(RelatedField):
    @property
    def context(self):
        return getattr(self.root, '_context', {})    #将序列化组件中的_context属性赋值给自己
        

    def to_representation(self, value):
        assert 'request' in self.context, (                            #断言在context中有request属性
            "`%s` requires the request in the serializer"
            " context. Add `context={'request': request}` when instantiating "
            "the serializer." % self.__class__.__name__
        )

        request = self.context['request']                    #拿到request对象
        format = self.context.get('format', None)

        if format and self.format and self.format != format:
            format = self.format

        try:
            url = self.get_url(value, self.view_name, request, format)     #get_url 传入对象返回url        将request对象传入
        except NoReverseMatch:
            msg = (
                'Could not resolve URL for hyperlinked relationship using '
                'view name "%s". You may have failed to include the related '
                'model in your API, or incorrectly configured the '
                '`lookup_field` attribute on this field.'
            )
            if value in ('', None):
                value_string = {'': 'the empty string', None: 'None'}[value]
                msg += (
                    " WARNING: The value of the field on the model instance "
                    "was %s, which may be why it didn't match any "
                    "entries in your URL conf." % value_string
                )
            raise ImproperlyConfigured(msg % self.view_name)

        if url is None:
            return None

        return Hyperlink(url, value)


   
    def get_url(self, obj, view_name, request, format):
       
        if hasattr(obj, 'pk') and obj.pk in (None, ''):
            return None

        lookup_value = getattr(obj, self.lookup_field)        #取obj中的lookup_field属性 __init__方法中赋值了lookup_field 不传默认为'pk'
        kwargs = {self.lookup_url_kwarg: lookup_value}        #传入lookup_url_kwarg参数为key lookup_field参数为value
        return self.reverse(view_name, kwargs=kwargs, request=request, format=format)    
        #反向解析生成了url


'''
总结:HyperlinkedRelatedField用反向解析的方法获得url
需要传view_name, lookup_field, lookup_url_kwarg 三个参数 
view_name = 路由需要起别名 这个参数传入别名                            
lookup_field = url是以哪个字段来取到数据的 lookup_field传入那个字段 一般为id     #lookup_field = 'userdetail.id'
lookup_url_kwarg = 有名分组的参数名称                                    #lookup_url_kwarg='pk'


在使用序列化时需要传入context参数 内容{'request': request}        #不传会在断言处报错
'''

 

视图组件

'''GenericAPIView'''

class GenericAPIView(views.APIView):
    queryset = None                #指定queryset
    serializer_class = None            #指定用哪个序列化组件

    lookup_field = 'pk'
    lookup_url_kwarg = None

    filter_backends = api_settings.DEFAULT_FILTER_BACKENDS

    pagination_class = api_settings.DEFAULT_PAGINATION_CLASS        #指定用哪个分页器 或全局配置  



    def get_queryset(self):

        assert self.queryset is not None, (                            #不指定queryset对象报错
            "'%s' should either include a `queryset` attribute, "
            "or override the `get_queryset()` method."
            % self.__class__.__name__
        )

        queryset = self.queryset
        if isinstance(queryset, QuerySet):            #判断是否加上.all()   指定是可以不加.all()
            queryset = queryset.all()
        return queryset                                #可以使用get_queryset方法拿到queryset对象
        
        
    def get_serializer(self, *args, **kwargs):

        serializer_class = self.get_serializer_class()            #拿到序列化组件(有点多此一举)
        kwargs['context'] = self.get_serializer_context()        #自动给序列化对象传上context值
        return serializer_class(*args, **kwargs)                #可以使用get_serializer方法拿到序列化组件
    
    
    def get_serializer_class(self):
        assert self.serializer_class is not None, (                #判断是否指定序列化组件
            "'%s' should either include a `serializer_class` attribute, "
            "or override the `get_serializer_class()` method."
            % self.__class__.__name__
        )
        return self.serializer_class
        
        
        @property
    def paginator(self):            #拿到分页器对象
        if not hasattr(self, '_paginator'):
            if self.pagination_class is None:
                self._paginator = None
            else:
                self._paginator = self.pagination_class()
        return self._paginator

    def paginate_queryset(self, queryset):        #使用分页器对象
        if self.paginator is None:
            return None
        return self.paginator.paginate_queryset(queryset, self.request, view=self)        #使用paginate_queryset方法使用分页器
        

'''GenericViewSet'''

class GenericViewSet(ViewSetMixin, generics.GenericAPIView):            #继承了两个类    
    pass                                            #ViewSetMixin在左边优先度比较高       
                                                    #GenericAPIView继承了APIView 有视图函数的所有功能


'''ViewSetMixin'''


class ViewSetMixin(object):        #可以使用半自动路由
     def initialize_request(self, request, *args, **kwargs):        #改写initialize_request封装request对象
        request = super(ViewSetMixin, self).initialize_request(request, *args, **kwargs)
        method = request.method.lower()                #拿到请求方式
        if method == 'options':
            self.action = 'metadata'
        else:
            self.action = self.action_map.get(method)    #action_map在自己的as_view方法中被赋值     ->{'get':'list'}    self.action='list'
        return request


'''ModelViewSet封装的最后一层封装了增删改查与半自动路由'''


class ModelViewSet(mixins.CreateModelMixin,
                   mixins.RetrieveModelMixin,
                   mixins.UpdateModelMixin,
                   mixins.DestroyModelMixin,
                   mixins.ListModelMixin,
                   GenericViewSet):
    """
    A viewset that provides default `create()`, `retrieve()`, `update()`,
    `partial_update()`, `destroy()` and `list()` actions.
    """
    pass

 

posted @ 2019-04-12 13:02  SwZ1886  阅读(327)  评论(0)    收藏  举报