FastAPI系列:中间件

中间件介绍

中间件是一个函数,它在每个请求被特定的路径操作处理之前 ,以及在每个响应返回之前工作

装饰器版中间件

1.必须使用装饰器@app.middleware("http"),且middleware_type必须为http
2.中间件参数:request, call_next,且call_next 它将接收 request 作为参数

@app.middleware("http")
async def custom_middleware(request: Request, call_next):
    logger.info("Before request")
    response = await call_next(request)  # 让请求继续处理
    logger.info("After request")
    # 也可以在返回response之前做一些事情,比如添加响应头header
    # response.headers['xxx'] = 'xxx'
    return response


@app.get("/")
def read_root():
    logger.info("执行了.......")
    return {"message": "hello world"}

自定义中间件BaseHTTPMiddleware

BaseHTTPMiddleware是一个抽象类,允许您针对请求/响应接口编写ASGI中间件

要使用 实现中间件类BaseHTTPMiddleware,您必须重写该 async def dispatch(request, call_next)方法,

如果您想为中间件类提供配置选项,您应该重写该__init__方法,确保第一个参数是app,并且任何剩余参数都是可选关键字参数。app 如果执行此操作,请确保在实例上设置该属性。

# 通过继承BaseHTTPMiddleware来实现自定义的中间件
import time

from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from fastapi import FastAPI, Request
from starlette.responses import Response

app = FastAPI()

# 基于BaseHTTPMiddleware的中间件实例
class TimeCcalculateMiddleware(BaseHTTPMiddleware):
    # dispatch必须实现
    async def dispatch(self, request: Request, call_next):
        print('start')
        start_time = time.time()
        response = await call_next(request)
        process_time = round(time.time() - start_time, 4)
        #返回接口响应事件
        response.headers['X-Process-Time'] = f"{process_time} (s)"
        print('end')
        return response


class AuthMiddleware(BaseHTTPMiddleware):
    def __init__(self,app, header_value='auth'):
        super().__init__(app)
        self.header_value = header_value
        
    #dispatch必须实现 
    async def dispatch(self, request:Request, call_next):
        print('auth start')
        response =  await call_next(request)
        response.headers['Custom'] = self.header_value
        print('auth end')
        return response
        
# fastapi实例的add_middleware方法
app.add_middleware(TimeCcalculateMiddleware)
app.add_middleware(AuthMiddleware, header_value='CustomAuth')

@app.get('/index')
async def index():
    print('index start')
    return  {
        'code': 200
    }

"""执行顺序
auth start
start
index start
end
auth end
"""

ip白名单中间件(基于纯ASGI中间)

根据官网说明BaseHTTPMiddleware有一些已知的局限性:

使用BaseHTTPMiddleware将阻止对contextlib.ContextVar的更改向上传播。
也就是说,如果您ContextVar在端点中设置 a 值并尝试从中间件读取它,您会发现该值与您在端点中设置的值不同

纯ASGI中间件,使用类的方式

class ASGIMiddleware:
    def __init__(self, app):
        self.app = app

    async def __call__(self, scope, receive, send):
        await self.app(scope, receive, send)

上面的中间件是最基本的ASGI中间件。它接收父 ASGI 应用程序作为其构造函数的参数,并实现async __call__调用该父应用程序的方法。
无论如何,ASGI 中间件必须是接受三个参数的可调用对象:scopereceivesend

  • scope是一个保存有关连接信息的字典,其中scope["type"]可能是:
  • receivesend 可以用来与ASGI服务器交换ASGI事件消息。这些消息的类型和内容取决于作用域类型。在ASGI规范中了解更多信息
# 基于自定义类来实现
from fastapi import FastAPI
app = FastAPI()

from starlette.responses import PlainTextResponse
from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.requests import HTTPConnection
import typing

class WhiteIpMiddleware:
    def __init(self, app:ASGIApp, allow_ip: typing.Sequence[str] = ()) -> None:
        self.app = app
        self.allow_ip = allow_ip or '*'
        
    async def __call__(self, scope:Scope, receive:Receive, send:Send)->None:
        if scope['type'] in ('http','websocket') and scope['scheme'] in ('http', 'ws'):
            conn = HTTPConnection(scope=scope)
            if self.allow_ip and conn.client.host not in self.allow_ip:
                response = PlainTextResponse(content='不在ip白名单内', status_code=403)
                await response(scope, receive, send)
                return
            await self.app(scope, receive, send)
        else:
            await self.app(scope, receive, send)

app.add_middleware(WhiteIpMiddleware, allow_ip=['127.0.0.2'])

@app.get('/index')
async def index():
    print('index-start')
    return {'code': 200}

跨域中间件cors

同源:协议,域,端口相同

from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware  # fastapi内置了一个CORSMiddleware,可以直接使用
import uvicorn
app = FastAPI()

origins = [
    "http://localhost.tiangolo.com",
    "https://localhost.tiangolo.com",
    "http://localhost",
    "http://localhost:8080",
]

app.add_middleware(
	CORSMiddleware,
    allow_origins=origins,  #一个允许跨域请求的源列表
    allow_credentials=True, #指示跨域请求支持 cookies,默认False, 另外,允许凭证时allow_origins 不能设定为 ['*'],必须指定源。
    allow_methods=["*"], # 一个允许跨域请求的 HTTP 方法列表,默认get
    allow_headers=["*"], # 一个允许跨域请求的 HTTP 请求头列表
)

@app.get("/")
async def main():
    return {"message": "hello world"}

if __name__ == '__main__':
    uvicorn.run(app=app)
posted @ 2024-02-28 17:56  我在路上回头看  阅读(320)  评论(0编辑  收藏  举报