c/c++实现有栈协程
有栈协程
有栈协程通过切换执行上下文实现,核心是切换栈寄存器和跳转代码地址(IP寄存器),同时需要保存切换当前编译器ABI
规定的 非易失寄存器
。
System V AMD64 ABI 和 MSVC x64 ABI 的非易失性寄存器
RBX、RBP、RSP、R12、R13、R14、R15
XMM6-XMM15
RDI、RSI、仅MSVC
MSVC
比GCC
多了两个RDI、RSI
,这些寄存器在切换时是必须要保存的。更多详情参考官方文档。
切换核心汇编如下
linux amd64
extern "C" __attribute__((sysv_abi, naked, noinline)) void
switch_jmp(context_amd64 *save_, context_amd64 *jmp_);
.text
.globl switch_jmp
switch_jmp:
// rdi = save_, rsi = jmp_
// 保存上下文到save_
movq %rbx, 0x000(%rdi) // RBX
movq %rbp, 0x008(%rdi) // RBP
movq %rsp, 0x010(%rdi) // RSP
movq %r12, 0x018(%rdi) // R12
movq %r13, 0x020(%rdi) // R13
movq %r14, 0x028(%rdi) // R14
movq %r15, 0x030(%rdi) // R15
#ifdef SAVE_FLOAT
movaps %xmm6, 0x040(%rdi) // XMM6
movaps %xmm7, 0x050(%rdi) // XMM7
movaps %xmm8, 0x060(%rdi) // XMM8
movaps %xmm9, 0x070(%rdi) // XMM9
movaps %xmm10, 0x080(%rdi) // XMM10
movaps %xmm11, 0x090(%rdi) // XMM11
movaps %xmm12, 0x0A0(%rdi) // XMM12
movaps %xmm13, 0x0B0(%rdi) // XMM13
movaps %xmm14, 0x0C0(%rdi) // XMM14
movaps %xmm15, 0x0D0(%rdi) // XMM15
#endif
// 从jmp_恢复上下文
movq 0x000(%rsi), %rbx // RBX
movq 0x008(%rsi), %rbp // RBP
movq 0x010(%rsi), %rsp // RSP
movq 0x018(%rsi), %r12 // R12
movq 0x020(%rsi), %r13 // R13
movq 0x028(%rsi), %r14 // R14
movq 0x030(%rsi), %r15 // R15
#ifdef SAVE_FLOAT
movaps 0x040(%rsi), %xmm6
movaps 0x050(%rsi), %xmm7
movaps 0x060(%rsi), %xmm8
movaps 0x070(%rsi), %xmm9
movaps 0x080(%rsi), %xmm10
movaps 0x090(%rsi), %xmm11
movaps 0x0A0(%rsi), %xmm12
movaps 0x0B0(%rsi), %xmm13
movaps 0x0C0(%rsi), %xmm14
movaps 0x0D0(%rsi), %xmm15
#endif
ret
对齐相关
一定要注意16字节对齐。凡是涉及xmm
浮点数操作都要16字节对齐。
现代编译器默认都启用sse3指令去操作浮点数或者优化,一旦栈不对其,sse
指令会直接报错!所以不仅仅是xmm存储切换的内存地址,还有栈对齐。
栈必须是16字节对齐的,给有栈协程创建栈基地址必须是16字节对齐。
注意,栈对齐是包含call压入的ret地址的,也就是说,构造的协程入口栈是16字节对齐再压入一个指针地址。
// 压入一个空地址 对齐16 重要 这是默认call的压栈
stack_ = stack_ - sizeof(uintptr_t);
// 压入entry,配合switch_jmp
stack_ = stack_ - sizeof(uintptr_t);
*static_cast<uintptr_t *>((void *)stack_) =
(uintptr_t)&coroutine_base::entry;
如上,stack是申请的内存,它默认一般绝对是16字节对齐的。
我们压入的entry地址会被switch_jmp ret弹出,但上面还压入了一个指针!
就是为了进入entry的时候esp是16字节对齐下被call压入,此时栈指针才是正确的。
c++异常在windows下的有栈协程里失效
实际情况可能很复杂,msvc c++
实现在SEH(结构化异常处理)机制之上,如上的上下文切换在windows会让协程的c++异常无法被catch
拦截,会直接崩溃。
借鉴了boost.context
的实现,hack了线程gs:[030h]
NT_TIB
结构,切换了栈base
地址,msvc c++
异常就可以被拦截 成功运行,但MINGW
下异常还是无法被正常catch
!!
; load NT_TIB
mov r10, gs:[030h]
mov rax, [r10+01478h]
mov [rcx+048h], rax
mov rax, [r10+010h]
mov [rcx+050h], rax
mov rax, [r10+08h]
mov [rcx+058h], rax
//set
mov r10, gs:[030h]
mov rax, [rdx+048h]
mov [r10+01478h], rax
mov rax, [rdx+050h]
mov [r10+010h], rax
mov rax, [rdx+058h]
mov [r10+08h], rax
关于TLS,不推荐自己hack gs fs
寄存器实现TLS,在原有TLS上封装协程的TLS就好。
关于栈切换保存在哪儿?有的实现是直接在当前栈压入,但这种实现是需要返回值的。还是采取了非压入栈的,在独立的地方存储上下文。
封装相关:协程挂起时,应该跳出到调度器,保持干净的栈,所以挂起时时返回到调度器的resume后。
关于enter参数传递
比如,我们可以在switch_jmp最后一个参数加一个指针,
void switch_jmp(context_amd64 *save_, context_amd64 *jmp_,void* args_1);
在switch_jmp汇编的ret前,将第三个参数寄存器赋值到第一个参数寄存器,这样进入static void entry(void* args_1)
就能传参。但是由于必须使用tls,所以就没有采用这种实现。
有栈协程实现demo,全部源码:
//allocate.h
#include <cstddef>
#include <cassert>
#ifdef _WIN32
#include <Windows.h>
std::size_t page_size() {
SYSTEM_INFO si;
::GetSystemInfo(&si);
return static_cast<std::size_t>(si.dwPageSize);
}
void *allocate_stack(size_t size_, size_t &out_size) {
size_t page_size_ = page_size();
const std::size_t pages = (size_ + page_size_ - 1) / page_size_;
const std::size_t size__ = (pages + 1) * page_size_;
out_size = size__;
void *vp = ::VirtualAlloc(0, size__, MEM_COMMIT, PAGE_READWRITE);
assert(vp != nullptr && "VirtualAlloc NULL");
DWORD old_options;
// 增加一个PAGE_GUARD保护页,拦截栈溢出
const BOOL result = ::VirtualProtect(
vp, page_size_, PAGE_READWRITE | PAGE_GUARD, &old_options);
assert(FALSE != result && "VirtualProtect false");
return vp;
}
void free_stack(void *ptr, size_t size) { ::VirtualFree(ptr, 0, MEM_RELEASE); }
#elif __linux__
#include <fcntl.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <unistd.h>
std::size_t page_size() {
return static_cast<std::size_t>(::sysconf(_SC_PAGESIZE));
}
void *allocate_stack(size_t size_, size_t &out_size) {
size_t page_size_ = page_size();
const std::size_t pages = (size_ + page_size_ - 1) / page_size_;
const std::size_t size__ = (pages + 1) * page_size_;
out_size = size__;
void *vp = ::mmap(0, size__, PROT_READ | PROT_WRITE,
MAP_PRIVATE | MAP_ANON | MAP_STACK, -1, 0);
assert(vp != nullptr && "::mmap return null!");
const int result(::mprotect(vp, page_size_, PROT_NONE));
assert(0 == result && "::mprotect faild!");
return vp;
}
void free_stack(void *ptr, size_t size) { ::munmap(ptr, size); }
#endif
//stackfull.cpp
#include <cassert>
#include <cstddef>
#include <iostream>
#include <cstdint>
#include <functional>
#include <iomanip>
#include <memory>
#include <list>
#include <algorithm>
#include <atomic>
#include <string.h>
#include "allocate.h"
#if !(defined(__x86_64__) || defined(_M_X64) || defined(__amd64__) || \
defined(__amd64))
#error "此程序仅支持 AMD64 (x86-64) 架构平台,不支持当前平台编译。"
#endif
namespace co {
struct context_amd64;
#ifdef _WIN64
#if defined(_MSC_VER)
extern "C" __declspec(noinline) void switch_jmp(context_amd64 *save_,
context_amd64 *jmp_);
#elif __MINGW64__
extern "C" __attribute__((ms_abi, naked, noinline)) void
switch_jmp(context_amd64 *save_, context_amd64 *jmp_);
#endif
#elif __linux__
extern "C" __attribute__((sysv_abi, naked, noinline)) void
switch_jmp(context_amd64 *save_, context_amd64 *jmp_);
#endif
// System V AMD64 ABI 和 MSVC x64 ABI 的非易失性寄存器
// RDI、RSI、仅MSVC
// RBX、RBP、RSP、R12、R13、R14、R15
// XMM6-XMM15
constexpr size_t rbp_index = 1;
constexpr size_t rsp_index = 2;
// xmm align 16
struct xmm_regs {
uint64_t _1;
uint64_t _2;
};
struct alignas(16) context_amd64 {
uint64_t regs[7]{}; // RBX、RBP、RSP、R12、R13、R14、R15
#ifdef _WIN64
uint64_t rdi_rsi[2]{}; // RDI、RSI、
uint64_t tib_info[3]; // fc_dealloc limit base windows tib
#else
uint64_t alignas__ = 0;//仅对齐
#endif // _WIN64_
// cmakelist中启用SAVE_FLOAT
#ifdef SAVE_FLOAT
xmm_regs regs_xmm[10]{}; // XMM6-XMM15 它需要16对齐
#endif
};
struct coroutine_base;
struct context_tls {
context_amd64 break_{};
struct coroutine_base *co_current_ = nullptr;
};
constexpr size_t STACK_SIZE = 1024 * 1024 * 2;
thread_local context_tls switch_;
void co_suspend();
void debug_print(const char *action, const char *name) {
std::cout << std::right << std::setfill('=');
std::string action_;
action_.append("[");
action_.append(action);
action_.append("]");
std::cout << std::setw(15) << action_;
std::cout << std::setw(10) << name << std::endl;
};
struct coroutine_base {
static void entry() noexcept {
context_tls &tls_ = switch_;
coroutine_base *this_ = tls_.co_current_;
debug_print("start", this_->name());
try {
this_->invoke();
} catch (...) {
this_->except_ptr = std::current_exception();
}
this_->done_ = true;
debug_print("end", this_->name());
switch_jmp(&this_->context_,
&tls_.break_); // 没有ret地址,不能继续必须跳转
};
public:
void resume() {
if (done_)
return;
context_tls &tls_ = switch_;
tls_.co_current_ = this;
debug_print("resume", this->name());
switch_jmp(&tls_.break_, &this->context_);
debug_print("suspend", this->name());
}
bool done() { return done_; }
template <typename R> R get() {
return std::move(*std::launder(static_cast<R *>(result())));
}
std::exception_ptr &except() { return except_ptr; }
template <typename R> void set(R &&r) {
using type = std::remove_reference_t<R>;
*std::launder(static_cast<type *>(result())) = r;
}
template <typename RR> static void current_yield(RR &r) {
context_tls &tls_ = switch_;
tls_.co_current_->set(std::forward<RR>(r));
co::co_suspend();
}
const char *name() { return name_.c_str(); }
void set_name(const char *_name) { name_ = _name; }
virtual ~coroutine_base() { free_stack(stack_m_, stack_size_); }
protected:
coroutine_base(size_t stack_size) {
stack_m_ = (uint8_t *)allocate_stack(stack_size, stack_size_);
stack_ = stack_m_ + stack_size_; // 逆行增长
// 压入一个空地址 对齐16 重要 这是默认call的压栈
stack_ = stack_ - sizeof(uintptr_t);
// 压入entry,配合switch_jmp
stack_ = stack_ - sizeof(uintptr_t);
*static_cast<uintptr_t *>((void *)stack_) =
(uintptr_t)&coroutine_base::entry;
context_.regs[rsp_index] = (uintptr_t)(void *)stack_;
context_.regs[rbp_index] = (uintptr_t)(void *)stack_;
#ifdef _WIN64
// c++ 异常 msvc windows
context_.tib_info[0] = (uintptr_t)(stack_ - stack_size);
context_.tib_info[1] = (uintptr_t)(stack_ - stack_size);
context_.tib_info[2] = (uintptr_t)stack_;
#endif
};
virtual void invoke() = 0;
virtual void *result() = 0;
context_amd64 context_{};
uint8_t *stack_m_ = nullptr;
uint8_t *stack_ = nullptr;
size_t stack_size_ = 0;
std::atomic_bool done_ = false;
std::exception_ptr except_ptr{};
std::string name_;
friend void co_suspend();
};
void co_suspend() {
context_tls &tls_ = switch_;
coroutine_base *this_ = tls_.co_current_;
switch_jmp(&this_->context_, &tls_.break_);
};
template <typename R> struct result_storage {
alignas(std::max_align_t) uint8_t ret_s[sizeof(R)]{};
void *get_addr() { return static_cast<void *>(&this->ret_s[0]); }
};
template <> struct result_storage<void> {
void *get_addr() {
throw std::runtime_error("void null!");
return nullptr;
}
};
template <typename R, typename... Args>
struct coroutine : public coroutine_base {
public:
template <typename F, typename... Args2>
coroutine(F &&_func, Args2 &&..._args) noexcept
: coroutine_base(STACK_SIZE), args_(std::forward<Args2>(_args)...) {
coro_main_ = _func;
}
private:
result_storage<R> ret{};
std::tuple<Args...> args_{};
std::function<R(Args...)> coro_main_;
void invoke() override {
if constexpr (std::is_void_v<R>) {
std::apply(coro_main_, args_);
} else {
new (ret.get_addr()) R(std::apply(coro_main_, args_));
}
}
void *result() override { return ret.get_addr(); }
};
template <typename F, typename... Args>
co::coroutine_base *create_co(const char *name, F &&_func, Args &&..._args) {
using R = std::invoke_result_t<F, Args...>;
auto *ptr = new co::coroutine<R, Args...>(std::forward<F>(_func),
std::forward<Args>(_args)...);
ptr->set_name(name);
return ptr;
};
} // namespace co
int main() {
std::cout << 1.132 << std::endl;
using coro_unptr = std::unique_ptr<co::coroutine_base>;
std::list<coro_unptr> task_;
task_.emplace_back(co::create_co("co_1", []() {
std::cout << "hello " << std::endl;
co::co_suspend();
std::cout << "world~" << std::endl;
}));
task_.emplace_back(co::create_co(
"co_2",
[](int x) {
int my_value = 123;
int other = 200;
other += 100;
double f64 = 1.23;
f64 = f64 + 2.222;
co::co_suspend();
double ret = (my_value + other + x) + f64;
std::cout << "co_r value:" << ret << std::endl;
return ret;
},
10000));
while (task_.size() > 0) {
for (auto &t : task_) {
if (!t->done()) {
t->resume();
}
}
std::erase_if(task_, [](const auto &x) { return x->done(); });
}
coro_unptr generator_test(co::create_co(
"co_3",
[](size_t x) -> size_t {
for (size_t i = 0; i < x; i++) {
co::coroutine_base::current_yield(i);
}
return 0;
},
10));
while (true) {
if (!generator_test->done()) {
generator_test->resume();
size_t v = generator_test->get<size_t>();
std::cout << "generator:" << v << std::endl;
} else
break;
}
coro_unptr throw_test(co::create_co(
"co_test", []() { throw std::runtime_error("error !!!"); }));
try {
throw_test->resume();
if (throw_test->except())
std::rethrow_exception(throw_test->except());
} catch (const std::exception &e) {
std::cout << "exception: " << e.what() << std::endl;
}
std::cout << "end####" << std::endl;
return 0;
}
;win_msvc.asm
.code
switch_jmp proc frame
.endprolog
;RCX save_,RDX jmp_
;-----------保存到上下文到save_-------------
mov [rcx + 000h], rbx ; RBX
mov [rcx + 008h], rbp ; RBP
mov [rcx + 010h], rsp ; RSP
mov [rcx + 018h], r12 ; R12
mov [rcx + 020h], r13 ; R13
mov [rcx + 028h], r14 ; R14
mov [rcx + 030h], r15 ; R15
mov [rcx + 038h], rdi ; RDI
mov [rcx + 040h], rsi ; RSI
; load NT_TIB
mov r10, gs:[030h]
mov rax, [r10+01478h]
mov [rcx+048h], rax
mov rax, [r10+010h]
mov [rcx+050h], rax
mov rax, [r10+08h]
mov [rcx+058h], rax
IFDEF SAVE_FLOAT
movaps [rcx + 0060h], xmm6 ; XMM6
movaps [rcx + 0070h], xmm7 ; XMM7
movaps [rcx + 0080h], xmm8 ; XMM8
movaps [rcx + 0090h], xmm9 ; XMM9
movaps [rcx + 00A0h], xmm10 ; XMM10
movaps [rcx + 00B0h], xmm11 ; XMM11
movaps [rcx + 00C0h], xmm12 ; XMM12
movaps [rcx + 00D0h], xmm13 ; XMM13
movaps [rcx + 00E0h], xmm14 ; XMM14
movaps [rcx + 00F0h], xmm15 ; XMM15
ENDIF
;-----------从jmp_恢复上下文-------------
; 恢复整数寄存器
mov rbx, [rdx + 000h] ; RBX
mov rbp, [rdx + 008h] ; RBP
mov rsp, [rdx + 010h] ; RSP
mov r12, [rdx + 018h] ; R12
mov r13, [rdx + 020h] ; R13
mov r14, [rdx + 028h] ; R14
mov r15, [rdx + 030h] ; R15
mov rdi, [rdx + 038h] ; RDI
mov rsi, [rdx + 040h] ; RSI
; load NT_TIB
mov r10, gs:[030h]
mov rax, [rdx+048h]
mov [r10+01478h], rax
mov rax, [rdx+050h]
mov [r10+010h], rax
mov rax, [rdx+058h]
mov [r10+08h], rax
IFDEF SAVE_FLOAT
movaps xmm6, [rdx+0060h]
movaps xmm7, [rdx+0070h]
movaps xmm8, [rdx+0080h]
movaps xmm9, [rdx+0090h]
movaps xmm10, [rdx+00A0h]
movaps xmm11, [rdx+00B0h]
movaps xmm12, [rdx+00C0h]
movaps xmm13, [rdx+00D0h]
movaps xmm14, [rdx+00E0h]
movaps xmm15, [rdx+00F0h]
ENDIF
ret
switch_jmp endp
end
//linux_gas.S
.text
.globl switch_jmp
switch_jmp:
// rdi = save_, rsi = jmp_
// 保存上下文到save_
movq %rbx, 0x000(%rdi) // RBX
movq %rbp, 0x008(%rdi) // RBP
movq %rsp, 0x010(%rdi) // RSP
movq %r12, 0x018(%rdi) // R12
movq %r13, 0x020(%rdi) // R13
movq %r14, 0x028(%rdi) // R14
movq %r15, 0x030(%rdi) // R15
#ifdef SAVE_FLOAT
movaps %xmm6, 0x040(%rdi) // XMM6
movaps %xmm7, 0x050(%rdi) // XMM7
movaps %xmm8, 0x060(%rdi) // XMM8
movaps %xmm9, 0x070(%rdi) // XMM9
movaps %xmm10, 0x080(%rdi) // XMM10
movaps %xmm11, 0x090(%rdi) // XMM11
movaps %xmm12, 0x0A0(%rdi) // XMM12
movaps %xmm13, 0x0B0(%rdi) // XMM13
movaps %xmm14, 0x0C0(%rdi) // XMM14
movaps %xmm15, 0x0D0(%rdi) // XMM15
#endif
// 从jmp_恢复上下文
movq 0x000(%rsi), %rbx // RBX
movq 0x008(%rsi), %rbp // RBP
movq 0x010(%rsi), %rsp // RSP
movq 0x018(%rsi), %r12 // R12
movq 0x020(%rsi), %r13 // R13
movq 0x028(%rsi), %r14 // R14
movq 0x030(%rsi), %r15 // R15
#ifdef SAVE_FLOAT
movaps 0x040(%rsi), %xmm6
movaps 0x050(%rsi), %xmm7
movaps 0x060(%rsi), %xmm8
movaps 0x070(%rsi), %xmm9
movaps 0x080(%rsi), %xmm10
movaps 0x090(%rsi), %xmm11
movaps 0x0A0(%rsi), %xmm12
movaps 0x0B0(%rsi), %xmm13
movaps 0x0C0(%rsi), %xmm14
movaps 0x0D0(%rsi), %xmm15
#endif
ret
cmake_minimum_required (VERSION 3.10)
project ("stackfull")
set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
add_compile_definitions(SAVE_FLOAT)
if (UNIX AND NOT APPLE)
message(STATUS "Platform: Linux")
enable_language(ASM)
set(ASM_SOURCES linux_gas.S)
elseif (WIN32)
if (MSVC)
message(STATUS "Platform: Windows MSVC")
enable_language(ASM_MASM)
set(ASM_SOURCES win_msvc.asm)
elseif (MINGW)
message(STATUS "Platform: Windows MinGW")
enable_language(ASM)
set_source_files_properties(win_gas.S PROPERTIES
COMPILE_FLAGS "-x assembler-with-cpp"
)
set(ASM_SOURCES win_gas.S)
add_compile_options(-fexceptions)
endif ()
else()
message(FATAL_ERROR "Unsupported platform")
endif()
add_executable (stackfull "stackfull.cpp" ${ASM_SOURCES})