c/c++实现有栈协程

有栈协程

有栈协程通过切换执行上下文实现,核心是切换栈寄存器和跳转代码地址(IP寄存器),同时需要保存切换当前编译器ABI规定的 非易失寄存器

System V AMD64 ABI 和 MSVC x64 ABI 的非易失性寄存器

RBX、RBP、RSP、R12、R13、R14、R15
XMM6-XMM15

RDI、RSI、仅MSVC

MSVCGCC多了两个RDI、RSI,这些寄存器在切换时是必须要保存的。更多详情参考官方文档。

MSVC x64 ABI

切换核心汇编如下

linux amd64

extern "C" __attribute__((sysv_abi, naked, noinline)) void
switch_jmp(context_amd64 *save_, context_amd64 *jmp_);

.text
.globl switch_jmp

switch_jmp:
    // rdi = save_, rsi = jmp_
 
    // 保存上下文到save_
    movq %rbx, 0x000(%rdi)    // RBX
    movq %rbp, 0x008(%rdi)    // RBP
    movq %rsp, 0x010(%rdi)    // RSP
    movq %r12, 0x018(%rdi)    // R12
    movq %r13, 0x020(%rdi)    // R13
    movq %r14, 0x028(%rdi)    // R14
    movq %r15, 0x030(%rdi)    // R15

#ifdef SAVE_FLOAT
    movaps %xmm6, 0x040(%rdi)   // XMM6
    movaps %xmm7, 0x050(%rdi)   // XMM7
    movaps %xmm8, 0x060(%rdi)   // XMM8
    movaps %xmm9, 0x070(%rdi)   // XMM9
    movaps %xmm10, 0x080(%rdi)  // XMM10
    movaps %xmm11, 0x090(%rdi)  // XMM11
    movaps %xmm12, 0x0A0(%rdi)  // XMM12
    movaps %xmm13, 0x0B0(%rdi)  // XMM13
    movaps %xmm14, 0x0C0(%rdi)  // XMM14
    movaps %xmm15, 0x0D0(%rdi)  // XMM15
#endif

    // 从jmp_恢复上下文
    movq 0x000(%rsi), %rbx  // RBX
    movq 0x008(%rsi), %rbp  // RBP
    movq 0x010(%rsi), %rsp  // RSP
    movq 0x018(%rsi), %r12  // R12
    movq 0x020(%rsi), %r13  // R13
    movq 0x028(%rsi), %r14  // R14
    movq 0x030(%rsi), %r15  // R15
#ifdef SAVE_FLOAT
    movaps 0x040(%rsi), %xmm6
    movaps 0x050(%rsi), %xmm7
    movaps 0x060(%rsi), %xmm8
    movaps 0x070(%rsi), %xmm9
    movaps 0x080(%rsi), %xmm10
    movaps 0x090(%rsi), %xmm11
    movaps 0x0A0(%rsi), %xmm12
    movaps 0x0B0(%rsi), %xmm13
    movaps 0x0C0(%rsi), %xmm14
    movaps 0x0D0(%rsi), %xmm15
#endif

ret

对齐相关

一定要注意16字节对齐。凡是涉及xmm浮点数操作都要16字节对齐。
现代编译器默认都启用sse3指令去操作浮点数或者优化,一旦栈不对其,sse指令会直接报错!所以不仅仅是xmm存储切换的内存地址,还有栈对齐。

栈必须是16字节对齐的,给有栈协程创建栈基地址必须是16字节对齐。
注意,栈对齐是包含call压入的ret地址的,也就是说,构造的协程入口栈是16字节对齐再压入一个指针地址。

 // 压入一个空地址 对齐16 重要 这是默认call的压栈
stack_ = stack_ - sizeof(uintptr_t);
// 压入entry,配合switch_jmp
stack_ = stack_ - sizeof(uintptr_t);
*static_cast<uintptr_t *>((void *)stack_) =
            (uintptr_t)&coroutine_base::entry;

如上,stack是申请的内存,它默认一般绝对是16字节对齐的。
我们压入的entry地址会被switch_jmp ret弹出,但上面还压入了一个指针!
就是为了进入entry的时候esp是16字节对齐下被call压入,此时栈指针才是正确的。

c++异常在windows下的有栈协程里失效

实际情况可能很复杂,msvc c++实现在SEH(结构化异常处理)机制之上,如上的上下文切换在windows会让协程的c++异常无法被catch拦截,会直接崩溃。

借鉴了boost.context的实现,hack了线程gs:[030h] NT_TIB结构,切换了栈base地址,msvc c++异常就可以被拦截 成功运行,但MINGW下异常还是无法被正常catch!!

    ; load NT_TIB
    mov  r10,  gs:[030h]

    mov  rax, [r10+01478h]
    mov  [rcx+048h], rax

    mov  rax, [r10+010h]
    mov  [rcx+050h], rax

    mov  rax,  [r10+08h]
    mov  [rcx+058h], rax

    //set
    mov  r10,  gs:[030h]

    mov  rax, [rdx+048h]
    mov  [r10+01478h], rax

    mov  rax, [rdx+050h]
    mov  [r10+010h], rax
    
    mov  rax, [rdx+058h]
    mov  [r10+08h], rax

关于TLS,不推荐自己hack gs fs寄存器实现TLS,在原有TLS上封装协程的TLS就好。

关于栈切换保存在哪儿?有的实现是直接在当前栈压入,但这种实现是需要返回值的。还是采取了非压入栈的,在独立的地方存储上下文。

封装相关:协程挂起时,应该跳出到调度器,保持干净的栈,所以挂起时时返回到调度器的resume后。

关于enter参数传递

比如,我们可以在switch_jmp最后一个参数加一个指针,
void switch_jmp(context_amd64 *save_, context_amd64 *jmp_,void* args_1);

在switch_jmp汇编的ret前,将第三个参数寄存器赋值到第一个参数寄存器,这样进入static void entry(void* args_1)就能传参。但是由于必须使用tls,所以就没有采用这种实现。

有栈协程实现demo,全部源码:

//allocate.h
#include <cstddef>
#include <cassert>

#ifdef _WIN32
#include <Windows.h>
std::size_t page_size() {
    SYSTEM_INFO si;
    ::GetSystemInfo(&si);
    return static_cast<std::size_t>(si.dwPageSize);
}
void *allocate_stack(size_t size_, size_t &out_size) {
    size_t page_size_ = page_size();
    const std::size_t pages = (size_ + page_size_ - 1) / page_size_;
    const std::size_t size__ = (pages + 1) * page_size_;
    out_size = size__;

    void *vp = ::VirtualAlloc(0, size__, MEM_COMMIT, PAGE_READWRITE);
    assert(vp != nullptr && "VirtualAlloc NULL");
    DWORD old_options;
    // 增加一个PAGE_GUARD保护页,拦截栈溢出
    const BOOL result = ::VirtualProtect(
        vp, page_size_, PAGE_READWRITE | PAGE_GUARD, &old_options);
    assert(FALSE != result && "VirtualProtect false");
    return vp;
}

void free_stack(void *ptr, size_t size) { ::VirtualFree(ptr, 0, MEM_RELEASE); }
#elif __linux__
#include <fcntl.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <unistd.h>

std::size_t page_size() {
    return static_cast<std::size_t>(::sysconf(_SC_PAGESIZE));
}

void *allocate_stack(size_t size_, size_t &out_size) {
    size_t page_size_ = page_size();

    const std::size_t pages = (size_ + page_size_ - 1) / page_size_;
    const std::size_t size__ = (pages + 1) * page_size_;
    out_size = size__;

    void *vp = ::mmap(0, size__, PROT_READ | PROT_WRITE,
                      MAP_PRIVATE | MAP_ANON | MAP_STACK, -1, 0);
    assert(vp != nullptr && "::mmap return null!");
    const int result(::mprotect(vp, page_size_, PROT_NONE));
    assert(0 == result && "::mprotect faild!");
    return vp;
}
void free_stack(void *ptr, size_t size) { ::munmap(ptr, size); }

#endif
//stackfull.cpp
#include <cassert>
#include <cstddef>
#include <iostream>
#include <cstdint>
#include <functional>
#include <iomanip>
#include <memory>
#include <list>
#include <algorithm>
#include <atomic>
#include <string.h>

#include "allocate.h"

#if !(defined(__x86_64__) || defined(_M_X64) || defined(__amd64__) ||          \
      defined(__amd64))
#error "此程序仅支持 AMD64 (x86-64) 架构平台,不支持当前平台编译。"
#endif

namespace co {

struct context_amd64;
#ifdef _WIN64
#if defined(_MSC_VER)
extern "C" __declspec(noinline) void switch_jmp(context_amd64 *save_,
                                                context_amd64 *jmp_);
#elif __MINGW64__
extern "C" __attribute__((ms_abi, naked, noinline)) void
switch_jmp(context_amd64 *save_, context_amd64 *jmp_);
#endif
#elif __linux__
extern "C" __attribute__((sysv_abi, naked, noinline)) void
switch_jmp(context_amd64 *save_, context_amd64 *jmp_);
#endif

// System V AMD64 ABI 和 MSVC x64 ABI 的非易失性寄存器
// RDI、RSI、仅MSVC
// RBX、RBP、RSP、R12、R13、R14、R15
// XMM6-XMM15

constexpr size_t rbp_index = 1;
constexpr size_t rsp_index = 2;
// xmm align 16
struct xmm_regs {
    uint64_t _1;
    uint64_t _2;
};
struct alignas(16) context_amd64 {
    uint64_t regs[7]{}; // RBX、RBP、RSP、R12、R13、R14、R15
#ifdef _WIN64
    uint64_t rdi_rsi[2]{}; // RDI、RSI、
    uint64_t tib_info[3];  // fc_dealloc  limit  base     windows tib
#else
    uint64_t alignas__ = 0;//仅对齐
#endif // _WIN64_
    // cmakelist中启用SAVE_FLOAT
#ifdef SAVE_FLOAT
    xmm_regs regs_xmm[10]{}; // XMM6-XMM15 它需要16对齐
#endif
};

struct coroutine_base;
struct context_tls {
    context_amd64 break_{};
    struct coroutine_base *co_current_ = nullptr;
};

constexpr size_t STACK_SIZE = 1024 * 1024 * 2;
thread_local context_tls switch_;

void co_suspend();

void debug_print(const char *action, const char *name) {
    std::cout << std::right << std::setfill('=');
    std::string action_;
    action_.append("[");
    action_.append(action);
    action_.append("]");
    std::cout << std::setw(15) << action_;
    std::cout << std::setw(10) << name << std::endl;
};

struct coroutine_base {
    static void entry() noexcept {
        context_tls &tls_ = switch_;
        coroutine_base *this_ = tls_.co_current_;
        debug_print("start", this_->name());
        try {
            this_->invoke();
        } catch (...) {
            this_->except_ptr = std::current_exception();
        }
        this_->done_ = true;
        debug_print("end", this_->name());
        switch_jmp(&this_->context_,
                   &tls_.break_); // 没有ret地址,不能继续必须跳转
    };

  public:
    void resume() {
        if (done_)
            return;
        context_tls &tls_ = switch_;
        tls_.co_current_ = this;
        debug_print("resume", this->name());
        switch_jmp(&tls_.break_, &this->context_);
        debug_print("suspend", this->name());
    }
    bool done() { return done_; }

    template <typename R> R get() {
        return std::move(*std::launder(static_cast<R *>(result())));
    }

    std::exception_ptr &except() { return except_ptr; }

    template <typename R> void set(R &&r) {
        using type = std::remove_reference_t<R>;
        *std::launder(static_cast<type *>(result())) = r;
    }

    template <typename RR> static void current_yield(RR &r) {
        context_tls &tls_ = switch_;
        tls_.co_current_->set(std::forward<RR>(r));
        co::co_suspend();
    }

    const char *name() { return name_.c_str(); }
    void set_name(const char *_name) { name_ = _name; }
    virtual ~coroutine_base() { free_stack(stack_m_, stack_size_); }

  protected:
    coroutine_base(size_t stack_size) {
        stack_m_ = (uint8_t *)allocate_stack(stack_size, stack_size_);
        stack_ = stack_m_ + stack_size_; // 逆行增长

        // 压入一个空地址 对齐16 重要 这是默认call的压栈
        stack_ = stack_ - sizeof(uintptr_t);
        // 压入entry,配合switch_jmp
        stack_ = stack_ - sizeof(uintptr_t);
        *static_cast<uintptr_t *>((void *)stack_) =
            (uintptr_t)&coroutine_base::entry;

        context_.regs[rsp_index] = (uintptr_t)(void *)stack_;
        context_.regs[rbp_index] = (uintptr_t)(void *)stack_;
#ifdef _WIN64
        //  c++ 异常 msvc windows
        context_.tib_info[0] = (uintptr_t)(stack_ - stack_size);
        context_.tib_info[1] = (uintptr_t)(stack_ - stack_size);
        context_.tib_info[2] = (uintptr_t)stack_;
#endif
    };

    virtual void invoke() = 0;
    virtual void *result() = 0;

    context_amd64 context_{};
    uint8_t *stack_m_ = nullptr;
    uint8_t *stack_ = nullptr;
    size_t stack_size_ = 0;
    std::atomic_bool done_ = false;
    std::exception_ptr except_ptr{};
    std::string name_;

    friend void co_suspend();
};

void co_suspend() {
    context_tls &tls_ = switch_;
    coroutine_base *this_ = tls_.co_current_;
    switch_jmp(&this_->context_, &tls_.break_);
};

template <typename R> struct result_storage {
    alignas(std::max_align_t) uint8_t ret_s[sizeof(R)]{};
    void *get_addr() { return static_cast<void *>(&this->ret_s[0]); }
};
template <> struct result_storage<void> {
    void *get_addr() {
        throw std::runtime_error("void null!");
        return nullptr;
    }
};

template <typename R, typename... Args>
struct coroutine : public coroutine_base {
  public:
    template <typename F, typename... Args2>
    coroutine(F &&_func, Args2 &&..._args) noexcept
        : coroutine_base(STACK_SIZE), args_(std::forward<Args2>(_args)...) {
        coro_main_ = _func;
    }

  private:
    result_storage<R> ret{};
    std::tuple<Args...> args_{};
    std::function<R(Args...)> coro_main_;
    void invoke() override {
        if constexpr (std::is_void_v<R>) {
            std::apply(coro_main_, args_);
        } else {
            new (ret.get_addr()) R(std::apply(coro_main_, args_));
        }
    }
    void *result() override { return ret.get_addr(); }
};

template <typename F, typename... Args>
co::coroutine_base *create_co(const char *name, F &&_func, Args &&..._args) {
    using R = std::invoke_result_t<F, Args...>;
    auto *ptr = new co::coroutine<R, Args...>(std::forward<F>(_func),
                                              std::forward<Args>(_args)...);
    ptr->set_name(name);
    return ptr;
};
} // namespace co

int main() {
    std::cout << 1.132 << std::endl;
    using coro_unptr = std::unique_ptr<co::coroutine_base>;
    std::list<coro_unptr> task_;

    task_.emplace_back(co::create_co("co_1", []() {
        std::cout << "hello " << std::endl;

        co::co_suspend();
        std::cout << "world~" << std::endl;
    }));

    task_.emplace_back(co::create_co(
        "co_2",
        [](int x) {
            int my_value = 123;
            int other = 200;
            other += 100;
            double f64 = 1.23;
            f64 = f64 + 2.222;
            co::co_suspend();
            double ret = (my_value + other + x) + f64;
            std::cout << "co_r value:" << ret << std::endl;
            return ret;
        },
        10000));

    while (task_.size() > 0) {
        for (auto &t : task_) {
            if (!t->done()) {
                t->resume();
            }
        }
        std::erase_if(task_, [](const auto &x) { return x->done(); });
    }

    coro_unptr generator_test(co::create_co(
        "co_3",
        [](size_t x) -> size_t {
            for (size_t i = 0; i < x; i++) {
                co::coroutine_base::current_yield(i);
            }
            return 0;
        },
        10));

    while (true) {
        if (!generator_test->done()) {
            generator_test->resume();
            size_t v = generator_test->get<size_t>();
            std::cout << "generator:" << v << std::endl;
        } else
            break;
    }
    coro_unptr throw_test(co::create_co(
        "co_test", []() { throw std::runtime_error("error !!!"); }));
    try {
        throw_test->resume();
        if (throw_test->except())
            std::rethrow_exception(throw_test->except());
    } catch (const std::exception &e) {
        std::cout << "exception: " << e.what() << std::endl;
    }

    std::cout << "end####" << std::endl;
    return 0;
}

;win_msvc.asm
.code

switch_jmp proc frame
    .endprolog

    ;RCX save_,RDX jmp_
    ;-----------保存到上下文到save_-------------
    mov [rcx + 000h], rbx   ; RBX
    mov [rcx + 008h], rbp   ; RBP
    mov [rcx + 010h], rsp   ; RSP
    mov [rcx + 018h], r12   ; R12
    mov [rcx + 020h], r13   ; R13
    mov [rcx + 028h], r14   ; R14
    mov [rcx + 030h], r15   ; R15
    mov [rcx + 038h], rdi   ; RDI
    mov [rcx + 040h], rsi   ; RSI

    ; load NT_TIB
    mov  r10,  gs:[030h]

    mov  rax, [r10+01478h]
    mov  [rcx+048h], rax

    mov  rax, [r10+010h]
    mov  [rcx+050h], rax

    mov  rax,  [r10+08h]
    mov  [rcx+058h], rax
             
IFDEF SAVE_FLOAT
    movaps [rcx + 0060h], xmm6    ; XMM6
    movaps [rcx + 0070h], xmm7    ; XMM7
    movaps [rcx + 0080h], xmm8    ; XMM8
    movaps [rcx + 0090h], xmm9    ; XMM9
    movaps [rcx + 00A0h], xmm10   ; XMM10
    movaps [rcx + 00B0h], xmm11   ; XMM11
    movaps [rcx + 00C0h], xmm12   ; XMM12
    movaps [rcx + 00D0h], xmm13   ; XMM13
    movaps [rcx + 00E0h], xmm14   ; XMM14
    movaps [rcx + 00F0h], xmm15   ; XMM15
ENDIF

    
    ;-----------从jmp_恢复上下文-------------

    ; 恢复整数寄存器
    mov rbx, [rdx + 000h] ; RBX
    mov rbp, [rdx + 008h] ; RBP
    mov rsp, [rdx + 010h] ; RSP
    mov r12, [rdx + 018h] ; R12
    mov r13, [rdx + 020h] ; R13
    mov r14, [rdx + 028h] ; R14
    mov r15, [rdx + 030h] ; R15
    mov rdi, [rdx + 038h] ; RDI
    mov rsi, [rdx + 040h] ; RSI

    ; load NT_TIB
    mov  r10,  gs:[030h]

    mov  rax, [rdx+048h]
    mov  [r10+01478h], rax

    mov  rax, [rdx+050h]
    mov  [r10+010h], rax
    
    mov  rax, [rdx+058h]
    mov  [r10+08h], rax

IFDEF SAVE_FLOAT
    movaps  xmm6,  [rdx+0060h]
    movaps  xmm7,  [rdx+0070h]
    movaps  xmm8,  [rdx+0080h]
    movaps  xmm9,  [rdx+0090h]
    movaps  xmm10, [rdx+00A0h]
    movaps  xmm11, [rdx+00B0h]
    movaps  xmm12, [rdx+00C0h]
    movaps  xmm13, [rdx+00D0h]
    movaps  xmm14, [rdx+00E0h]
    movaps  xmm15, [rdx+00F0h]
ENDIF
    
    ret
    switch_jmp endp
end    

//linux_gas.S
.text
.globl switch_jmp

switch_jmp:
    // rdi = save_, rsi = jmp_
 
    // 保存上下文到save_
    movq %rbx, 0x000(%rdi)    // RBX
    movq %rbp, 0x008(%rdi)    // RBP
    movq %rsp, 0x010(%rdi)    // RSP
    movq %r12, 0x018(%rdi)    // R12
    movq %r13, 0x020(%rdi)    // R13
    movq %r14, 0x028(%rdi)    // R14
    movq %r15, 0x030(%rdi)    // R15

#ifdef SAVE_FLOAT
    movaps %xmm6, 0x040(%rdi)   // XMM6
    movaps %xmm7, 0x050(%rdi)   // XMM7
    movaps %xmm8, 0x060(%rdi)   // XMM8
    movaps %xmm9, 0x070(%rdi)   // XMM9
    movaps %xmm10, 0x080(%rdi)  // XMM10
    movaps %xmm11, 0x090(%rdi)  // XMM11
    movaps %xmm12, 0x0A0(%rdi)  // XMM12
    movaps %xmm13, 0x0B0(%rdi)  // XMM13
    movaps %xmm14, 0x0C0(%rdi)  // XMM14
    movaps %xmm15, 0x0D0(%rdi)  // XMM15
#endif

    // 从jmp_恢复上下文
    movq 0x000(%rsi), %rbx  // RBX
    movq 0x008(%rsi), %rbp  // RBP
    movq 0x010(%rsi), %rsp  // RSP
    movq 0x018(%rsi), %r12  // R12
    movq 0x020(%rsi), %r13  // R13
    movq 0x028(%rsi), %r14  // R14
    movq 0x030(%rsi), %r15  // R15
#ifdef SAVE_FLOAT
    movaps 0x040(%rsi), %xmm6
    movaps 0x050(%rsi), %xmm7
    movaps 0x060(%rsi), %xmm8
    movaps 0x070(%rsi), %xmm9
    movaps 0x080(%rsi), %xmm10
    movaps 0x090(%rsi), %xmm11
    movaps 0x0A0(%rsi), %xmm12
    movaps 0x0B0(%rsi), %xmm13
    movaps 0x0C0(%rsi), %xmm14
    movaps 0x0D0(%rsi), %xmm15
#endif

ret

cmake_minimum_required (VERSION 3.10)


project ("stackfull")
set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

add_compile_definitions(SAVE_FLOAT)

if (UNIX AND NOT APPLE)
    message(STATUS "Platform: Linux")
    enable_language(ASM)
    set(ASM_SOURCES linux_gas.S)
elseif (WIN32)
    if (MSVC)
        message(STATUS "Platform: Windows MSVC")
        enable_language(ASM_MASM)
        set(ASM_SOURCES win_msvc.asm)
    elseif (MINGW)
         message(STATUS "Platform: Windows MinGW")
        enable_language(ASM)
        set_source_files_properties(win_gas.S PROPERTIES
    COMPILE_FLAGS "-x assembler-with-cpp"
)
        set(ASM_SOURCES win_gas.S)
         add_compile_options(-fexceptions)
    endif ()
else()
    message(FATAL_ERROR "Unsupported platform")
endif()


add_executable (stackfull "stackfull.cpp" ${ASM_SOURCES})

posted @ 2025-09-18 22:41  0xc  阅读(5)  评论(0)    收藏  举报