C++20协程解糖 - 动手实现协程2 - 实现co_await和co_return

在开始之前,我们先修复上一篇文章中的一个bug,SharedState::add_finish_callback中post_all_callbacks应当提前判断settled,否则会在未设置结果的情况下添加callback,callback也会被立即post


template<class T>
class SharedState : public SharedStateBase {
    // ...
    // private
    void add_finish_callback(std::function<void(T&)> callback) {
        finish_callbacks.push_back(std::move(callback));
        if (settled) {
            post_all_callbacks();
        }
    }
};


概述

今天我们要实现的东西包括

  1. 给schedular加上timer支持
  2. 给Future和Promise补充必要设施以支持C++20协程

如果你看到这行文字,说明这篇文章被无耻的盗用了(或者你正在选中文字),请前往 cnblogs.com/pointer-smq 支持原作者,谢谢

开始动手

首先是Schedular的timer支持,我们这里使用一个简单的优先队列用来管理所有timer,并在poll函数中处理完当帧pending_state后,视情况sleep到最近的timer到期,并处理所有到期的timer


class Schedular {
    // ...
    // public
    using timer_callback = std::function<void()>;
    using timer_item = std::tuple<bool, float, chrono::time_point<chrono::steady_clock, chrono::duration<float>>, timer_callback>;
    using timer = std::chrono::steady_clock;
    struct timer_item_cmp {
        bool operator()(const timer_item& a, const timer_item& b) const {
            return std::get<2>(a) > std::get<2>(b);
        }
    };

    // ...
    // public
    void poll() {
        size_t sz = pending_states.size();
        for (size_t i = 0; i != sz; i++) {
            auto state = std::move(pending_states[i]);
            state->invoke_all_callback();
        }
        pending_states.erase(pending_states.begin(), pending_states.begin() + sz);
        if (timer_queue.empty()) {
            return;
        }
        if (pending_states.empty()) { //如果pending_states为空,则可以sleep较长的时间,等待第一个将要完成的timer
            std::this_thread::sleep_until(std::get<2>(timer_queue.front()));
            auto now = timer::now();
            do {
                deal_one_timer();
            } while (!timer_queue.empty() && std::get<2>(timer_queue.front()) <= now);
        } else { //否则只能处理当帧到期的timer,不能sleep,要及时返回给caller,让caller及时下一次poll处理剩下的pending_states
            auto now = timer::now();
            while (!timer_queue.empty() && std::get<2>(timer_queue.front()) <= now) {
                deal_one_timer();
            }
        }
    }

    void add_timer(bool repeat, float delay, timer_callback callback) {
        auto cur_time = chrono::time_point_cast<chrono::duration<float>>(timer::now());
        auto timeout = cur_time + chrono::duration<float>(delay);
        timer_queue.emplace_back(repeat, delay, timeout, callback);
        std::push_heap(timer_queue.begin(), timer_queue.end(), timer_item_cmp{});
    }

    // ...
    // private
    void deal_one_timer() {
        std::pop_heap(timer_queue.begin(), timer_queue.end(), timer_item_cmp{});
        auto item = std::move(timer_queue.back());
        timer_queue.pop_back();
        std::get<3>(item)();
        if (std::get<0>(item)) {
            add_timer(true, std::get<1>(item), std::move(std::get<3>(item)));
        }
    }

    std::deque<timer_item> timer_queue;
};


这样之后,基于当前调度器的delay函数就可以写出来了


class Schedular {
    // ...
    // public
    Future<float> delay(float second) {
        auto promise = Promise<float>(*this);
        add_timer(false, second, [=]() mutable {
            promise.set_result(second);
        });
        return promise.get_future();
    }
    // ...
};


因为之前我们设计的Future和Promise并不支持void,于是这里简单用Future<float>代替,返回的是等待的秒数。

需要注意的是,这个delay函数虽然返回Future,但并不是协程,协程的判断标准是当且仅当函数中使用了co_await/co_yield/co_return,和返回类型无关。

这个函数同样展示了将回调式API封装为Future的做法,就是把Promise.set_result作为回调传入给API,并返回Promise.get_future,使用者在Future这边等待就好了。

有了这些东西之后,我们可以先把本次的测试代码写出来了

如果你看到这行文字,说明这篇文章被无耻的盗用了(或者你正在选中文字),请前往 cnblogs.com/pointer-smq 支持原作者,谢谢


Future<float> func(Schedular& schedular) {
    std::cout << "start sleep\n";
    auto r = co_await schedular.delay(1.2);
    co_return r;
}

Future<int> func2(Schedular& schedular) {
    auto r = co_await func(schedular);
    std::cout << "slept for " << r << "s\n";
    co_return 42;
}


这里需要注意的是C++协程在编译器实现中,会自动构造一个Promise对象,而我们的Promise并不支持默认构造,必须传入一个Schedular参数。好在C++会替我们自动将协程参数作为作为构造函数参数来构造Promise,因此要在协程参数中指定Schedular,相当于指定Schedular构造了Promise。为一个协程显式指定调度器,是一个很合理的设计,python也是类似的设计。C#将协程调度器隐藏进了Task,因为它有一个全局的默认调度器。如果我们的实现中提供一个全局构造的Schedular,让Promise自动去找他调度,那这里的协程也可以没有参数。

如果你看到这行文字,说明这篇文章被无耻的盗用了(或者你正在选中文字),请前往 cnblogs.com/pointer-smq 支持原作者,谢谢

为了让Future支持协程,代码中还需要补充一系列的内容,列举在下面


template<class T>
class Future {
    // ...
    // public
    // 协程接口
    using promise_type = Promise<T>;

    bool await_ready() { return _state->settled; }

    void await_suspend(exp::coroutine_handle<> handle) {
        add_finish_callback([=](T&) mutable { handle.resume(); });
    }

    T await_resume() { return _state->value; }
    // 协程接口
    // ...
};

  • promise_type用来指定本future对应的promise,结果的输入端
  • await_ready检查future是否已经完成
  • await_suspend用来通知future,协程为了等待它完成,已经暂停,future需要在自己完成的时候,主动恢复协程
  • await_resume用来通知future,协程已经恢复执行,需要从future中取出结果,用作co_await表达式的结果。这里我们直接返回拷贝,实现中较为合理的是把future持有的对象移动出去,但这样的话被await的future就不能再单独获取结果了。

为了让Promise支持协程,需要补充的内容在下面


// ...
// 在最开头
// 如果你的编译器已经不需要std::experimental了,那就去掉这行,后面使用std而不是exp
namespace exp = std::experimental;

template<class T>
class Promise {
    // ...
    // public
    // 协程接口
    Future<T> get_return_object();

    exp::suspend_never initial_suspend() { return {}; }

    exp::suspend_never final_suspend() noexcept { return {}; }

    void return_value(T v) { set_result(v); }

    void unhandled_exception() { std::terminate(); }
    // 协程接口
    // ...
};

// ...
// 在Future定义后面
template<class T>
Future<T> Promise<T>::get_return_object() {
    return get_future();
}

  • initial_suspend用来表明协程是否在调用时暂停,异步任务一般返回suepend_never,调用时立即启动
  • final_suspend用来表明协程是否在co_return后暂停(延迟销毁),我们是使用shared_state的异步任务,因此可以不暂停协程,直接自动销毁协程,让shared_state留在空中靠引用计数清零销毁
  • return_value用于co_return将结果传入
  • unhandled_exception用于协程中出现了未处理异常的情况,这里面可以通过std::current_exception来获取当前异常,我们的简化版不可能出现异常,出了就直接terminate
  • get_return_object就是get_future。大家不要忘记一个协程是先构造promise,后从promise获取future的

有了这些东西之后,编译就不应该再出现错误了,我的编译选项是 clang++-9 test.cpp -stdlib=libc++ -std=c++2a 

运行?还差最后一点

为了方便,我们效法python,给Schedular补一个run_until_compete的方法

如果你看到这行文字,说明这篇文章被无耻的盗用了(或者你正在选中文字),请前往 cnblogs.com/pointer-smq 支持原作者,谢谢


class Schedular {
    // ...
    // public
    template<class F, class... Args>
    auto run_until_complete(F&& fn, Args&&... args) -> typename std::invoke_result_t<F&&, Args&&...>::result_type {
        auto future = std::forward<F>(fn)(std::forward<Args>(args)...);
        while (!future.await_ready()) {
            poll();
        }
        return future.await_resume();
    }
};


然后main

int main() {
    Schedular schedular;

    auto r = schedular.run_until_complete(func2, schedular);

    std::cout << "run complete with " << r << "\n";
}


运行结果就有了

start sleep
slept for 1.2s
run complete with 42


怎么样,是不是很简单呢,赶紧自己写一个吧!

如果你看到这行文字,说明这篇文章被无耻的盗用了(或者你正在选中文字),请前往 cnblogs.com/pointer-smq 支持原作者,谢谢

附录 - 全部代码



#include <vector>
#include <deque>
#include <memory>
#include <iostream>
#include <functional>
#include <chrono>
#include <thread>
#include <algorithm>
#include <experimental/coroutine>

namespace exp = std::experimental;

template<class T>
class Future;

template<class T>
class Promise;

class Schedular;

class SharedStateBase : public std::enable_shared_from_this<SharedStateBase> {
    friend class Schedular;
public:
    virtual ~SharedStateBase() = default;
private:
    virtual void invoke_all_callback() = 0;
};

template<class T>
class SharedState : public SharedStateBase {
    friend class Future<T>;
    friend class Promise<T>;
public:
    SharedState(Schedular& schedular)
        : schedular(&schedular)
    {}
    SharedState(const SharedState&) = delete;
    SharedState(SharedState&&) = delete;
    SharedState& operator=(const SharedState&) = delete;
    SharedState& operator=(SharedState&&) = delete;

private:
    template<class U>
    void set(U&& v) {
        if (settled) {
            return;
        }
        settled = true;
        value = std::forward<U>(v);
        post_all_callbacks();
    }

    T& get() { return value; }

    void add_finish_callback(std::function<void(T&)> callback) {
        finish_callbacks.push_back(std::move(callback));
        if (settled) {
            post_all_callbacks();
        }
    }

    void post_all_callbacks();

    virtual void invoke_all_callback() override {
        callback_posted = false;
        size_t sz = finish_callbacks.size();
        for (size_t i = 0; i != sz; i++) {
            auto v = std::move(finish_callbacks[i]);
            v(value);
        }
        finish_callbacks.erase(finish_callbacks.begin(), finish_callbacks.begin()+sz);
    }

    bool settled = false;
    bool callback_posted = false;
    Schedular* schedular = nullptr;
    T value;
    std::vector<std::function<void(T&)>> finish_callbacks;
};

template<class T>
class Promise {
public:
    Promise(Schedular& schedular)
        : _schedular(&schedular)
        , _state(std::make_shared<SharedState<T>>(*_schedular))
    {}

    Future<T> get_future();

    // 协程接口
    Future<T> get_return_object();
    exp::suspend_never initial_suspend() { return {}; }
    exp::suspend_never final_suspend() noexcept { return {}; }
    void return_value(T v) { set_result(v); }
    void unhandled_exception() { std::terminate(); }
    // 协程接口

    template<class U>
    void set_result(U&& value) {
        if (_state->settled) {
            throw std::invalid_argument("already set result");
        }
        _state->set(std::forward<U>(value));
    }
private:
    Schedular* _schedular;
    std::shared_ptr<SharedState<T>> _state;
};

template<class T>
class Future {
public:
    using result_type = T;
    using promise_type = Promise<T>;
    friend class Promise<T>;
private:
    Future(std::shared_ptr<SharedState<T>> state)
        : _state(std::move(state))
    {
    }
public:
    // 协程接口
    bool await_ready() { return _state->settled; }
    void await_suspend(exp::coroutine_handle<> handle) {
        add_finish_callback([=](T&) mutable { handle.resume(); });
    }
    T await_resume() { return _state->value; }
    // 协程接口

    void add_finish_callback(std::function<void(T&)> callback) {
        _state->add_finish_callback(std::move(callback));
    }
private:
    std::shared_ptr<SharedState<T>> _state;
};

template<class T>
Future<T> Promise<T>::get_future() {
    return Future<T>(_state);
}

template<class T>
Future<T> Promise<T>::get_return_object() {
    return get_future();
}

namespace chrono = std::chrono;

class Schedular {
    template<class T>
    friend class SharedState;
public:
    using timer_callback = std::function<void()>;
    using timer_item = std::tuple<bool, float, chrono::time_point<chrono::steady_clock, chrono::duration<float>>, timer_callback>;
    using timer = std::chrono::steady_clock;
    struct timer_item_cmp {
        bool operator()(const timer_item& a, const timer_item& b) const {
            return std::get<2>(a) > std::get<2>(b);
        }
    };

    Schedular() = default;
    Schedular(Schedular&&) = delete;
    Schedular(const Schedular&) = delete;
    Schedular& operator=(Schedular&&) = delete;
    Schedular& operator=(const Schedular&) = delete;

    void poll() {
        size_t sz = pending_states.size();
        for (size_t i = 0; i != sz; i++) {
            auto state = std::move(pending_states[i]);
            state->invoke_all_callback();
        }
        pending_states.erase(pending_states.begin(), pending_states.begin() + sz);
        if (timer_queue.empty()) {
            return;
        }
        if (pending_states.empty()) { //如果pending_states为空,则可以sleep较长的时间,等待第一个将要完成的timer
            std::this_thread::sleep_until(std::get<2>(timer_queue.front()));
            auto now = timer::now();
            do {
                deal_one_timer();
            } while (!timer_queue.empty() && std::get<2>(timer_queue.front()) <= now);
        } else { //否则只能处理当帧到期的timer,不能sleep,要及时返回给caller,让caller及时下一次poll处理剩下的pending_states
            auto now = timer::now();
            while (!timer_queue.empty() && std::get<2>(timer_queue.front()) <= now) {
                deal_one_timer();
            }
        }
    }

    template<class F, class... Args>
    auto run_until_complete(F&& fn, Args&&... args) -> typename std::invoke_result_t<F&&, Args&&...>::result_type {
        auto future = std::forward<F>(fn)(std::forward<Args>(args)...);
        while (!future.await_ready()) {
            poll();
        }
        return future.await_resume();
    }

    void add_timer(bool repeat, float delay, timer_callback callback) {
        auto cur_time = chrono::time_point_cast<chrono::duration<float>>(timer::now());
        auto timeout = cur_time + chrono::duration<float>(delay);
        timer_queue.emplace_back(repeat, delay, timeout, callback);
        std::push_heap(timer_queue.begin(), timer_queue.end(), timer_item_cmp{});
    }

    Future<float> delay(float second) {
        auto promise = Promise<float>(*this);
        add_timer(false, second, [=]() mutable {
            promise.set_result(second);
        });
        return promise.get_future();
    }
private:
    void deal_one_timer() {
        std::pop_heap(timer_queue.begin(), timer_queue.end(), timer_item_cmp{});
        auto item = std::move(timer_queue.back());
        timer_queue.pop_back();
        std::get<3>(item)();
        if (std::get<0>(item)) {
            add_timer(true, std::get<1>(item), std::move(std::get<3>(item)));
        }
    }

    void post_call_state(std::shared_ptr<SharedStateBase> state) {
        pending_states.push_back(std::move(state));
    }
    std::vector<std::shared_ptr<SharedStateBase>> pending_states;
    std::deque<timer_item> timer_queue;
};

template<class T>
void SharedState<T>::post_all_callbacks() {
    if (callback_posted) {
        return;
    }
    callback_posted = true;
    schedular->post_call_state(shared_from_this());
}

Future<float> func(Schedular& schedular) {
    std::cout << "start sleep\n";
    auto r = co_await schedular.delay(1.2);
    co_return r;
}

Future<int> func2(Schedular& schedular) {
    auto r = co_await func(schedular);
    std::cout << "slept for " << r << "s\n";
    co_return 42;
}

int main() {
    Schedular schedular;

    auto r = schedular.run_until_complete(func2, schedular);

    std::cout << "run complete with " << r << "\n";
}


posted on 2020-05-22 23:19  PointerSMQ  阅读(1080)  评论(0编辑  收藏  举报