flask_wtf flask 的 CSRF 源代码初研究

因为要搞一个基于flask的前后端分离的个人网站,所以需要研究下flask的csrf防护原理.

用的扩展是flask_wtf,也算是比较官方的扩展库了.

先上相关源代码:

  1 def validate_csrf(data, secret_key=None, time_limit=None, token_key=None):
  2     """Check if the given data is a valid CSRF token. This compares the given
  3     signed token to the one stored in the session.
  4 
  5     :param data: The signed CSRF token to be checked.
  6     :param secret_key: Used to securely sign the token. Default is
  7         ``WTF_CSRF_SECRET_KEY`` or ``SECRET_KEY``.
  8     :param time_limit: Number of seconds that the token is valid. Default is
  9         ``WTF_CSRF_TIME_LIMIT`` or 3600 seconds (60 minutes).
 10     :param token_key: Key where token is stored in session for comparision.
 11         Default is ``WTF_CSRF_FIELD_NAME`` or ``'csrf_token'``.
 12 
 13     :raises ValidationError: Contains the reason that validation failed.
 14 
 15     .. versionchanged:: 0.14
 16         Raises ``ValidationError`` with a specific error message rather than
 17         returning ``True`` or ``False``.
 18     """
 19 
 20     secret_key = _get_config(
 21         secret_key, 'WTF_CSRF_SECRET_KEY', current_app.secret_key,
 22         message='A secret key is required to use CSRF.'
 23     )
 24     field_name = _get_config(
 25         token_key, 'WTF_CSRF_FIELD_NAME', 'csrf_token',
 26         message='A field name is required to use CSRF.'
 27     )
 28     time_limit = _get_config(
 29         time_limit, 'WTF_CSRF_TIME_LIMIT', 3600, required=False
 30     )
 31 
 32     if not data:
 33         raise ValidationError('The CSRF token is missing.')
 34 
 35     if field_name not in session:
 36         raise ValidationError('The CSRF session token is missing.')
 37 
 38     s = URLSafeTimedSerializer(secret_key, salt='wtf-csrf-token')
 39 
 40     try:
 41         token = s.loads(data, max_age=time_limit)
 42     except SignatureExpired:
 43         raise ValidationError('The CSRF token has expired.')
 44     except BadData:
 45         raise ValidationError('The CSRF token is invalid.')
 46 
 47     if not safe_str_cmp(session[field_name], token):
 48         raise ValidationError('The CSRF tokens do not match.')
 49 
 50 
 51 class CSRFProtect(object):
 52     """Enable CSRF protection globally for a Flask app.
 53 
 54     ::
 55 
 56         app = Flask(__name__)
 57         csrf = CsrfProtect(app)
 58 
 59     Checks the ``csrf_token`` field sent with forms, or the ``X-CSRFToken``
 60     header sent with JavaScript requests. Render the token in templates using
 61     ``{{ csrf_token() }}``.
 62 
 63     See the :ref:`csrf` documentation.
 64     """
 65 
 66     def __init__(self, app=None):
 67         self._exempt_views = set()
 68         self._exempt_blueprints = set()
 69 
 70         if app:
 71             self.init_app(app)
 72 
 73     def init_app(self, app):
 74         app.extensions['csrf'] = self
 75 
 76         app.config.setdefault('WTF_CSRF_ENABLED', True)
 77         app.config.setdefault('WTF_CSRF_CHECK_DEFAULT', True)
 78         app.config['WTF_CSRF_METHODS'] = set(app.config.get(
 79             'WTF_CSRF_METHODS', ['POST', 'PUT', 'PATCH', 'DELETE']
 80         ))
 81         app.config.setdefault('WTF_CSRF_FIELD_NAME', 'csrf_token')
 82         app.config.setdefault(
 83             'WTF_CSRF_HEADERS', ['X-CSRFToken', 'X-CSRF-Token']
 84         )
 85         app.config.setdefault('WTF_CSRF_TIME_LIMIT', 3600)
 86         app.config.setdefault('WTF_CSRF_SSL_STRICT', True)
 87 
 88         app.jinja_env.globals['csrf_token'] = generate_csrf        <><><><><><><><><><><><><><><><><><><>
 89         app.context_processor(lambda: {'csrf_token': generate_csrf})
 90 
 91         @app.before_request
 92         def csrf_protect():
 93             if not app.config['WTF_CSRF_ENABLED']:
 94                 return
 95 
 96             if not app.config['WTF_CSRF_CHECK_DEFAULT']:
 97                 return
 98 
 99             if request.method not in app.config['WTF_CSRF_METHODS']:
100                 return
101 
102             if not request.endpoint:
103                 return
104 
105             view = app.view_functions.get(request.endpoint)
106 
107             if not view:
108                 return
109 
110             if request.blueprint in self._exempt_blueprints:
111                 return
112 
113             dest = '%s.%s' % (view.__module__, view.__name__)
114 
115             if dest in self._exempt_views:
116                 return
117 
118             self.protect()
119 
120     def _get_csrf_token(self):
121         # find the ``csrf_token`` field in the subitted form
122         # if the form had a prefix, the name will be
123         # ``{prefix}-csrf_token``
124         field_name = current_app.config['WTF_CSRF_FIELD_NAME']
125 
126         for key in request.form:
127             if key.endswith(field_name):
128                 csrf_token = request.form[key]
129 
130                 if csrf_token:
131                     return csrf_token
132 
133         for header_name in current_app.config['WTF_CSRF_HEADERS']:
134             csrf_token = request.headers.get(header_name)
135 
136             if csrf_token:
137                 return csrf_token
138 
139         return None
140 
141     def protect(self):
142         if request.method not in current_app.config['WTF_CSRF_METHODS']:
143             return
144 
145         try:
146             validate_csrf(self._get_csrf_token())
147         except ValidationError as e:
148             logger.info(e.args[0])
149             self._error_response(e.args[0])
150 
151         if request.is_secure and current_app.config['WTF_CSRF_SSL_STRICT']:
152             if not request.referrer:
153                 self._error_response('The referrer header is missing.')
154 
155             good_referrer = 'https://{0}/'.format(request.host)
156 
157             if not same_origin(request.referrer, good_referrer):
158                 self._error_response('The referrer does not match the host.')
159 
160         g.csrf_valid = True  # mark this request as CSRF valid

 先说明下csrftoken的普通机制,上面代码中有一行代码后面被我加了一串<>符号,这行代码表明,默认的jinja2渲染的方式就是通过generate_csrf 方法生成csrftoken字符串,所以前后端分离的话,可以直接通过这个方法获取csrftoken,效果是一样的.

进入generate_csrf函数内部,会发现他做了这么点事:生成token,放在session里,然后返回一个加工过的token.这一块说明每当不同的访问触发该函数,那么服务器session内的csrftoken值就会不一样,所以,你可以这么做,获取一次之后在有效期(一个小时内)可以重复使用,但是不建议这么做.然后如果不是form表单提交的话,该csrf系统不会从json中获取token,而会从请求头获取,所以需要在请求头内添加关键字段:X-CSRFToken,将这个值赋值为获取的token即可.

首先获取csrftoken的方式: _get_csrf_token

会先从表单中查找关键字段,如果获取,那么返回该值,获取不到,从请求头获取,方式和django的基本一致,毕竟也就这两种规范方式.

 91         @app.before_request
 92         def csrf_protect():

这两行代码表明wtf是如何实现校验的,通过flask的钩子函数在每次请求开始时进行校验,这是在初始化wtf init_app(app)的时候就已经添加了该钩子函数.

在django里面,一旦中间件的process_request返回任何值,中间件即开始执行响应回调,视图不在执行,那么上面的两行代码下面好像不停地return了好多次,到底啥意思呢,只好再找源码看看.相关源码在下面:

    @setupmethod
    def before_request(self, f):
        """Registers a function to run before each request.

        For example, this can be used to open a database connection, or to load
        the logged in user from the session.

        The function will be called without any arguments. If it returns a
        non-None value, the value is handled as if it was the return value from
        the view, and further request handling is stopped.
        """
        self.before_request_funcs.setdefault(None, []).append(f)
        return f

可以看到添加钩子函数的装饰器执行了什么操作,他只是把钩子函数放进了一个函数列表里,然后我们看看这个函数列表是什么方式处理的.源码如下:

 

    def preprocess_request(self):
        """Called before the request is dispatched. Calls
        :attr:`url_value_preprocessors` registered with the app and the
        current blueprint (if any). Then calls :attr:`before_request_funcs`
        registered with the app and the blueprint.

        If any :meth:`before_request` handler returns a non-None value, the
        value is handled as if it was the return value from the view, and
        further request handling is stopped.
        """

        bp = _request_ctx_stack.top.request.blueprint

        funcs = self.url_value_preprocessors.get(None, ())
        if bp is not None and bp in self.url_value_preprocessors:
            funcs = chain(funcs, self.url_value_preprocessors[bp])
        for func in funcs:
            func(request.endpoint, request.view_args)

        funcs = self.before_request_funcs.get(None, ())
        if bp is not None and bp in self.before_request_funcs:
            funcs = chain(funcs, self.before_request_funcs[bp])
        for func in funcs:
            rv = func()
            if rv is not None:
                return rv

该方法的注释说明了,如果钩子函数返回任意不为空的数据,那么等同于视图的响应,所以仅仅return 不会导致钩子函数结束,仍然可以访问视图.

 现在可以解释def csrf_protect():函数的内容了,即,请求方式不在保护范围内时,跳过校验,未开启防护时,跳过校验,视图无效时跳过校验.

 

csrf_protect 中会执行 protect ,protect 会执行 validate_csrf(),validate_csrf()是校验的关键,源代码如下:

def validate_csrf(data, secret_key=None, time_limit=None, token_key=None):
    """Check if the given data is a valid CSRF token. This compares the given
    signed token to the one stored in the session.

    :param data: The signed CSRF token to be checked.
    :param secret_key: Used to securely sign the token. Default is
        ``WTF_CSRF_SECRET_KEY`` or ``SECRET_KEY``.
    :param time_limit: Number of seconds that the token is valid. Default is
        ``WTF_CSRF_TIME_LIMIT`` or 3600 seconds (60 minutes).
    :param token_key: Key where token is stored in session for comparision.
        Default is ``WTF_CSRF_FIELD_NAME`` or ``'csrf_token'``.

    :raises ValidationError: Contains the reason that validation failed.

    .. versionchanged:: 0.14
        Raises ``ValidationError`` with a specific error message rather than
        returning ``True`` or ``False``.
    """

    secret_key = _get_config(
        secret_key, 'WTF_CSRF_SECRET_KEY', current_app.secret_key,
        message='A secret key is required to use CSRF.'
    )
    field_name = _get_config(
        token_key, 'WTF_CSRF_FIELD_NAME', 'csrf_token',
        message='A field name is required to use CSRF.'
    )
    time_limit = _get_config(
        time_limit, 'WTF_CSRF_TIME_LIMIT', 3600, required=False
    )

    if not data:
        raise ValidationError('The CSRF token is missing.')

    if field_name not in session:
        raise ValidationError('The CSRF session token is missing.')

    s = URLSafeTimedSerializer(secret_key, salt='wtf-csrf-token')

    try:
        token = s.loads(data, max_age=time_limit)
    except SignatureExpired:
        raise ValidationError('The CSRF token has expired.')
    except BadData:
        raise ValidationError('The CSRF token is invalid.')

    if not safe_str_cmp(session[field_name], token):
        raise ValidationError('The CSRF tokens do not match.')

该方法前面部分就是在获取相关秘钥和关键字,如果不自己自定义的话,这一块通常不会出问题,后面可以看到,方法会从全局变量session中寻找csrftoken字段名,然后最后一步进行校验,所以,wtf是通过比对session中的CSRFtoken和表单中的csrftoken是否一致.

 所以前后端分离方式开发的话,需要将csrftoken通过接口或者cookie的方式传给前端,前端将该部分数据取出保存,提交表单的时候带上.

至于关键字,最上面那段代码写的很清楚,默认的,表单是csrf_token, 请求头是 X-CSRFToken.


posted @ 2019-06-18 11:51  华腾海神  阅读(1299)  评论(0编辑  收藏  举报