Loading

flask源码分析

源码分析

准备工作

另外,在粘贴源码时,删除了无关部分,只挑出与执行流程相关的主干部分。

启动流程分析

从helloworld读起,这是flask官方文档上提供的代码:

from flask import Flask

app = Flask(__name__)

@app.route("/")
def hello_world():
    return "<p>Hello, World!</p>"
if __name__ == '__main__':
    app.run()

执行这段代码,执行app.run(),项目就启动了,也就是说,启动的代码一定在run方法里。

def run(self, host=None, port=None, debug=None,
        load_dotenv=True, **options):
    if not host:
        if sn_host:
            host = sn_host
        else:
            host = "127.0.0.1"

    if port or port == 0:
        port = int(port)
    elif sn_port:
        port = int(sn_port)
    else:
        port = 5000

    from werkzeug.serving import run_simple

    try:
        run_simple(host, port, self, **options)  # 这里的self就是app对象
    finally:
        self._got_first_request = False

run主要做了这些事:设置主机和端口号,然后处理一些其他参数(此处未截取),最后调用werkzeug.serving里的run_simple方法。这个方法封装了一些参数,底层使用socket,无限循环等待请求到达。当HTTP请求到来的时候,它将其解析为WSGI格式,然后werkzeug.serving:WSGIRequestHandler调用了app去执行相应的处理:

try:
    execute(self.server.app)
# 这里的self.server.app就是我们的app
def execute(app: "WSGIApplication") -> None:
    application_iter = app(environ, start_response)
    try:
        for data in application_iter:
            write(data)
        if not headers_sent:
            write(b"")
    finally:
        if hasattr(application_iter, "close"):
            application_iter.close()  # type: ignore

可以看到这里调用了 application_iter = app(environ, start_response),相当于app(),也就是说会调用Flask.__call__方法。

def __call__(self, environ: dict, start_response: t.Callable) -> t.Any:
    """The WSGI server calls the Flask application object as the
    WSGI application. This calls :meth:`wsgi_app`, which can be
    wrapped to apply middleware.
    """
    return self.wsgi_app(environ, start_response)

调用了wsgi_app,直接点过来看:

def wsgi_app(self, environ: dict, start_response: t.Callable) -> t.Any:
    # 1.调用了request_context创建request对象
    ctx = self.request_context(environ)
    error: t.Optional[BaseException] = None
    try:
        try:
            ctx.push() # 2.将ctx压栈
            # 3.调用full_dispatch_request方法
            response = self.full_dispatch_request()
        except Exception as e:
            # 调用错误处理
            error = e
            response = self.handle_exception(e)
        except:
            error = sys.exc_info()[1]
            raise
        return response(environ, start_response)
    finally:
        if self.should_ignore_error(error):
            error = None
        # 4.最后将ctx出栈
        ctx.auto_pop(error)

1 request_context

这里的内容比较多,分开来研究,首先是request_context,如下:

def request_context(self, environ: dict) -> RequestContext:
	# 这里的self是app对象;将app对象当做第一个参数,实例化一个RequestContext对象
    return RequestContext(self, environ)

class RequestContext:
    # RequestContext类,用于包装一个request对象
    def __init__(self,app,environ,request=None,session=None):
        self.app = app
        if request is None:
            request = app.request_class(environ) # 这里的request如果为None,默认会创建一个
        self.request = request
        self.url_adapter = None
        try:
            self.url_adapter = app.create_url_adapter(self.request)
        except HTTPException as e:
            self.request.routing_exception = e
        self.flashes = None
        self.session = session  # 注意这里的session最开始为None
        self._implicit_app_ctx_stack: t.List[t.Optional["AppContext"]] = []
        self.preserved = False
        self._preserved_exc = None
        self._after_request_functions: t.List[AfterRequestCallable] = []

2 ctx.push

然后是ctx.push(),也就是RequestContext类的push方法。

def push(self):
    # 由于是ctx.push,所以此时的self是我们刚刚RequestContext类的对象

    app_ctx = _app_ctx_stack.top
    if app_ctx is None or app_ctx.app != self.app:
        # 创建app_context对象,里面存放了app对象和g
        app_ctx = self.app.app_context()
        # app_context入栈
        app_ctx.push()
        self._implicit_app_ctx_stack.append(app_ctx)
    else:
        self._implicit_app_ctx_stack.append(None)

	# 这里_request_ctx_stack是flask项目的全局变量,它是LocalStack的对象
    _request_ctx_stack.push(self)

    # 只有第一次push时,session为None
    if self.session is None:
        # self.app就是我们创建的app
        session_interface = self.app.session_interface
        # 将app对象、request对象放入session
        self.session = session_interface.open_session(self.app, self.request)

        if self.session is None:
            # 如果session仍然为空,则创建一个空session
            self.session = session_interface.make_null_session(self.app)
    if self.url_adapter is not None:
        self.match_request()

稍微研究一下 _request_ctx_stack.push(self)

class LocalStack:
    def __init__(self) -> None:
        # _local是Local类的对象
        self._local = Local()
    def push(self, obj: t.Any) -> t.List[t.Any]:
        # 这里的obj就是前面传过来的self,也就是RequestContext类的对象
        rv = getattr(self._local, "stack", []).copy() # 反射取值,第一次调用的时候,由于没有"stack",返回一个列表
        # rv=[]
        rv.append(obj) # 将obj添加进列表中
        # rv=[obj]
        # _local是这种格式:{当次请求的线程id:{},请求的线程id:{},...}
        # 多个请求到达时,创建多个线程local的特点是:
        # 每个线程操作数据时,根据线程id从local里操作数据(同时支持gevent),所以保证了每个请求之间相互独立地操作自己的数据
        self._local.stack = rv # 把rv放进stack
        # _local:{当次请求的线程id:{"stack":[obj]},请求的线程id:{"stack":[obj]},...}
        return rv

3 full_dispatch_request

def full_dispatch_request(self) -> Response:
    # 检查app是不是已经处理了一个请求,如果是第一次处理请求,并且定义了@app.before_first_request装饰器,那么就会先执行一遍那些请求扩展函数
    self.try_trigger_before_first_request_functions()
    try:
        request_started.send(self) # 发送信号
        # 同理,执行@app.before_request注册的函数
        rv = self.preprocess_request()
        if rv is None:
            # 如果在before_request函数中返回None,则进行路由分发;如果不是None,则直接跳到最后
            rv = self.dispatch_request()
    except Exception as e:
        rv = self.handle_user_exception(e)
    # 最后执行@app.after_request,并且处理session
    return self.finalize_request(rv)

除了请求扩展的这些函数之外,最关键的就是self.dispatch_request,关于路由分发和异常处理,暂时先不深入,先把整个请求流程走完。

4 ctx.auto_pop

不管前面的处理是否发生异常,最终执行ctx.auto_pop(error),把这次请求的ctx对象出栈

def auto_pop(self, exc: t.Optional[BaseException]) -> None:
    if self.request.environ.get("flask._preserve_context") or (
            exc is not None and self.app.preserve_context_on_exception
    ): # 异常相关,大概意思是如果发生了错误,并且在配置文件中配置了保存错误信息,那么就把错误保存
        self.preserved = True
        self._preserved_exc = exc  # type: ignore
    else:
        # 最终,出栈,如果没发生错误,这里的exc=None
        self.pop(exc)

调用ctx.pop出栈

def pop(self, exc: t.Optional[BaseException] = _sentinel) -> None:  # type: ignore
    # 从另一个全局变量栈中弹出app_ctx
    app_ctx = self._implicit_app_ctx_stack.pop()
    clear_request = False

    try:
        # 如果此时_implicit_app_ctx_stack栈中已经没有元素了,就执行下面代码
        if not self._implicit_app_ctx_stack:
            self.preserved = False
            self._preserved_exc = None
            if exc is _sentinel:
                exc = sys.exc_info()[1]
            self.app.do_teardown_request(exc)

            request_close = getattr(self.request, "close", None)
            if request_close is not None:
                request_close()
            clear_request = True
    finally:
        # 和入栈同理,最终将ctx出栈
        rv = _request_ctx_stack.pop()

整个请求流程结束。

路由分发

上一节梳理了大致流程,这里详细看一看full_dispatch_request里面的self.dispatch_request,也就是路由分发的过程。

def dispatch_request(self) -> ResponseReturnValue:
    # 从_request_ctx_stack栈中获取到当次请求的request
    req = _request_ctx_stack.top.request
    if req.routing_exception is not None:
        # 判断是否异常,如果异常就抛出错误
        self.raise_routing_exception(req)
    # 如果request对象没有异常,就获取url_rule路径,比如访问127.0.0.1:5000,那么url_rule就匹配为/
    rule = req.url_rule
    
    if (
        getattr(rule, "provide_automatic_options", False)
        and req.method == "OPTIONS"
    ):# 如果我们提供了provide_automatic_options,并且请求的方法是OPTIONS那么就会执行这里
        return self.make_default_options_response()

    # 从self.view_functions字典中寻找endpoint对应的视图函数,把参数传过去,然后调用它
    return self.ensure_sync(self.view_functions[rule.endpoint])(**req.view_args)

这里又出现了_request_ctx_stack,前面提到过,它就是一个全局变量,每次请求到达的时候就会把当次的请求保存到栈里,使得我们可以在整个请求处理过程中使用它。上一节也简单介绍了这个栈的特点,这里暂时先不深入研究,我们在后面单独讲。总之,dispatch_request就做了一件事,那就是获取当前请求的request对象,然后把它交给某个合适的函数(我们写的视图函数)处理。那么,这里的路由url_rule是什么时候匹配的?

在上一节request_context里,创建ctx,也就是ctx = self.request_context(environ)这行,最终会调用RequestContext.__init__初始化方法。在这个初始化方法里,会执行这一句app.create_url_adapter(self.request),我们来看看这个方法:

def create_url_adapter(self, request: t.Optional[Request]) -> t.Optional[MapAdapter]:
    # 由于调用的时候传入了当前的request,所以不为空
    if request is not None:
        # 把app的url_map绑定到WSGI environ变量上
        return self.url_map.bind_to_environ(
            request.environ,
            server_name=self.config["SERVER_NAME"],
            subdomain=subdomain,
        )

    # 如果request为空,但是SERVER_NAME不为空那么也会绑定
    if self.config["SERVER_NAME"] is not None:
        return self.url_map.bind(
            self.config["SERVER_NAME"],
            script_name=self.config["APPLICATION_ROOT"],
            url_scheme=self.config["PREFERRED_URL_SCHEME"],
        )
    # 如果连名字也没有,返回None
    return None

无论是bind_to_environ还是bind,它们的效果相同,最终都会返回一个werkzeug.routing.MapAdapter对象,这个对象主要用于url的匹配,这涉及到werkzeug的源码,这里就不深入研究了,想了解的话建议查阅文档。总之记住,此时ctx.url_adapter不为空,这个很重要。

还是在上一节ctx.push里,我们阅读了RequestContext类的push源码,其实留下了一些代码没有讲,我把它留到这里是因为,除了将ctx压栈之外,同时还进行了路由匹配,这就是最后两句所做的:

def push(self): # self是当前请求的ctx对象,实际也就是RequestContext类的对象
    ....
    if self.url_adapter is not None:
        self.match_request()

还记得吗,此时ctx.url_adapter不为空,所以会调用self.match_request()

def match_request(self) -> None:
    try:
        # 进行路由匹配
        result = self.url_adapter.match(return_rule=True)
        # 获得匹配结果,将路由url_rule和参数view_args保存到request对象里
        self.request.url_rule, self.request.view_args = result
    except HTTPException as e:
        self.request.routing_exception = e

这里调用的是ctx.url_adaptermatch方法,底层是由werkzeug实现的。匹配到的路径保存在了ctx.request对象里,交给dispatch_request,最后匹配到我们写的视图函数。

上下文之Local

在前面我们多次提到flask的全局变量,比如ctx.push介绍的源码里_request_ctx_stack或者是_local。它们是什么?这就是flask中的上下文机制。从一个简单的例子开始:

from flask import Flask,request

app = Flask(__name__)

@app.route("/",endpoint="root")
def hello_world():
    if request.method =="GET":
        print(request)
    return "hello!"

是否注意到,当你在视图函数里使用request对象,无论是request.method,还是print(request),每当你调用的时候,request都是当前请求的request,换句话说,假设此时另一个人也发送了同样的请求,flask是如何保证你获取到的request来自于你?答案就是全局变量。但是会有一个问题,我们先来看下面的例子

from threading import Thread
import time
s = 0
def add(num):
    global s
    s=num
    time.sleep(1)
    print(s)

if __name__ == '__main__':
    for i in range(5):
        trd = Thread(target=add,args=[i])
        trd.start()

这里创建了5个线程,第一个线程修改loacl为1,第二个修改local为2,以此类推,在阻塞了1秒之后,最终输出的结果全都是4。这就是数据的不安全问题,这也是进程(和线程)中的经典问题。要解决这个问题有很多种办法,比如皮特森算法、PV原语、加锁等等,这已经是操作系统的范畴,就不展开了。为数据加锁是可行的,但是加锁适合多个线程共用一个数据的情况,而request对象每个请求都应该不一样,也就是说想要实现该线程对于请求对象的修改并不影响其他线程,就是threading.local

from threading import Thread,local
import time
s = local()
def add(num):
    s.val=num
    time.sleep(1)
    print(s.val)

if __name__ == '__main__':
    for i in range(5):
        trd = Thread(target=add,args=[i])
        trd.start()

local的原理是,为每个线程复制一份数据。其实我们也可以使用字典来实现这一功能:

from threading import get_ident,Thread,current_thread

Local = {}
def set(k,v):
    curr=get_ident()
    if curr in Local: # 如果字典里有当前线程id就直接存
        Local[curr][k]=v
    else: # 如果没有线程id就创建一个
        Local[curr]={k:v}
def get(k):
    curr=get_ident()
    return Local[curr][k]

def add(num):
    set("val",num)

if __name__ == '__main__':
    for i in range(5):
        trd = Thread(target=add,args=[i])
        trd.start()
    print(Local)

最终效果就是这个样子:Local={13324: {'val': 0}, 4164: {'val': 1}, 12764: {'val': 2}, 3120: {'val': 3}, 2004: {'val': 4}},每个线程都各自用自己的id区分开,取值时互不干扰。于是进一步,把它封装在类中以供调用:

from threading import get_ident,Thread,current_thread
import time

class Local(object):
    storage = {}
    get_ident = get_ident

    def __setattr__(self, k, v):
        ident =self.get_ident()
        origin = self.storage.get(ident)
        if not origin:
            origin={}
        origin[k] = v
        self.storage[ident] = origin
    def __getattr__(self, k):
        ident = self.get_ident()
        v= self.storage[ident].get(k)
        return v

locals_values = Local()
def func(num):

    locals_values.KEY=num
    time.sleep(2)
    print(locals_values.KEY,current_thread().name)

for i in range(10):
    t = Thread(target=func,args=(i,),name='线程%s'%i)
    t.start()

这样实现了需求,但是有一个小问题,那就是如果创建多个Local对象,它们共用同一个字典,我们想做的是把这个字典放在对象中:

# 这样可以支持协程
try:
    from greenlet import getcurrent as get_ident
except Exception as e:
    from threading import get_ident
# 每个对象都有自己的storage
from threading import get_ident,Thread
import time
class Local(object):
    def __init__(self):
       object.__setattr__(self,'storage',{}) # 用父类设置对象的属性,避免循环调用
       # self.storage={}  # 注意不要写成这样,因为self.storage={}会调用`__setattr__(self,"storage",{})`
    def __setattr__(self, k, v):
        ident = get_ident()
        if ident in self.storage: # 当执行到这句的时候,self.storage会调用`__getattr__(self,"storage")`
            self.storage[ident][k] = v
        else:
            self.storage[ident] = {k: v}
    def __getattr__(self, k):
        ident = get_ident()
        return self.storage[ident][k] # 当执行到这句的时候,会调用`__getattr__(self,"storage")`,会一直在这里无限循环
obj = Local()
def task(arg):
    obj.val = arg
    time.sleep(1)
    print(obj.val)
for i in range(10):
    t = Thread(target=task,args=(i,))
    t.start()

werkzeug就使用了上述的代码的思想,并且更强大:

  • 会在协程可用的情况下优先使用协程
  • 自定义了__release_local__释放资源
  • 自定义LocalStack:栈,可以像栈一样操作Local,包括入栈、出栈、获取栈顶元素
  • 自定义LocalProxy:即Local代理,把对自己的操作转发给内部的__local对象

flask基于werkzeug实现,所以它理所当然地继承了这些特性。

上下文之Context

flask 中有两种上下文:application contextrequest context,前者用于存储app相关的信息,后者用于存储请求相关的信息。有关它们的定义在globals.py文件中,这个文件代码不多,如下:

def _lookup_req_object(name):
    top = _request_ctx_stack.top  # top就是ctx
    if top is None:
        raise RuntimeError(_request_ctx_err_msg)
    return getattr(top, name)


def _lookup_app_object(name):
    top = _app_ctx_stack.top
    if top is None:
        raise RuntimeError(_app_ctx_err_msg)
    return getattr(top, name)


def _find_app():
    top = _app_ctx_stack.top
    if top is None:
        raise RuntimeError(_app_ctx_err_msg)
    return top.app


# context locals
_request_ctx_stack = LocalStack()
_app_ctx_stack = LocalStack()
current_app: "Flask" = LocalProxy(_find_app)  # type: ignore
request: "Request" = LocalProxy(partial(_lookup_req_object, "request"))  # type: ignore
session: "SessionMixin" = LocalProxy(  # type: ignore
    partial(_lookup_req_object, "session")
)
g: "_AppCtxGlobals" = LocalProxy(partial(_lookup_app_object, "g"))  # type: ignore

application context 包括current_app grequest context包括requestsession。这里也会发现两个单例的LocalStack对象,它们提供了数据隔离的栈访问。来看看它是怎么写的

class LocalStack:

    def __init__(self) -> None:
        self._local = Local()  # 这里的__local就是Local对象

    def __release_local__(self) -> None:
        self._local.__release_local__() # 用于清空当前线程或者协程的栈数据

    @property
    def __ident_func__(self) -> t.Callable[[], int]:
        return self._local.__ident_func__

    @__ident_func__.setter
    def __ident_func__(self, value: t.Callable[[], int]) -> None:
        object.__setattr__(self._local, "__ident_func__", value)

    def __call__(self) -> "LocalProxy":
        def _lookup() -> t.Any:
            rv = self.top
            if rv is None:
                raise RuntimeError("object unbound")
            return rv

        return LocalProxy(_lookup)

    def push(self, obj: t.Any) -> t.List[t.Any]:
        """Pushes a new item to the stack"""
        rv = getattr(self._local, "stack", []).copy()
        rv.append(obj)
        self._local.stack = rv
        return rv  # type: ignore

    def pop(self) -> t.Any:
        """Removes the topmost item from the stack, will return the
        old value or `None` if the stack was already empty.
        """
        stack = getattr(self._local, "stack", None)
        if stack is None:
            return None
        elif len(stack) == 1:
            release_local(self._local)
            return stack[-1]
        else:
            return stack.pop()

    @property
    def top(self) -> t.Any:
        """The topmost item on the stack.  If the stack is empty,
        `None` is returned.
        """
        try:
            return self._local.stack[-1]
        except (AttributeError, IndexError):
            return None

它主要有pushpoptop 方法,在上一节ctx.push简单介绍了push方法,结合请求流程的分析,也验证了request context存储了requestsession。这里的__call__方法返回当前线程或协程栈顶元素的代理对象,暂时不理解?继续往下看。

我们说LocalProxyLocal对象的代理,源码:

class LocalProxy:
    __slots__ = ("__local", "__name", "__wrapped__")

    def __init__(
        self,
        local: t.Union["Local", t.Callable[[], t.Any]],
        name: t.Optional[str] = None,
    ) -> None:
        object.__setattr__(self, "_LocalProxy__local", local) # 设置self.__local
        object.__setattr__(self, "_LocalProxy__name", name) # 设置self.__name

        if callable(local) and not hasattr(local, "__release_local__"):
            # "local" is a callable that is not an instance of Local or
            # LocalManager: mark it as a wrapped function.
            object.__setattr__(self, "__wrapped__", local) 

    def _get_current_object(self) -> t.Any:
        if not hasattr(self.__local, "__release_local__"):  # type: ignore
            return self.__local()  # type: ignore

        try:
            return getattr(self.__local, self.__name)  # type: ignore
        except AttributeError:
            name = self.__name  # type: ignore
            raise RuntimeError(f"no object bound to {name}") from None

    __doc__ = _ProxyLookup(  # type: ignore
        class_value=__doc__, fallback=lambda self: type(self).__doc__
    )
    # __del__ should only delete the proxy
    __repr__ = _ProxyLookup(  # type: ignore
        repr, fallback=lambda self: f"<{type(self).__name__} unbound>"
    )
    __str__ = _ProxyLookup(str)  # type: ignore
    __bytes__ = _ProxyLookup(bytes)

self.__local是在init方法中存储的,由于是双下划线的私有方法,所以才使用__LocalProxy__local这种格式存值。并且它重写了所有的魔法方法(这里只截取一小部分),_get_current_object就是获取当前请求的对象。

回到这里request=LocalProxy(partial(_lookup_req_object, "request")),这里的partial是偏函数,可以在调用之前为函数提前赋值(不执行),举个例子:

from functools import partial
def add(a,b,c):
    print(a+b+c)

func  = partial(add,1,2)  # 为add函数提前传两个值,返回一个func
func(3) # 再调用只需要传剩下的值

关于partial的原理,以后有机会新开一篇讲讲。理解了偏函数,这里就相当于:

def _lookup_req_object(name):
    top = _request_ctx_stack.top
    if top is None:
        raise RuntimeError(_request_ctx_err_msg)
    return getattr(top, name)


request= LocalProxy(partial(_lookup_req_object, "request"))
# func = _lookup_req_object("request")
# 相当于request= LocalProxy(func),调用init方法:
class LocalProxy:
    def __init__(self,local,name=None):
# local:func
# name:None

因此_get_current_object返回的self.__local(),实际上就是func(),最终从ctx里找到并返回request对象。当我们在视图函数里print(request)时,就会调用LocalProxy里的__str__,同理当你request.meathod的时候,就会调用对应的__getattr__方法。这就是一个典型的代理模式使用。

理解了request,那么session也是同理,让我们结合所有知识重新梳理一下请求流程:

  • 当请求到达的时候,首先创建request context,里面保存了当前线程的requestsession
  • 然后ctx.push进栈操作:将request context保存到_request_ctx_stack;将app_context保存到_app_ctx_stack
  • 进行路由分发,交给视图处理
  • 视图中如果需要使用当前请求对象(比如print(request)),就会触发代理LocalProxy对应的魔法方法(比如__str__
  • 请求结束,将ctx出栈并进行清理工作

补充:application context 包括current_app g,前者是app相关的参数,而g是什么?其实它就是一个供我们使用的全局变量,在整个请求的生命周期内,都可以使用g来赋值和取值。你可能会问,已经有这么多全局变量,我难道不可以在current_app或者request中放吗,何必要用g

答案是确实可以,这些全局变量都存储了一些参数,比如request里存储了method字段,当我们赋值的时候很可能会覆盖掉原有字段。为了防止可能出现的错误,使用专门为我们打造的g就是不错的选择。但是使用时要注意g只在请求周期中存在,当一次请求结束,g也就销毁了,这和session不同,session在过期时间内,不同的请求使用的是同一个session

session

ctx.push这一节里提到了session的操作过程,如下

# 只有第一次push时,session为None
if self.session is None:
    # self.app就是我们创建的app
    session_interface = self.app.session_interface 
    # 将app对象、request对象放入session
    self.session = session_interface.open_session(self.app, self.request)

    if self.session is None:
        # 如果session仍然为空,则创建一个空session
        self.session = session_interface.make_null_session(self.app)

我们来深入研究一下,首先是session_interface,它默认(可以自定义,比如使用flask-session第三方插件)是一个位于sessions.py中的SecureCookieSessionInterface类。后面调用的open_sessionmake_null_session都是它的方法:

class SecureCookieSessionInterface(SessionInterface):
    # 加密盐,与secret_key配合使用
    salt = "cookie-session"
    # 加密算法默认是sha1
    digest_method = staticmethod(hashlib.sha1)
    # 签名算法,默认是hmac
    key_derivation = "hmac"
    # 序列化器,支持一些python的数据结构
    serializer = session_json_serializer
    session_class = SecureCookieSession

    def get_signing_serializer(
        self, app: "Flask"
    ) -> t.Optional[URLSafeTimedSerializer]:
        # 如果我们没有定义secret_key那么就返回None
        if not app.secret_key:
            return None
        signer_kwargs = dict(
            key_derivation=self.key_derivation, digest_method=self.digest_method
        )
        # 传入参数进行序列化,生成URL安全的字符串
        return URLSafeTimedSerializer(
            app.secret_key,
            salt=self.salt,
            serializer=self.serializer,
            signer_kwargs=signer_kwargs,
        )

    def open_session(
        self, app: "Flask", request: "Request"
    ) -> t.Optional[SecureCookieSession]:
        # 调用get_signing_serializer,获取
        s = self.get_signing_serializer(app)
        if s is None:
            # 没有定义secret_key那么就返回None
            return None
        # 通过请求的cookie获取session对象
        val = request.cookies.get(self.get_cookie_name(app))
        if not val:
            # 调用session_class类返回一个session对象
            return self.session_class()
        # 过期时间,默认31天,这些参数都是从app对象中获取的
        max_age = int(app.permanent_session_lifetime.total_seconds())
        try:
            # 验证session数据是否被篡改
            data = s.loads(val, max_age=max_age)
            return self.session_class(data)
        except BadSignature:
            return self.session_class()

URLSafeTimedSerializer是借助于itsdangerous实现的,它可以进行数据验证,生成URL安全的字符串。

默认的session_classSecureCookieSession,它本质是一个字典,在字典的基础上封装了一些参数,重写了魔法方法,内部的实现不是很复杂,感兴趣可以看看。

接下来我们看看session是怎么保存的,在请求来的时候flask会获取session,并保存在上下文中让视图函数可以获取并修改它,在响应返回时,也会自动把session写回到cookie中。

full_dispatch_request小节中,请求结束会调用 self.finalize_request(),其中就会调用process_response方法返回响应,其中的这部分就是session处理:

def process_response(self, response: Response) -> Response:
    # 如果需要返回session就调用save_sessoin
    if not self.session_interface.is_null_session(ctx.session):
        self.session_interface.save_session(self, ctx.session, response)

    return response

调用

def save_session(
    self, app: "Flask", session: SessionMixin, response: "Response"
) -> None:
    name = self.get_cookie_name(app)
    domain = self.get_cookie_domain(app)
    path = self.get_cookie_path(app)
    secure = self.get_cookie_secure(app)
    samesite = self.get_cookie_samesite(app)

    # 如果session被设置为空,那么直接不设置cookie
    if not session:
        if session.modified:
            response.delete_cookie(
                name, domain=domain, path=path, secure=secure, samesite=samesite
            )

        return
    # 是否需要设置cookie
    if not self.should_set_cookie(app, session):
        return
    # 配置参数,设置cookie
    httponly = self.get_cookie_httponly(app)
    expires = self.get_expiration_time(app, session)
    val = self.get_signing_serializer(app).dumps(dict(session))  # type: ignore
    response.set_cookie(
        name,
        val,  # type: ignore
        expires=expires,
        httponly=httponly,
        domain=domain,
        path=path,
        secure=secure,
        samesite=samesite,
    )

信号

在基本使用中,简单介绍了信号的使用。现在我们来看看信号的实现。在上文信号中,介绍了flask内置的10种信号,并提到这些信号预先被放置在了某些位置,但是不会执行,直到我们订阅才执行。那就从源码里分别来找找这十个信号的位置:

1 template_rendered和before_render_template

模板渲染前、后触发。位于templates.py中:

def render_template(template_name_or_list, **context):
    # 获取上下文
    ctx = _app_ctx_stack.top
    # 将模板更新到上下文
    ctx.app.update_template_context(context)
    # 调用_render
    return _render(
        ctx.app.jinja_env.get_or_select_template(template_name_or_list),
        context,
        ctx.app,
    )

def _render(template: Template, context: dict, app: "Flask") -> str:
    """Renders the template and fires the signal"""
    #--------------------渲染模板之前信号触发--------------------------#
    before_render_template.send(app, template=template, context=context)
    # 渲染模板
    rv = template.render(context)
    #--------------------渲染模板之后信号触发--------------------------#
    template_rendered.send(app, template=template, context=context)
    return rv

2 appcontext_pushed和appcontext_popped

上下文入栈、出栈时触发,位于ctx.py中:

class AppContext:
    def push(self) -> None:
        """Binds the app context to the current context."""
        self._refcnt += 1
        _app_ctx_stack.push(self)
        #--------------------入栈信号触发--------------------------#
        appcontext_pushed.send(self.app)

    def pop(self, exc: t.Optional[BaseException] = _sentinel) -> None:  # type: ignore
        """Pops the app context."""
        try:
            self._refcnt -= 1
            if self._refcnt <= 0:
                if exc is _sentinel:
                    exc = sys.exc_info()[1]
                self.app.do_teardown_appcontext(exc)
        finally:
            rv = _app_ctx_stack.pop()
        assert rv is self, f"Popped wrong app context.  ({rv!r} instead of {self!r})"
        #--------------------出栈信号触发--------------------------#
        appcontext_popped.send(self.app)

3 request_started和request_finished

位于app.py

def full_dispatch_request(self) -> Response:
    self.try_trigger_before_first_request_functions()
    try:
        #--------------------请求到达前信号触发--------------------------#
        request_started.send(self)
        rv = self.preprocess_request()
        if rv is None:
            rv = self.dispatch_request()
    except Exception as e:
        rv = self.handle_user_exception(e)
    return self.finalize_request(rv)

def finalize_request(
    self,
    rv: t.Union[ResponseReturnValue, HTTPException],
    from_error_handler: bool = False,
) -> Response:

    response = self.make_response(rv)
    try:
        response = self.process_response(response)
        #--------------------请求结束后信号触发--------------------------#
        request_finished.send(self, response=response)
    except Exception:
        if not from_error_handler:
            raise
        self.logger.exception(
            "Request finalizing failed with an error while handling an error"
        )
    return response

4 request_tearing_down

位于app.py

def do_teardown_request(
    self, exc: t.Optional[BaseException] = _sentinel  # type: ignore
) -> None:
    if exc is _sentinel:
        exc = sys.exc_info()[1]

    for name in chain(request.blueprints, (None,)):
        if name in self.teardown_request_funcs:
            for func in reversed(self.teardown_request_funcs[name]):
                self.ensure_sync(func)(exc)
    #--------------------请求执行完毕信号触发--------------------------#
    request_tearing_down.send(self, exc=exc)

5 appcontext_tearing_down

位于app.py

def do_teardown_appcontext(
    self, exc: t.Optional[BaseException] = _sentinel  # type: ignore
) -> None:
    if exc is _sentinel:
        exc = sys.exc_info()[1]

    for func in reversed(self.teardown_appcontext_funcs):
        self.ensure_sync(func)(exc)
#--------------------上下文完毕信号触发--------------------------#
    appcontext_tearing_down.send(self, exc=exc)

6 got_request_exception

位于app.py

def handle_exception(self, e: Exception) -> Response:

    exc_info = sys.exc_info()
    #--------------------出现异常时信号触发--------------------------#
    got_request_exception.send(self, exception=e)

    if self.propagate_exceptions:
        # Re-raise if called with an active exception, otherwise
        # raise the passed in exception.
        if exc_info[1] is e:
            raise

        raise e

7 message_flashed

位于helpers.py

def flash(message: str, category: str = "message") -> None:

    flashes = session.get("_flashes", [])
    flashes.append((category, message))
    session["_flashes"] = flashes
    #--------------------调用flash添加数据自动触发--------------------------#
    message_flashed.send(
        current_app._get_current_object(),  # type: ignore
        message=message,
        category=category,
    )

找到源码里信号触发的位置,再加上启动流程的分析,我们就可以驾轻就熟地使用这些信号了。

接下来,我们研究一下信号机制的源码,其定义是在signals.py中:

import typing as t

try:
    from blinker import Namespace
    # Namespace如果导入成功说明我们安装了blinker,于是signals_available设置为True表示信号机制启动
    signals_available = True
except ImportError:
    # 如果没有安装blinker
    signals_available = False
    # flask自己定义了一个Namespace
    class Namespace:  # type: ignore
        def signal(self, name: str, doc: t.Optional[str] = None) -> "_FakeSignal":
            return _FakeSignal(name, doc)
    # 创建一个虚假的信号系统,它允许触发信号,但是什么都不做
    class _FakeSignal:
        def __init__(self, name: str, doc: t.Optional[str] = None) -> None:
            self.name = name
            self.__doc__ = doc

        def send(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
        # send方法直接pass掉
            pass

        def _fail(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
            raise RuntimeError(
                "Signalling support is unavailable because the blinker"
                " library is not installed."
            ) from None

        connect = connect_via = connected_to = temporarily_connected_to = _fail
        disconnect = _fail
        has_receivers_for = receivers_for = _fail
        del _fail


# 信号的命名空间
_signals = Namespace()


# 定义的十个信号
template_rendered = _signals.signal("template-rendered")
before_render_template = _signals.signal("before-render-template")
request_started = _signals.signal("request-started")
request_finished = _signals.signal("request-finished")
request_tearing_down = _signals.signal("request-tearing-down")
got_request_exception = _signals.signal("got-request-exception")
appcontext_tearing_down = _signals.signal("appcontext-tearing-down")
appcontext_pushed = _signals.signal("appcontext-pushed")
appcontext_popped = _signals.signal("appcontext-popped")
message_flashed = _signals.signal("message-flashed")

flask调用了blinker提供信号,接下来的内容不是flask的源码。

我们以其中一个信号创建为例,读一读内部是怎么执行的。首先调用signal方法:

class Namespace(dict):
    """A mapping of signal names to signals."""
    # 信号名称与信号的映射,是一个字典
    def signal(self, name, doc=None):
        try:
            # 第一次创建时,字典里没有字符串对应的值
            return self[name]
        except KeyError:
            # setdefault:如果键不存在于字典中,将会添加键并将值设为默认值
            return self.setdefault(name, NamedSignal(name, doc))

第一次调用,被异常捕获,调用NamedSignal的初始化方法:

class NamedSignal(Signal):
    """A named generic notification emitter."""

    def __init__(self, name, doc=None):
        Signal.__init__(self, doc)

        #: The name of this signal.
        self.name = name

调用Signal的初始化方法

class Signal(object):
    def __init__(self, doc=None):
        if doc:
            self.__doc__ = doc
        self.receivers = {}
        self._by_receiver = defaultdict(set)
        self._by_sender = defaultdict(set)
        self._weak_senders = {}

接下来的流程是,我们需要订阅信号比如signals.request_started.connect(func),会调用connect

def connect(self, receiver, sender=ANY, weak=True):
    # 生成receiver_id
    receiver_id = hashable_identity(receiver)
    if weak:
        # 将receiver_id和self传入创建引用
        receiver_ref = reference(receiver, self._cleanup_receiver)
        receiver_ref.receiver_id = receiver_id
    else:
        receiver_ref = receiver
    if sender is ANY:
        sender_id = ANY_ID
    else:
        sender_id = hashable_identity(sender)
    # 在字典中存储这个引用
    self.receivers.setdefault(receiver_id, receiver_ref)
    self._by_sender[sender_id].add(receiver_id)
    self._by_receiver[receiver_id].add(sender_id)
    del receiver_ref
    return receiver

当我们订阅某个信号时,信号内的receivers字典保存了我们绑定的函数。当flask内部调用该信号的send方法时,比如appcontext_pushed.send(self.app)

class Signal(object):
    def send(self, *sender, **kwargs):
        if len(sender) == 0:
            sender = None
        elif len(sender) > 1:
            raise TypeError('send() accepts only one positional argument, '
                            '%s given' % len(sender))
        else:
            # 进行长度判断,最终获取到self.app
            sender = sender[0]
        # 如果我们没订阅信号,字典为空,执行到这里结束
        if not self.receivers:
            return []
       # 如果是我们订阅的信号,receivers字典不为空
        else:
            return [(receiver, receiver(sender, **kwargs))  # 在这里执行我们的函数
                    for receiver in self.receivers_for(sender)]

调用self.receivers_for(sender)内部实现:

def receivers_for(self, sender):
    # 再次判断self.receivers是否为空
    if self.receivers:
        # 生成sender_id
        sender_id = hashable_identity(sender)
        if sender_id in self._by_sender:
            ids = (self._by_sender[ANY_ID] |
                   self._by_sender[sender_id])
        else:
            ids = self._by_sender[ANY_ID].copy()
        # ids里存的就是receiver_id的集合
        for receiver_id in ids:
            # 通过receiver_id获取到receiver,内部包含了我们的函数
            receiver = self.receivers.get(receiver_id)
            if receiver is None:
                continue
            if isinstance(receiver, WeakTypes):
                # receiver()执行返回我们自定义的函数
                strong = receiver()
                if strong is None:
                    self._disconnect(receiver_id, ANY_ID)
                    continue
                receiver = strong
            # 返回生成器
            yield receiver

TODO:flask源码还有一些没读,以后会在这里继续更新

posted @ 2021-11-25 00:27  yyyz  阅读(127)  评论(0编辑  收藏  举报