超简单!手把手实现axum简易中间件

axum是Rust语言tokio生态中的重要一环,以轻量、模块化、易用而闻名于世。它的中间件系统集成自另一个叫tower的框架,这就意味着如果我们要写axum的中间件的话,就得了解一下这个tower的各个核心概念,并学习它的用法。但是,很多时候我们可能只是想写一点简单的小工具,为了小需求去学习一个复杂的大库也许并不是很值得。

下图是tower文档里的中间件范例,可以在axum中使用……但是还是算了吧,让我们看看别的解决方法?

还好,axum贴心地为我们提供了一个小工具:middleware::from_fn 。它可以让我们像写Handler一样地写中间件,就像这样:

没有乱七八糟的trait、生命周期和Into<Box<dyn>>,一切都是如此的自然和简单!

axum的文档这样描述这种构建中间件的方法:

当你满足这些条件时,使用middleware::from_fn

  • 你不打算实现自己的Future,而更愿意用熟悉的async/await语法。
  • 你不打算将自己的中间件发布为一个crate供别人使用,因为像这样编写的中间件只能与axum兼容。

显然,我们没有上面的两种顾虑(自己的Future实现?这不是月薪三千还随时可能被裁的可怜rust程序员该考虑的)。我们的目标就是十分钟内搞定需求,所以我们就快乐地选择middleware::from_fn了。

如果读者想更加详细地了解axum的中间件生态,可以去看看官方文档:axum::middleware

准备

接下来,我将会手把手带领你在10分钟内(根据熟悉axum的程度因人而异)写完一个简单的axum中间件。中间件的功能很简单:假设我们系统的登录接口没有设置验证码,为了防止系统内的用户密码被爆破,我们需要对登录接口加一层限流的中间件,限制每个用户每分钟能调用登录接口10次左右。

为了简化教程的复杂度,我们用timedmap库来代替Redis;在实际生产中,建议用Redis来保证服务的可扩展性。


首先,我们创建一个项目,并导入这样的依赖:

[dependencies]
anyhow = "1.0.79"
axum = "0.7.4"
axum-client-ip = "0.5.0"
lazy_static = "1.4.0"
timedmap = { version = "1.0.1", features = ["tokio"] }
tokio = { version = "1.36.0", features = ["full"] }

接下来,我们快速地构建一个简单的axum应用:

use anyhow::Result;
use axum::{routing::get, Router};
use tokio::net::TcpListener;

#[tokio::main]
async fn main() -> Result<()> {
    let app = Router::new().route("/login", get(|| async {"Username or password invalid."}));
    let listener = TcpListener::bind("0.0.0.0:8080").await?;
    axum::serve(listener, app).await?;

    Ok(())
}

这个12行的程序可以在访问/login时返回一段文本,模拟一个成熟系统的登录功能。

运行效果

接下来,我们在它的基础上进行改造,为它编写一个中间件。

中间件

axum的from_fn中间件有这样的模板:

use axum::{
    response::Response,
    middleware::Next,
    extract::Request,
};

async fn my_middleware(
    request: Request,
    next: Next,
) -> Response {
    // do something with `request`...

    let response = next.run(request).await;

    // do something with `response`...

    response
}

简单介绍一下这个模板中的关键事项:

  • 必须是async fn形式的异步函数
  • 必须有一个或多个提取器作为函数参数(在上面的模板中,是Request
  • 最后一个参数必须是Next
  • 返回值必须实现了IntoResponse

在了解了这些事项之后,我们就可以着手写中间件了。

最简单的中间件

话不多说,直接上代码:

use std::time::{self, SystemTime};

async fn hello_world(request: Request, next: Next) -> Response {
    let now = SystemTime::now()
        .duration_since(time::UNIX_EPOCH)
        .unwrap()
        .as_millis();
    if now % 2 == 0 {
        "Surprise!".into_response()
    } else {
        next.run(request).await
    }
}

这个中间件以二分之一的概率拦截请求并返回一个“Surprise”,虽然它并没有什么用 ,但是它很好地为我们的后续开发开了一个头。它就像一个真正的handler一样,接受各种各样的参数,并返回一个Response;它也像真正的handler一样,可以直接返回内容,而不管后续的中间件和handler。

搞定了中间件,接下来就是把它插入到路由里了:

let app = Router::new()
    .route("/login", get(|| async { "Username or password invalid." }))
    .route_layer(middleware::from_fn(hello_world)); // 看这一行

我们调用了axum::middleware::from_fn方法,这是我们之前所提到过的;它将我们的函数转换为一个真正的中间件,如下图所示,它用一个宏把我们的短短几行代码展开成了一个towerService

现在我们再启动程序,刷新刚刚的页面,可以看到新的结果:

程序当前的代码如下:

use std::time::{self, SystemTime};

use anyhow::Result;
use axum::{
    extract::Request,
    middleware::{self, Next},
    response::{IntoResponse, Response},
    routing::get,
    Router,
};
use tokio::net::TcpListener;

async fn hello_world(request: Request, next: Next) -> Response {
    let now = SystemTime::now()
        .duration_since(time::UNIX_EPOCH)
        .unwrap()
        .as_millis();
    if now % 2 == 0 {
        "Surprise!".into_response()
    } else {
        next.run(request).await
    }
}

#[tokio::main]
async fn main() -> Result<()> {
    let app = Router::new()
        .route("/login", get(|| async { "Username or password invalid." }))
        .route_layer(middleware::from_fn(hello_world));
    let listener = TcpListener::bind("0.0.0.0:8080").await?;
    axum::serve(listener, app).await?;

    Ok(())
}

实战:登录限流

现在,我们可以继续我们的实战演练了。首先,我们准备一个数据结构来检查用户是否发出了过多请求:

struct RateLimiter {
    pub map: TimedMap<IpAddr, i32>,
}

impl RateLimiter {
    pub fn new() -> Self {
        Self {
            map: Default::default(),
        }
    }

    pub fn check(&self, ip_addr: &IpAddr) -> bool {
        let count = self.map.get(ip_addr).unwrap_or(0);
        if count > 10 {
            false
        }else{
            self.map.insert(ip_addr.clone(), count+1, Duration::from_secs(6));
            true
        }
    }
}

这里我们用了一个比较简单的算法来判断用户是否发送了过多请求:每次用户的请求到来时,我们从含有效期的哈希表中获取用户最近的连续请求次数;如果次数大于10次,我们判定用户请求量过大,返回false;否则,我们更新哈希表并返回true。这里可以注意到我们用的超时时长是6秒,这是因为每次更新数据时,有效期都会刷新,因此我们需要设置有效期为60秒/10次=6次/秒。


现在我们已经准备好了限流工具,那么我们要怎样将它放到中间件里呢?

也许有的读者刚接触axum,会说:

lazy_static!{
    pub static ref RATE_LIMITER: Mutex<RateLimiter> = Mutex::new(RateLimiter::new());
}

这种写法其实是不适合axum和异步编程的,具体的原因可以询问GPT;正确的做法是通过axum提供的State机制来将限流工具注入到中间件内:

#[derive(Clone)]
struct LimitState {
    pub limiter: Arc<RateLimiter>,
}

impl LimitState {
    pub fn new() -> Self {
        Self {
            limiter: Arc::new(RateLimiter::new()),
        }
    }
}

在这里,我们定义了一个用于中间件的State。应当注意,它内部的限流工具使用了Arc来包裹,并且State实现了Clone特质。这种写法可以保证limiter可以在多个线程/协程之间使用同一个RateLimiter实例。接下来,我们稍微修改中间件部分的from_fn,更换为from_fn_with_state

let state = LimitState::new();
let app = Router::new()
    .route("/login", get(|| async { "Username or password invalid." }))
    .route_layer(middleware::from_fn_with_state(state, rate_limit));  // 变化在这里

拦在我们面前的是最后一道难关:怎么获取用户的IP地址?这里我们直接使用axum-client-ip库,它提供了提取用户真实IP的工具,详情和配置方法可以参阅它的官方文档。在这里,我们只需要在主函数里额外加一行配置就好了:

use axum_client_ip::SecureClientIpSource;

let app = Router::new()
    .route("/login", get(|| async { "Username or password invalid." }))
    .route_layer(middleware::from_fn_with_state(state, rate_limit))
    .layer(SecureClientIpSource::ConnectInfo.into_extension());  // 在这里

在生产环境中,不要忘了把SecureClientIpSource更换成你应用的部署平台/CDN/反向代理使用的HTTP头;如果没有用CDN或者反代,则保留这样不变即可

此外,serve一行也要做一些修改:

axum::serve(listener, app.into_make_service_with_connect_info::<SocketAddr>()).await?;

一切障碍都已经清除,我们现在直接手起刀落,写完我们的中间件:

async fn rate_limit(
    State(state): State<LimitState>,
    SecureClientIp(ip_addr): SecureClientIp,
    request: Request,
    next: Next,
) -> Response {
    if state.limiter.check(&ip_addr) {
        next.run(request).await
    } else {
        (
            StatusCode::TOO_MANY_REQUESTS,
            "Too many requests! Please wait for one minute.",
        )
            .into_response()
    }
}

接下来就是测试环节!

可以看到,我们的中间件已经开始工作啦~

最后附上程序的完整代码:

use std::{
    net::{IpAddr, SocketAddr},
    sync::Arc,
    time::Duration,
};

use anyhow::Result;
use axum::{
    extract::{Request, State},
    http::StatusCode,
    middleware::{self, Next},
    response::{IntoResponse, Response},
    routing::get,
    Router,
};
use axum_client_ip::{SecureClientIp, SecureClientIpSource};
use timedmap::TimedMap;
use tokio::net::TcpListener;

struct RateLimiter {
    pub map: TimedMap<IpAddr, i32>,
}

impl RateLimiter {
    pub fn new() -> Self {
        Self {
            map: Default::default(),
        }
    }

    pub fn check(&self, ip_addr: &IpAddr) -> bool {
        let count = self.map.get(ip_addr).unwrap_or(0);
        if count > 10 {
            false
        } else {
            self.map
                .insert(ip_addr.clone(), count + 1, Duration::from_secs(6));
            true
        }
    }
}

#[derive(Clone)]
struct LimitState {
    pub limiter: Arc<RateLimiter>,
}

impl LimitState {
    pub fn new() -> Self {
        Self {
            limiter: Arc::new(RateLimiter::new()),
        }
    }
}

async fn rate_limit(
    State(state): State<LimitState>,
    SecureClientIp(ip_addr): SecureClientIp,
    request: Request,
    next: Next,
) -> Response {
    if state.limiter.check(&ip_addr) {
        next.run(request).await
    } else {
        (
            StatusCode::TOO_MANY_REQUESTS,
            "Too many requests! Please wait for one minute.",
        )
            .into_response()
    }
}

#[tokio::main]
async fn main() -> Result<()> {
    let state = LimitState::new();
    let app = Router::new()
        .route("/login", get(|| async { "Username or password invalid." }))
        .route_layer(middleware::from_fn_with_state(state, rate_limit))
        .layer(SecureClientIpSource::ConnectInfo.into_extension());
    let listener = TcpListener::bind("0.0.0.0:8080").await?;
    axum::serve(
        listener,
        app.into_make_service_with_connect_info::<SocketAddr>(),
    )
    .await?;

    Ok(())
}
posted @ 2024-02-06 20:33  Cinea  阅读(201)  评论(0编辑  收藏  举报