RUST 实现 Future trait

RUST 实现 Future trait

rust 的 Future 可以简单理解成一个状态机,async/await 是实现 Future trait 的语法糖

use std::{sync::{Arc, Mutex}, task::Waker, time::Duration};

#[tokio::main]
async fn main() {
    println!("main start");
    SleepFut::new(Duration::from_secs(3)).await;
    println!("main end");
}


struct SleepFut {
    duration: Duration,
    state: Arc<Mutex<SleepState>>,
}

struct SleepState {
    waker: Option<Waker>,
    inner_state: SleepInnerState,
}

enum SleepInnerState {
    Init,
    Sleeping,
    Done,
}

impl SleepFut {
    fn new(duration: Duration) -> Self {
        SleepFut {
            duration,
            state: Arc::new(Mutex::new(SleepState {
                waker: None,
                inner_state: SleepInnerState::Init,
            })),
        }
    }
}

impl Future for SleepFut {
    type Output = ();

    fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Self::Output> {

        println!("polling...");

        let mut guard = self.state.lock().unwrap();

        match guard.inner_state {
            SleepInnerState::Init => {

                println!("init");

                guard.inner_state = SleepInnerState::Sleeping;
                let duration = self.duration;
                let state = self.state.clone();
                
                std::thread::spawn(move || {
                    println!("thread start");
                    std::thread::sleep(duration);
                    let mut guard = state.lock().unwrap();
                    guard.inner_state = SleepInnerState::Done;
                    if let Some(waker) = guard.waker.take()  {
                        waker.wake();
                    }
                    println!("thread end");
                });

                guard.waker = Some(cx.waker().clone());
                std::task::Poll::Pending
            }
            SleepInnerState::Sleeping => {
                println!("sleeping");
                // 只有当新的 waker 与现有 waker 不同时才更新
                // 这里使用 will_wake 进行优化
                match &guard.waker {
                    Some(w) if w.will_wake(cx.waker()) => {
                        // 已经是相同的 waker,不需要更新
                    }
                    _ => {
                        // 更新 waker
                        guard.waker = Some(cx.waker().clone());
                    }
                }

                std::task::Poll::Pending
            },
            SleepInnerState::Done => {
                println!("done");

                std::task::Poll::Ready(())
            },
        }

    }
}
main start
polling...
init
thread start
thread end
polling...
done
main end
posted @ 2025-09-19 16:33  等你下课啊  阅读(3)  评论(0)    收藏  举报