rust-web组件 第Ⅰ章 --chapter 2. actix-web 简单的api网关

toml

log = "0.4.29"
tracing = "0.1.44"
tracing-appender = "0.2.4"
tracing-subscriber = { version = "0.3.22", features = ["local-time", "env-filter"] }
logroller ={version = "0.1.10", features = ["tracing"]}
chrono = { version = "0.4.42", features = ["serde"] }
time = { version = "0.3.43", features = ["macros"] }
jsonwebtokens = "1.2.0"
actix-web = { version = "4.12.1"}
actix-cors = "0.7.1"
toml = "0.9.8"
nacos_rust_client = "0.3.2"
redis = {"version" = "1.0.1", features = ["ahash", "connection-manager","tokio-comp"]}
serde = "1.0.228"
serde_json = "1.0.148"
local_ipaddress = "0.1.3"
awc = { version = "3.8.1" }
tokio = { version = "1.48.0", features = ["sync", "io-util"] }
tokio-stream = "0.1.17"
futures-util = { version = "0.3.31", default-features = false, features = ["std"] }
mimalloc = {"version" = "0.1.48", features = ["v3"]}
xid = "1.1.1"

mian

use crate::init::config::Config;
use actix_web::{App, HttpServer, middleware, web};
use mimalloc::MiMalloc;
use std::io;

mod init;
mod util;

#[global_allocator]
static GLOBAL: MiMalloc = MiMalloc;

#[actix_web::main]
async fn main() -> io::Result<()> {
    let _guard = init::logger::tracing_init();
    log::info!("{}", init::constant::BANNER);
    init::nacos::init_nacos().await;
    init::redis::init_redis().await;
    let log_format = r#"%a "%r" %s %b "%{Referer}i" "%{User-Agent}i" "%{x-request-user}i" "%{x-request-id}i" %T"#;
    HttpServer::new(move || {
        App::new()
            .app_data(web::Data::new(
                awc::Client::builder()
                    .timeout(std::time::Duration::from_secs(90))
                    .no_default_headers()
                    .finish(),
            ))
            .wrap(middleware::Logger::new(log_format))
            //将请求的userId提取并放入请求头
            .wrap(util::request_user::RequestUserMiddleware)
            //生成requestId, 方便日志追踪
            .wrap(util::request_id::RequestIdMiddleware)
            .wrap(
                actix_cors::Cors::default()
                    // 允许指定源
                    .allow_any_origin("cnblogs.com")
                    // 允许指定HTTP方法
                    .allowed_methods(vec!["GET", "POST"])
                    // 允许指定请求头
                    .allow_any_header()
                    // 设置预检请求缓存时间(秒)
                    .max_age(3600),
            )
            .default_service(web::to(init::service::forward_reqwest))
    })
    .bind((Config::global().server_ip(), Config::global().server_port()))
    .expect("server bind failed")
    .shutdown_timeout(10)
    .run()
    .await
}

request_id中间件

use actix_web::http::header;
use actix_web::{
    body::MessageBody,
    dev::{Service, ServiceRequest, ServiceResponse, Transform, forward_ready},
};
use futures_util::future::{LocalBoxFuture, ready};

use crate::init::constant::X_REQUEST_ID;

// 1. 定义中间件结构体(无状态,可实例化)
#[derive(Clone, Default)]
pub struct RequestIdMiddleware;

// 2. 实现 Transform trait:将中间件转换为服务
impl<S, B> Transform<S, ServiceRequest> for RequestIdMiddleware
where
    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error> + 'static,
    S::Future: 'static,
    B: MessageBody + 'static,
{
    // 中间件转换后的服务类型
    type Response = ServiceResponse<B>;
    type Error = actix_web::Error;
    type InitError = ();
    type Transform = RequestIdService<S>; // 实际处理请求的服务
    type Future = LocalBoxFuture<'static, Result<Self::Transform, Self::InitError>>;

    // 创建中间件服务(包装下一个服务)
    fn new_transform(&self, service: S) -> Self::Future {
        Box::pin(ready(Ok(RequestIdService { service })))
    }
}

// 3. 定义中间件服务:持有下一个服务,实现请求处理逻辑
pub struct RequestIdService<S> {
    service: S, // 下一个服务(如路由处理函数)
}

// 4. 实现 Service trait:定义请求处理逻辑
impl<S, B> Service<ServiceRequest> for RequestIdService<S>
where
    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error> + 'static,
    S::Future: 'static,
    B: MessageBody + 'static,
{
    type Response = ServiceResponse<B>;
    type Error = actix_web::Error;
    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;

    // 检查服务是否就绪
    forward_ready!(service);

    // 核心:处理请求,生成/复用 request_id
    fn call(&self, mut req: ServiceRequest) -> Self::Future {
        // a. 生成或复用 request_id(优先使用客户端传递的 X-Request-ID)
        let request_id = req
            .headers()
            .get(X_REQUEST_ID)
            .and_then(|h| h.to_str().ok())
            .map(|s| s.to_string())
            .unwrap_or_else(|| xid::new().to_string());

        // b. 将 request_id 写入请求头(供后续中间件/路由使用)
        req.headers_mut().insert(
            header::HeaderName::from_static(X_REQUEST_ID),
            header::HeaderValue::from_str(&request_id).expect("request_id write header error"),
        );

        // e. 调用下一个服务处理请求,并在响应中添加 request_id
        let fut = self.service.call(req);
        Box::pin(async move {
            let mut res = fut.await?; // 等待下一个服务处理结果
            // 将 request_id 写入响应头
            res.headers_mut().insert(
                header::HeaderName::from_static(X_REQUEST_ID),
                header::HeaderValue::from_str(&request_id).unwrap(),
            );
            Ok(res)
        })
    }
}

转发请求

use actix_web::{Error, HttpRequest, HttpResponse, error, http::header::AUTHORIZATION, web};
use awc::{Client, http::Uri};
use std::borrow::Cow;

use crate::init::{
    auth,
    constant::{IP_HEADER, X_REQUEST_USER},
    nacos,
};

pub(crate) async fn forward_reqwest(
    req: HttpRequest,
    payload: web::Payload,
    client: web::Data<Client>,
) -> Result<HttpResponse, Error> {
    let user_id = req.headers().get(X_REQUEST_USER);
    let token = req.headers().get(AUTHORIZATION);

    // 优化认证逻辑,减少分支判断
    match (user_id, token) {
        (Some(user_id), Some(token)) => {
            if let (Ok(user_id_str), Ok(token_str)) = (user_id.to_str(), token.to_str()) {
                if !user_id_str.is_empty() && is_valid_request(user_id_str, token_str, &req).await {
                    return proxy_to_service(req, payload, client).await;
                }
            }
        }
        _ => {
            // 跳过登录接口
            let path = req.uri().path();
            if path.ends_with("login") {
                return proxy_to_service(req, payload, client).await;
            }
        }
    }
    //未认证全部403
    Ok(HttpResponse::Forbidden().finish())
}

async fn is_valid_request(user_id: &str, token: &str, req: &HttpRequest) -> bool {
   // 身份校验相关逻辑
    false
}

async fn proxy_to_service(
    req: HttpRequest,
    payload: web::Payload,
    client: web::Data<Client>,
) -> Result<HttpResponse, Error> {
    let path = req.uri().path();
    let mut parts = path.split('/');

    // 跳过第一个空元素
    parts.next();

    // 获取服务名称
    let req_server_name = match parts.next() {
        Some(name) => name,
        None => return Ok(HttpResponse::BadRequest().finish()),
    };

    let service_addr = nacos::select_service(req_server_name).await;
    if service_addr.is_empty() {
        return Ok(HttpResponse::NotFound().finish());
    }

    // 优化URL构建,使用Cow避免不必要的分配
    let remaining_path = if let Some(first_part) = parts.next() {
        let mut path_builder = String::with_capacity(path.len());
        path_builder.push('/');
        path_builder.push_str(first_part);

        for part in parts {
            path_builder.push('/');
            path_builder.push_str(part);
        }

        Cow::Owned(path_builder)
    } else {
        Cow::Borrowed("")
    };

    // 构建URL,避免多次格式化
    let query_str = req.uri().query().unwrap_or_default();
    let url_str = if query_str.is_empty() {
        format!("http://{}{}", service_addr, remaining_path)
    } else {
        format!("http://{}{}?{}", service_addr, remaining_path, query_str)
    };

    let new_url: Uri = url_str
        .parse()
        .map_err(|_| error::ErrorBadRequest("Invalid URL"))?;

    // 使用连接池优化
    let forwarded_req = client.request_from(new_url, req.head()).no_decompress();

    let res = forwarded_req
        .send_stream(payload)
        .await
        .map_err(error::ErrorInternalServerError)?;

    // 优化响应头复制
    let mut client_resp = HttpResponse::build(res.status());
    for (header_name, header_value) in res.headers().iter() {
        // 跳过连接相关的头部
        if header_name.as_str().eq_ignore_ascii_case("connection")
            || header_name
                .as_str()
                .eq_ignore_ascii_case("transfer-encoding")
        {
            continue;
        }
        client_resp.insert_header((header_name.clone(), header_value.clone()));
    }

    Ok(client_resp.streaming(res))
}

异步redis客户端

use std::sync::LazyLock;

use redis::{AsyncCommands, Client};

use super::config::Config;

pub(crate) static CLIENT: LazyLock<Client> = LazyLock::new(|| {
    let url = generate_connection_url();
    Client::open(url).unwrap()
});

pub(crate) async fn init_redis() {
    let mut conn = CLIENT.get_connection_manager().await.unwrap();
    let _a: String = conn.ping().await.unwrap();
    log::info!("redis connected");
}

fn generate_connection_url() -> String {
    let mut url = "redis://".to_owned();
    if !Config::global().redis_password().is_empty() {
        url.push(':');
        url.push_str(&Config::global().redis_password());
        url.push('@');
    }
    url.push_str(&Config::global().redis_ip());
    url.push(':');
    url.push_str(&Config::global().redis_port().to_string());
    url.push('/');
    url.push_str(&Config::global().redis_db().to_string());
    url
}

// Optimized function that checks existence and gets value in a single connection
pub(crate) async fn get_value_if_exists(key: &str) -> Option<String> {
    let mut conn = CLIENT.get_connection_manager().await.unwrap();

    // Use a pipeline to reduce network round trips
    let (exists, value): (bool, String) = redis::pipe()
        .cmd("EXISTS")
        .arg(key)
        .cmd("GET")
        .arg(key)
        .query_async(&mut conn)
        .await
        .unwrap_or((false, String::new()));

    if exists { Some(value) } else { None }
}

pub(crate) async fn set_expire(key: &str, seconds: i64) {
    let mut conn = CLIENT.get_connection_manager().await.unwrap();
    match conn.expire(key, seconds).await {
        Ok(()) => (),
        Err(err) => {
            log::error!("set key:{key} expire time: {seconds} error: {err}");
        }
    }
}

posted @ 2025-12-31 10:36  JiajieZeee  阅读(2)  评论(0)    收藏  举报