Lock Free 之 Epoch Based Reclamation

Epoch Based Reclamation

  

Epoch Based Reclamation

epoch based reclamation 

算法参考文档:

http://www.cs.toronto.edu/~tomhart/papers/tomhart_thesis.pdf

https://aturon.github.io/blog/2015/08/27/epoch/#epoch-based-reclamation

 

为什么要 Epoch Based Reclamation

设想我们要在无 gc 的语言中实现一种无锁但支持并发的数据结构(可能是个map、也可能是个 array):

  1. 线程 A 想替换该数据结构中的某个数据节点,线程 A 使用 newNode 原子替换了oldNode,并同时释放了oldNode,防止了内存泄露,这一切看起来没什么问题;
  2. 在线程 A 的上述操作中,若巧好有另一个线程 B 在线程 A 替换前便进行了读取(读到 oldNode 指针),但稍后线程 A 释放了 oldNode,那么线程 B 此时持有的便是一个悬空指针(一旦使用就 coredump 了)。

在拥有 gc 的语言中,如 java, golang,不会出现这种问题。但是对于 c/c++,则没有一种很好的方法应对这问题,除非 非gc语言也有一种延迟回收内存的机制,于是Epoch Based Reclamation作为其中一种方法应运而生。

 

 基本概念

如下图,表示一个节点的删除包含 逻辑删除(delete) 和 物理删除(free) 两个过程:

  • 逻辑删除(delete),一个节点在被逻辑删除之时可能会有其他线程正在访问它,逻辑删除不回收内存空间;
  • 物理删除(free / reclaim),物理删除之后不会再被线程访问到,会将对应的内存空间回收;
  • Grace Period,记时间段 T=[t1, t2],如果 t1 之前逻辑删除的节点,都可以在 t2 之后安全的回收,那么称 T = [t1, t2] 是一个 Grace Period。

 

算法原理

  • 维护了一个全局的 epoch (取值为0、1、2),epoch 的每个取值都对应一个 retire list(存放逻辑删除后待回收的指针);
  • 为每个线程维护一个局部的 active flag 和 epoch (取值自然也为0、1、2);
  • 线程进入临界区前会设置 active flag = true,并设置自己的局部 epoch 为全局 epoch 的值,离开临界区时设置 active flag = false;
  • 线程删除数据时,先放入对应的 retire_list [global_epoch](线程局部 epoch 等于 global epoch);
  • 全局 epoch (假设为 E)的增长规则,若所有活跃线程的 epoch 是否都等于 E 时,置换 E = (E + 1) % 3
 
理解这个算法的关键是理清线程局部 epoch 和全局 epoch 的关系:
在这个算法里,任何时刻,全局 epoch 的值如果为 E,那么线程的局部 epoch 值只能是 E 或  (E-1) % 3,不可能为 (E+1) % 3。
换言之,若有某一个活跃线程中的局部 e = (E - 1) % 3 的时候,E 不会增长,故而不会存在有 e = E - 2 (也即 E + 1)的线程。
 
如下表:
全局 epoch 取值 局部 epoch 可能取值 grace period 可回收的 restire_list
0 2、0 [1, 0] 1
1 0、1 [2, 1] 2
2 1、2 [0, 2] 0

 

前面说到 当全局 epoch = E 时, 活跃线程的 epoch 只能是 E 或 E - 1,当最后一个 e = E -1 的线程完成了临界区操作,也即所有活跃线程中 epoch 计数都等于E的时候,也就是说 E - 2(也即 E + 1)对应的回收队列 retireList 中的节点再也没有任何一个线程访问了,故而retireList[E - 2] 中的数据可以真正的释放了。如此做后,全局计数E = E + 1 

 

看个例子的推演,假设就两个线程 A、B,一个共享数据 N,初始数据为 N0。A 不断更改 N,B 则不断读取 N:

epoch

t0

t1

t2

t3

全局

0 -> 1

1 -> 2

2 -> 0

0

写线程 A

0,N0 -> N1

1,N1 -> N2

2,N2 -> N3

0,N3 -> N4

读线程 B

0,读到 N0

1,读到 N1

2,读到 N2

0,读到 N3

retire_list[0]

[N0]

[N0]

[N0] -> gc -> [ ]

[N3 ]

retire_list[1]

[ ]

[N1]

[N1]

[N1] -> gc -> [ ]

retire_list[1]

[ ]

[ ] -> gc

[N2]

[N2]

 

 

实现 

一个完整的实现例子:

  1 #ifndef DEMO_SMR_H
  2 #define DEMO_SMR_H
  3 
  4 #include <atomic>
  5 #include <vector>
  6 #include <string>
  7 #include <iostream>
  8 #include <chrono>
  9 #include <thread>
 10 
 11 #define CACHELINE_SIZE 64
 12 #define MAX_THREAD_NUM 503
 13 #define FETCH_AND_ADD(address,offset) __sync_fetch_and_add(address,offset)
 14 #define CPU_BARRIER() __asm__ __volatile__("mfence": : :"memory")
 15 
 16 struct alignas(64) ReadIndicator {
 17     void arrive(void) {
 18         counter.fetch_add(1, std::memory_order_seq_cst);
 19     }
 20 
 21     void depart(void) {
 22         counter.fetch_sub(1, std::memory_order_release);
 23     }
 24 
 25     bool empty(void) {
 26         return counter.load(std::memory_order_seq_cst) == 0;
 27     }
 28 
 29 private:
 30     std::atomic<uint64_t> counter{0};
 31 };
 32 
 33 struct alignas(64) ReadIndicatorGuard {
 34     explicit ReadIndicatorGuard(ReadIndicator& inst) : indicator(inst) {
 35         indicator.arrive();
 36     }
 37 
 38     ~ReadIndicatorGuard() { indicator.depart(); }
 39 
 40 private:
 41     ReadIndicator& indicator;
 42 };
 43 
 44 template <typename T>
 45 class SMRManagerBase {
 46 protected:
 47     using RetireListType = std::vector<T*>;
 48     static constexpr int EBR_CYCLE = 3;
 49 
 50 public:
 51     SMRManagerBase() {
 52         type_name = typeid(T).name();
 53     }
 54 
 55     virtual ~SMRManagerBase() {
 56         for (int i = 0; i < EBR_CYCLE; ++i) {
 57             for (T* t : retire_lists_[i]) {
 58                 delete t;
 59             }
 60         }
 61     }
 62 
 63     /* 读者访问临界资源前,首先调用这个函数 */
 64     virtual void reader_enter() = 0;
 65     virtual void reader_leave() = 0;
 66     virtual int smr_type() const = 0;
 67 
 68     int32_t zombie_cnt() const { return zombie_cnt_; }
 69     int32_t reclaim_cnt() const { return reclaim_cnt_; }
 70 
 71     // 写者回收资源。待回收的资源可能不会立马被回收
 72     int reclaim(T* const p) { return writer_reclaim_batch(p); }
 73 
 74     // 批量回收内存
 75     int reclaim(const std::vector<T*>& values) {
 76         return writer_reclaim_batch(values);
 77     }
 78 
 79     int reclaim() { return writer_reclaim_batch(nullptr); }
 80 
 81     int fast_reclaim(int64_t interval_ts, int64_t times) {
 82         int count = 0;
 83         while (times-- > 0) {
 84             std::this_thread::sleep_for(std::chrono::microseconds(interval_ts));
 85             count += reclaim();
 86         }
 87         return count;
 88     }
 89 
 90     int fast_reclaim(T* const p, int64_t interval_ts = 400, int64_t times = 3) {
 91         int count = 0;
 92         count += reclaim(p);
 93         count += fast_reclaim(interval_ts, times);
 94         return count;
 95     }
 96 
 97     // 物理删除一个 retire_list
 98     int writer_free(int epoch) {
 99         int count = 0;
100         RetireListType& retire_list = retire_lists_[epoch];
101         for (auto& retire_pointer : retire_list) {
102             if (retire_pointer != nullptr) {
103                 ++count;
104                 delete retire_pointer;
105                 retire_pointer = nullptr;
106             }
107         }  // end for
108 
109         if (retire_list.size() > 1000) {
110             retire_list.clear();
111             retire_list.shrink_to_fit();
112         }
113         zombie_cnt_ -= count;
114         reclaim_cnt_ += count;
115         return count;
116     }
117 
118     virtual ReadIndicator& get_read_indicator() = 0;
119 
120     ReadIndicatorGuard* get_read_guard() {
121         return new ReadIndicatorGuard(get_read_indicator());
122     }
123 
124 protected:
125     virtual int32_t writer_gc() = 0;
126     virtual int32_t get_epoch() = 0;
127 
128     // 写者回收资源。待回收的资源可能不会立马被回收。
129     int writer_reclaim_batch(T* const p) {
130         writer_record(p);
131         return writer_gc();
132     }
133 
134     //批量回收内存,提高效率
135     int writer_reclaim_batch(const std::vector<T*>& values) {
136         writer_record(values);
137         return writer_gc();
138     }
139 
140     /* 逻辑删除,将 p 写到 retire_list 中 */
141     void writer_record(T* const p) {
142         if (p == nullptr) {
143             return;
144         }
145 
146         RetireListType& retire_list = retire_lists_[get_epoch()];
147         bool found_vacant = false;
148         for (auto& retire_pointer : retire_list) {
149             if (retire_pointer == nullptr) {
150                 retire_pointer = p;
151                 found_vacant = true;
152                 break;
153             }
154         }
155         if (!found_vacant) {
156             retire_list.push_back(p);
157         }
158 
159         ++zombie_cnt_;
160     }
161 
162     void writer_record(const std::vector<T*>& reclaim_nodes) {
163         if (reclaim_nodes.empty()) {
164             return;
165         }
166         RetireListType& retire_list = retire_lists_[get_epoch()];
167 
168         int32_t rn_index = 0;
169         int32_t store_index = 0;
170         int cnt = 0;
171         while (rn_index < reclaim_nodes.size()) {
172             // 尝试找到一个空的位置存储一下
173             while (store_index < retire_list.size() &&
174                    retire_list[store_index] != nullptr) {
175                 ++store_index;
176             }
177 
178             // 如果没有空的位置,说明 retire_list 需要扩容
179             if (store_index >= retire_list.size()) {
180                 break;
181             }
182 
183             // 找到空位置的
184             retire_list[store_index++] = reclaim_nodes[rn_index++];
185             ++cnt;
186         }
187 
188         if (rn_index < reclaim_nodes.size()) {
189             int remains = reclaim_nodes.size() - rn_index;
190             retire_list.reserve(retire_list.size() + remains + 10);
191 
192             while (rn_index < reclaim_nodes.size()) {
193                 retire_list.push_back(reclaim_nodes[rn_index++]);
194                 ++cnt;
195             }
196         }
197 
198         zombie_cnt_ += reclaim_nodes.size();
199         std::cout << "reclaim node:" << reclaim_nodes.size() << ", cnt:" << cnt << std::endl;
200     }
201 
202 protected:
203     RetireListType retire_lists_[3];
204 
205     int32_t zombie_cnt_ = 0;            // 逻辑删除的计数器
206     int32_t reclaim_cnt_ = 0;           // 物理删除的计数器
207     std::string type_name;
208 };
209 
210 template <typename T>
211 // SMR is short for Safety Memory Reclamation
212 class SMRManager : public SMRManagerBase<T> {
213     using Base = typename SMRManager::SMRManagerBase;
214 
215 public:
216     using Base::fast_reclaim;
217     using Base::reclaim;
218     using Base::reclaim_cnt;
219     using Base::writer_free;
220     using Base::zombie_cnt;
221     using Base::type_name;
222 
223     // 这个构造需要在单线程下进行
224     SMRManager() {
225         global_epoch_.epoch_ = 0;
226         local_epoches_ = new EpochType[MAX_THREAD_NUM];
227         active_flags_ = new ActiveFlagType[MAX_THREAD_NUM];
228         for (int i = 0; i < MAX_THREAD_NUM; ++i) {
229             local_epoches_[i].epoch_ = 0;
230             active_flags_[i].active_ = false;
231         }
232     }
233 
234     ~SMRManager() {
235         delete[] local_epoches_;
236         delete[] active_flags_;
237     }
238 
239     SMRManager(const SMRManager& rhs) = delete;
240     SMRManager& operator=(const SMRManager& rhs) = delete;
241 
242     /*读者访问临界资源前,首先调用这个函数*/
243     void reader_enter() override {
244         active_flags_[get_thread_id()].active_ = true;
245         CPU_BARRIER();
246         local_epoches_[get_thread_id()].epoch_ = global_epoch_.epoch_;
247         CPU_BARRIER();
248     }
249 
250     /*读者离开临界区后,调用这个函数*/
251     void reader_leave() override {
252         CPU_BARRIER();
253         active_flags_[get_thread_id()].active_ = false;
254     }
255 
256     ReadIndicator& get_read_indicator() override {
257         static ReadIndicator indicator;
258         std::cout << "unexpected call, type_name:" << type_name << std::endl;
259         return indicator;
260     }
261 
262     int smr_type() const override { return 0; }
263 
264 private:
265     int32_t get_epoch() override {
266         return global_epoch_.epoch_;
267     }
268 
269     int writer_gc() override {
270         for (int i = 0; i < MAX_THREAD_NUM; i++) {
271             if (active_flags_[i].active_ && local_epoches_[i].epoch_ != global_epoch_.epoch_) {
272                 // 此时有活跃读线程 lag 了 epoch
273                 return writer_free((global_epoch_.epoch_ + 1) % 3);
274             }
275         }
276         // 所有的活跃读线程的 epoch 都升上来了
277         global_epoch_.epoch_ = (global_epoch_.epoch_ + 1) % 3;
278         CPU_BARRIER();  // 这个 memory barrier 的使用可以做一些优化
279         return writer_free((global_epoch_.epoch_ + 1) % 3);
280     }
281 
282     int get_thread_id() {
283         static __thread int tid = -1;
284         if (tid == -1) {
285             tid = FETCH_AND_ADD(&thread_count_, 1);
286             if (tid >= MAX_THREAD_NUM) {
287                 // abort();
288             }
289         }
290         return tid;
291     }
292 
293 private:
294     struct EpochType {
295         volatile int epoch_;
296     } __attribute__((aligned(CACHELINE_SIZE)));
297 
298     struct ActiveFlagType {
299         volatile bool active_;
300     } __attribute__((aligned(CACHELINE_SIZE)));  // 非常重要,保证原子,且避免 false sharing
301 
302 private:
303     EpochType global_epoch_;            // 全局 epoch
304     EpochType* local_epoches_;          // 数组,每个线程占一个元素
305     ActiveFlagType* active_flags_;      // 数组,每个线程占一个元素
306 
307     int32_t zombie_cnt_ = 0;
308     int32_t reclaim_cnt_ = 0;
309     static int thread_count_;
310 };
311 
312 template <typename T>
313 int SMRManager<T>::thread_count_ = 0;
314 
315 
316 
317 #endif //DEMO_SMR_H
View Code

 

 

 

 

SMR

SMR(Safety Memory Reclamation)是基于 Epoch Based Reclamation 的延迟 gc 技术,下面有个使用 SMR 实现线程安全的 map 例子:

  1 template <typename TKey, typename TValue>
  2 class ConcurrentMap {
  3 public:
  4     using Key = TKey;
  5     using Value = TValue;
  6 
  7     ConcurrentMap() {
  8         lazy_batch_.reserve(kLazyBatchSize);
  9         smr_.reset(new SMRManager<Value>());
 10     }
 11 
 12     ~ConcurrentMap() {
 13         data_.clear();
 14     }
 15 
 16     void Acquire() {
 17         smr_->reader_enter();
 18     }
 19     void Release() {
 20         smr_->reader_leave();
 21     }
 22     ReadIndicatorGuard *get_read_guard() {
 23         return smr_->get_read_guard();
 24     }
 25     int SmrType() const {
 26         return smr_->smr_type();
 27     }
 28     void SetSmrType() {
 29         smr_.reset(new SMRManager<Value>());
 30     }
 31 
 32     size_t Size() const {
 33         return data_.size();
 34     }
 35     size_t Capacity() const {
 36         return data_.capacity();
 37     }
 38     int32_t zombie_cnt() const {
 39         return smr_->zombie_cnt();
 40     }
 41     int32_t reclaim_cnt() const {
 42         return smr_->reclaim_cnt();
 43     }
 44 
 45     const Value *Get(const Key &key) const {
 46         const Value *node = nullptr;
 47         auto it = data_.find(key);
 48         if (it != data_.end()) {
 49             node = it->second;
 50         }
 51         return node;
 52     }
 53 
 54     void lazy_reclaim(Value *&node) {
 55         if (!node) {
 56             return;
 57         }
 58 
 59         lazy_batch_.push_back(node);
 60         if (lazy_batch_.size() >= kLazyBatchSize) {
 61             smr_->reclaim(lazy_batch_);
 62             lazy_batch_.clear();
 63         }
 64     }
 65     // 删除 key
 66     int Delete(const Key &key) {
 67         auto it = data_.find(key);
 68         if (it != data_.end()) {
 69             lazy_reclaim(it->second);
 70             data_.erase(it);
 71             return 1;
 72         }
 73 
 74         return 0;
 75     }
 76 
 77     // 新增或更新 key
 78     int Insert(const Key &key, const Value &value) {
 79         Value *node = new Value(value);
 80         Value *tmp = nullptr;
 81 
 82         auto it = data_.find(key);
 83         // key 存在
 84         if (it != data_.end()) {
 85             Value *&v = it->second;
 86             if (v->equal(*node)) {
 87                 // value 相等,不更新直接退出
 88                 delete node;
 89             } else {
 90                 // value 不相等,删除老的 value 对象
 91                 tmp = v;
 92                 v = node;
 93                 lazy_reclaim(tmp);
 94             }
 95         } else {
 96             data_.emplace(key, node);
 97         }
 98         return 0;
 99     }
100 
101 private:
102     const int32_t kLazyBatchSize = 1000;
103     std::vector<Value *> lazy_batch_;
104 
105     std::map<Key, Value *> data_;
106     std::unique_ptr<SMRManagerBase<Value>> smr_;
107 };
View Code

 

posted @ 2022-10-21 23:56  如果的事  阅读(1350)  评论(0)    收藏  举报