project2
Resources
-
https://github.com/cmu-db/bustub Bustub Github Repo
-
https://www.gradescope.com/ 自动测评网站 GradeScope,course entry code: PXWVR5
-
https://discord.gg/YF7dMCg Discord 论坛,课程交流用
-
bilibili 有搬运的课程视频,自寻。
-
B+ Tree Visualization B+ 树插入删除的动态演示, 我是按照这个逻辑实现的
下面放两个对我帮助很大的博客
[做个数据库](做个数据库:2022 CMU15-445 Project2 B+Tree Index - 知乎)
[迷茫就是能力配不上梦想](CMU15445-2022fall-Project2 - 知乎)
Overview
Project 2 需要为 Bustub 实现 B+ 树索引。拆分为两个部分:
- Checkpoint1: 单线程 B+ 树
- Checkpoint2: 多线程 B+ 树
实验中给出的 B+ 树接口非常简单,基本只有查询、插入和删除三个接口,内部基本没有给出别的辅助函数,可以让我们自由发挥(无从下手)。因此,任何合法的 B+ 树实现都是允许的。
B+ 树索引在 Bustub 中的位置如图所示:

需要使用我们在 Project 1 中实现的 buffer pool manager 来获取 page。
Checkpoint1 Single Thread B+Tree
Checkpoint1 分为两个部分:
- Task1: B+Tree pages,B+树中的各种 page。在 Bustub 索引 B+ 树中,所有的节点都是 page。包含 leaf page,internal page ,和它们的父类 tree page。
- Task2:B+Tree Data Structure (Insertion, Deletion, Point Search)。Checkpoint1 的重点,即 B+树的插入、删除和单点查询。
Task1 B+Tree Pages
Task1 的实现非常简单,都是一些普通的 Getter 和 Setter。这里主要介绍一下 page 的内存布局。
在 Project 1 中我们第一次与 page 打交道。page 实际上可以存储数据库内很多种类的数据。例如索引和实际的表数据等等。
/** The actual data that is stored within a page. */
char data_[BUSTUB_PAGE_SIZE]{};
/** The ID of this page. */
page_id_t page_id_ = INVALID_PAGE_ID;
/** The pin count of this page. */
int pin_count_ = 0;
/** True if the page is dirty, i.e. it is different from its corresponding page on disk. */
bool is_dirty_ = false;
/** Page latch. */
ReaderWriterLatch rwlatch_;

其中,data_ 是实际存放 page 数据的地方,大小为 BUSTUB_PAGE_SIZE,为 4KB。其他的成员是 page 的 metadata。
B+树中的 tree oage 数据均存放在 page 的 data 成员中。
B_PLUS_TREE_PAGE
b_plus_tree_page 是另外两个 page 的父类,即 B+树中 tree page 的抽象。
IndexPageType page_type_; // leaf or internal. 4 Byte
lsn_t lsn_ // temporarily unused. 4 Byte
int size_; // tree page data size(not in byte, in count). 4 Byte
int max_size_; // tree page data max size(not in byte, in count). 4 Byte
page_id_t parent_page_id_; // 4 Byte
page_id_t page_id_; // 4 Byte
// 24 Byte in total
以上数据组成了 tree page 的 header。

page data 的 4KB 中,24Byte 用于存放 header,剩下的则用于存放 tree page 的数据,即 KV 对。
B_PLUS_TREE_INTERNAL_PAGE
对应 B+ 树中的内部节点。
MappingType array_[1];
internal page 中没有新的 metadata,header 大小仍为 24B。它唯一的成员是这个怪怪的大小为 1 的数组。大小为 1 显然不合理,代表只能存放一个 KV 对。但又没法改变它的大小,难道要用 undefined behavior 来越界访问其后的地址?实际上差不多就是这个意思。但这不是 undefined behavior,是一种特殊的写法,叫做 flexible array。我也不知道怎么翻译。
简单来说就是,当你有一个类,这个类中有一个成员为数组。在用这个类初始化一个对象时,你不能确定该将这个数组的大小设置为多少,但知道这整个对象的大小是多少 byte,你就可以用到 flexible array。flexible array 必须是类中的最后一个成员,并且仅能有一个。在为对象分配内存时,flexible array 会自动填充,占用未被其他变量使用的内存。这样就可以确定自己的长度了。
例如有一个类 C:
class C {
int a; // 4 byte
int array[1]; // unknown size
};
现在初始化一个 C 的对象,并为其分配了 24 byte 的内存。a 占了 4 byte 内存,那么 array 会尝试填充剩下的内存,大小变为 5。
实际上这就是 C++ 对象内存布局的一个简单的例子。因此 flexible array 为什么只能有一个且必须放在最后一个就很明显了,因为需要向后尝试填充。
此外,虽然成员在内存中的先后顺序和声明的顺序一致,但需要注意可能存在的内存对齐的问题。header 中的数据大小都为 4 byte,没有对齐问题。
到这里,这个大小为 1 的数组的作用就比较清楚了。利用 flexible array 的特性来自动填充 page data 4KB 减掉 header 24byte 后剩余的内存。剩下的这些内存用来存放 KV 对。

internal page 中,KV 对的 K 是能够比较大小的索引,V 是 page id,用来指向下一层的节点。Project 中要求,第一个 Key 为空。主要是因为在 internal page 中,n 个 key 可以将数轴划分为 n+1 个区域,也就对应着 n+1 个 value。实际上你也可以把最后一个 key 当作是空的,只要后续的处理自洽就可以了。

通过比较 key 的大小选中下一层的节点。实际上等号的位置也可以改变,总之,只要是合法的 B+ 树,即节点大小需要满足最大最小值的限制,各种实现细节都是自由的。
另外需要注意的是,internal page 中的 key 并不代表实际上的索引值,仅仅是作为一个向导,引导需要插入/删除/查询的 key 找到这个 key 真正所在的 leaf page。
B_PLUS_TREE_LEAF_PAGE
leaf page 和 internal page 的内存布局基本一样,只是 leaf page 多了一个成员变量 next_page_id,指向下一个 leaf page(用于 range scan)。因此 leaf page 的 header 大小为 28 Byte。
leaf page 的 KV 对中,K 是实际的索引,V 是 record id。record id 用于识别表中的某一条数据。leaf page 的 KV 对是一一对应的,不像 internal page 的 value 多一个。这里也可以看出来 Bustub 所有的 B+ 树索引,无论是主键索引还是二级索引都是非聚簇索引。
这里简单介绍一下聚簇索引、非聚簇索引,主键索引、二级索引(非主键索引)的区别。 在聚簇索引里,leaf page 的 value 为表中一条数据的某几个字段或所有字段,一定包含主键字段。而非聚簇索引 leaf page 的 value 是 record id,即指向一条数据的指针。 在使用聚簇索引时,主键索引的 leaf page 包含所有字段,二级索引的 leaf page 包含主键和索引字段。当使用主键查询时,查询到 leaf page 即可获得整条数据。当使用二级索引查询时,若查询字段包含在索引内,可以直接得到结果,但如果查询字段不包含在索引内,则需使用得到的主键字段在主键索引中再次查询,以得到所有的字段,进而得到需要查询的字段,这就是回表的过程。 在使用非聚簇索引时,无论是使用主键查询还是二级索引查询,最终得到的结果都是 record id,需要使用 record id 去查询真正对应的整条记录。 聚簇索引的优点是,整条记录直接存放在 leaf page,无需二次查询,且缓存命中率高,在使用主键查询时性能比较好。缺点则是二级索引可能需要回表,且由于整条数据存放在 leaf page,更新索引的代价很高,页分裂、合并等情况开销比较大。 非聚簇索引的优点是,由于 leaf page 仅存放 record id,更新的代价较低,二级索引的性能和主键索引几乎相同。缺点是查询时均需使用 record id 进行二次查询。
Task1 的主要内容就是这些。实际上要实现的内容非常简单,重点是理解各个 page 的作用和内存布局。
Task2 B+Tree Data Structure (Insertion, Deletion, Point Search)
Task2 是单线程 B+ 树的重点, 细节很多。首先提供演示一个 B+ 树插入删除操作的 网站。主要是看看 B+ 树插入删除的各种细节变化。当然具体实现是自由的,这仅仅是一个示例。
1 首先是GetMinSize()
auto BPlusTreePage::IsRootPage() const -> bool { return parent_page_id_ == INVALID_PAGE_ID; }
auto BPlusTreePage::GetMinSize() const -> int {
if (IsRootPage()) {
if (IsLeafPage()) {
return 0;
}
return 1;
}
if (IsLeafPage()) {
return max_size_ / 2;
}
return (max_size_ + 1) / 2; // 增加内部节点的key数量, 保证每个key至少都有一个child, 并且使B+树的高度尽可能的小
}
2 利用ACID及时UnpinPage
要注意页面NewPage, FetchPage了,不用后及时UnpinPage, 我自己写了一个类利用ACID特性,把这些包装到构造和析构中
#pragma once
#include <stdexcept>
#include <type_traits>
#include "buffer/buffer_pool_manager.h"
#include "common/config.h"
#include "storage/page/b_plus_tree_internal_page.h"
#include "storage/page/b_plus_tree_leaf_page.h"
#include "storage/page/b_plus_tree_page.h"
namespace bustub {
template <typename PageType>
class PageGuard {
// 添加模板友元声明(允许所有 PageGuard 实例互访私有成员)
template <typename T>
friend class PageGuard;
public:
enum class LOCKTYPE { WRITE, READ, NONE };
enum class TRYLOCKTYPE { WRITE, READ, NONE };
// 空构造函数用于转换
PageGuard() = default;
// 禁用拷贝构造函数
PageGuard(const PageGuard &) = delete;
// 禁用拷贝赋值操作符
PageGuard &operator=(const PageGuard &) = delete;
// 基本构造函数
PageGuard(BufferPoolManager *bpm, page_id_t page_id, bool is_dirty = false, LOCKTYPE lock_type = LOCKTYPE::NONE,
TRYLOCKTYPE try_lock_type = TRYLOCKTYPE::NONE)
: bpm_(bpm), page_id_(page_id), is_dirty_(is_dirty), lock_type_(lock_type), try_lock_type_(try_lock_type) {
if (page_id != INVALID_PAGE_ID) {
page_ = bpm_->FetchPage(page_id_);
Lock(lock_type_);
TryLock(try_lock_type_);
} else {
page_ = bpm_->NewPage(&page_id_);
is_dirty_ = true; // 新页面自动标记为脏
}
}
// 类型安全转换构造函数
template <typename OtherType>
PageGuard(PageGuard<OtherType> &&other) {
static_assert(std::is_base_of_v<BPlusTreePage, PageType> && std::is_base_of_v<BPlusTreePage, OtherType>,
"Can only convert between B+ tree page types");
if (other.page_) {
TransferOwnershipFrom(other);
} else {
throw std::bad_cast();
}
}
// 移动构造函数
PageGuard(PageGuard &&other) noexcept
: bpm_(other.bpm_), page_(other.page_), page_id_(other.page_id_), is_dirty_(other.is_dirty_) {
other.Reset();
}
PageGuard &operator=(PageGuard &&other) noexcept {
if (this != &other) {
// 释放当前资源
if (page_ != nullptr) {
bpm_->UnpinPage(page_id_, is_dirty_);
}
// 转移资源
bpm_ = other.bpm_;
page_ = other.page_;
page_id_ = other.page_id_;
is_dirty_ = other.is_dirty_;
// 重置源对象
other.bpm_ = nullptr;
other.page_ = nullptr;
other.page_id_ = INVALID_PAGE_ID;
other.is_dirty_ = false;
}
return *this;
}
~PageGuard() {
if (page_ != nullptr) {
if (page_id_ != INVALID_PAGE_ID) {
bpm_->UnpinPage(page_id_, is_dirty_);
}
UnLock(lock_type_);
TryUnLock(try_lock_type_);
}
}
// 显式类型转换方法
template <typename TargetType>
PageGuard<TargetType> Convert() {
static_assert(std::is_base_of_v<BPlusTreePage, TargetType>, "Target must be a B+ tree page type");
PageGuard<TargetType> new_guard;
new_guard.TransferOwnershipFrom(*this);
return new_guard;
}
// 访问操作符
PageType *operator->() { return reinterpret_cast<PageType *>(page_->GetData()); }
// 状态检查
explicit operator bool() const { return page_ != nullptr; }
page_id_t GetPageId() const { return page_id_; }
void SetPageId(page_id_t page_id) { page_id_ = page_id; }
bool IsDirty() const { return is_dirty_; }
void SetDirty(bool dirty) { is_dirty_ = dirty; }
void Lock(LOCKTYPE lock_type) {
switch (lock_type) {
case LOCKTYPE::WRITE:
page_->WLatch();
break;
case LOCKTYPE::READ:
page_->RLatch();
break;
default:
break;
}
}
void TryLock(TRYLOCKTYPE lock_type) {
switch (lock_type) {
case TRYLOCKTYPE::WRITE:
trywlock_result_ = page_->TryWLatch();
break;
case TRYLOCKTYPE::READ:
tryrlock_result_ = page_->TryRLatch();
break;
default:
break;
}
}
void UnLock(LOCKTYPE lock_type) {
switch (lock_type) {
case LOCKTYPE::WRITE:
page_->WUnlatch();
break;
case LOCKTYPE::READ:
page_->RUnlatch();
break;
default:
break;
}
}
void TryUnLock(TRYLOCKTYPE lock_type) {
switch (lock_type) {
case TRYLOCKTYPE::WRITE:
if (trywlock_result_) page_->WUnlatch();
break;
case TRYLOCKTYPE::READ:
if (tryrlock_result_) page_->RUnlatch();
break;
default:
break;
}
}
bool GetTryrLockResult() { return tryrlock_result_; }
bool GetTrywLockResult() { return trywlock_result_; }
private:
// 资源转移方法
template <typename OtherPageType>
void TransferOwnershipFrom(PageGuard<OtherPageType> &other) {
bpm_ = other.bpm_;
page_ = other.page_;
page_id_ = other.page_id_;
is_dirty_ = other.is_dirty_;
other.Reset();
}
void Reset() {
page_ = nullptr;
page_id_ = INVALID_PAGE_ID;
is_dirty_ = false;
}
BufferPoolManager *bpm_ = nullptr;
Page *page_ = nullptr;
page_id_t page_id_ = INVALID_PAGE_ID;
bool is_dirty_ = false;
LOCKTYPE lock_type_ = LOCKTYPE::NONE;
TRYLOCKTYPE try_lock_type_ = TRYLOCKTYPE::NONE;
bool trywlock_result_ = false;
bool tryrlock_result_ = false;
};
} // namespace bustub
3 split和steal/merge的时机
插入和删除分别会涉及到页面的split和steal/merge, 有个关键的点是如何判断插入和删除的时机
插入的时机:
- 如果当GetSize() == GetMaxSize()的时候分裂,但是此时你可能要先分裂后插入
- 而我选择GetSize() > GetMaxSize()的时候分裂,并把条件判断放在insert逻辑的后面,此时你可以先插入后分裂, 注意特殊情况:
- 比如叶子页面: GetMaxSize() 为默认参数 LEAF_PAGE_SIZE ((BUSTUB_PAGE_SIZE - LEAF_PAGE_HEADER_SIZE) / sizeof(MappingType))
- 比如内部页面: GetMaxSize() 为默认参数 INTERNAL_PAGE_SIZE ((BUSTUB_PAGE_SIZE - INTERNAL_PAGE_HEADER_SIZE) / (sizeof(MappingType)))
叶子页面的插入:
INDEX_TEMPLATE_ARGUMENTS
auto B_PLUS_TREE_LEAF_PAGE_TYPE::IsFull() -> bool { return GetSize() > GetMaxSize(); }
INDEX_TEMPLATE_ARGUMENTS
auto B_PLUS_TREE_LEAF_PAGE_TYPE::Insert(const KeyType &key, const ValueType &value, const KeyComparator &cmp,
bool &have) -> bool {
...
}
内部页面的插入:
INDEX_TEMPLATE_ARGUMENTS
auto B_PLUS_TREE_INTERNAL_PAGE_TYPE::IsFull() -> bool { return GetSize() > GetMaxSize(); }
INDEX_TEMPLATE_ARGUMENTS
auto B_PLUS_TREE_INTERNAL_PAGE_TYPE::Insert(const KeyType &key, const ValueType &value, const KeyComparator &cmp,
BufferPoolManager *bpm, bool &have) -> bool {
...
}
steal/merge的时机:
- 同插入逻辑一样, 先删除,后判断, 当GetSize() < GetMinSize()时再steal/merge
- 注意特殊情况根节点, 因为根节点中k-v对数量可以少于一半
叶子页面的删除:
INDEX_TEMPLATE_ARGUMENTS
auto B_PLUS_TREE_LEAF_PAGE_TYPE::IsMin() -> bool {
if (IsRootPage()) {
return GetSize() == GetMinSize();
}
return GetSize() < GetMinSize();
}
INDEX_TEMPLATE_ARGUMENTS
auto B_PLUS_TREE_INTERNAL_PAGE_TYPE::Remove(const KeyType &key, const KeyComparator &KeyCmp, bool &nohave) -> bool {
...
}
内部页面的删除:
auto B_PLUS_TREE_INTERNAL_PAGE_TYPE::IsMin() -> bool {
if (IsRootPage()) {
return GetSize() == GetMinSize();
}
return GetSize() < GetMinSize();
}
INDEX_TEMPLATE_ARGUMENTS
auto B_PLUS_TREE_INTERNAL_PAGE_TYPE::Remove(const KeyType &key, const KeyComparator &KeyCmp, bool &nohave) -> bool {
...
}
4 辅助函数
不管split和steal/merge都需要先找到叶子结点:
INDEX_TEMPLATE_ARGUMENTS
auto BPLUSTREE_TYPE::FindLeafPage(const KeyType &key, page_id_t &page_id, Transaction *transaction) -> LeafPage * {
...
}
分裂叶子结点和内部节点:
INDEX_TEMPLATE_ARGUMENTS
void BPLUSTREE_TYPE::SplitLeafPage(page_id_t target_id) {
...
}
INDEX_TEMPLATE_ARGUMENTS
void BPLUSTREE_TYPE::SplitInternalPage(page_id_t target_id) {
...
}
合并叶子结点的内部节点:
INDEX_TEMPLATE_ARGUMENTS
void BPLUSTREE_TYPE::MergeLeafPage(page_id_t target_id, const KeyType &target_key) {
...
}
INDEX_TEMPLATE_ARGUMENTS
void BPLUSTREE_TYPE::MergeInternalPage(page_id_t target_id, const KeyType &target_key) {
...
}
还有其他辅助函数不在一一列举
5 插入和删除导致向上更新key
当插入和删除导致叶子页面的第一个key改变,可能(该叶子是父节点的第一个孩子)要向上传播
插入叶子第一个key:
auto old_key = leafpage->KeyAt(0);
auto target_old_key = old_key;
bool have = false;
bool flag = leafpage->Insert(key, value, comparator_, have);
if (!flag && have) {
buffer_pool_manager_->UnpinPage(leafpageid, false);
return false;
}
auto new_key = leafpage->KeyAt(0);
auto current_parent_id = leafpage->GetParentPageId();
buffer_pool_manager_->UnpinPage(leafpageid, true);
// 向上更新
while (comparator_(old_key, new_key) != 0 && current_parent_id != INVALID_PAGE_ID) {
auto parent_page = PageGuard<InternalPage>(buffer_pool_manager_, current_parent_id, true);
old_key = parent_page->KeyAt(0);
parent_page->SetKey(target_old_key, new_key, comparator_);
new_key = parent_page->KeyAt(0);
current_parent_id = parent_page->GetParentPageId();
}
删除: 删除叶子第一个key, 从右边窃取一个(左边删除后size为0), 向右合并,都可能会导致向上更新key
6 插入和删除时对叶子结点组成的单链表的维护
向左合并时,删除右边的结点, 比较简单:
left_page->SetNextPageId(right_page->GetNextPageId());
但是向右合并时,删除左边的结点target, 要先找到target的前驱:
-
从begin()开始遍历,效率太低
-
找先序便利的前驱, 即向上遍历找到还有左孩子的父节点, 再向右下遍历:
INDEX_TEMPLATE_ARGUMENTS auto BPLUSTREE_TYPE::FindPredecessorLeaf(page_id_t target_id, page_id_t parent_id, const KeyType &target_key) -> page_id_t { ... } INDEX_TEMPLATE_ARGUMENTS auto BPLUSTREE_TYPE::FindLastLeafPage(page_id_t &page_id) -> page_id_t { ... } // 维护单链表 page_id_t prev_leaf_id = FindPredecessorLeaf(target->GetPageId(), parent_id, target_key); if (prev_leaf_id != INVALID_PAGE_ID) { auto prev_leaf = PageGuard<LeafPage>(buffer_pool_manager_, prev_leaf_id, true, PageGuard<LeafPage>::LOCKTYPE::WRITE); prev_leaf->SetNextPageId(right_leaf_page->GetPageId()); prev_leaf.SetDirty(true); }
其他情况不再赘述, 直接上代码:
代码删了, 不展示了
Debug Your B+Tree
再次感叹真正的世界一流 CS 高校对课程项目设计的用心与体贴。为了方便调试,15-445 竟然帮我们实现了 B+ 树的可视化。有两种主要的方式:
- 使用已实现好的 b_plus_tree_printer 工具,可以自己对 B+ 树执行插入、删除等操作,并将结果输出为 dot 文件。
$ # To build the tool
$ mkdir build
$ cd build
$ make b_plus_tree_printer -j$(nproc)
$ ./bin/b_plus_tree_printer
>> ... USAGE ...
>> 5 5 # set leaf node and internal node max size to be 5
>> f input.txt # Insert into the tree with some inserts
>> g my-tree.dot # output the tree to dot format
>> q # Quit the test (Or use another terminal)
- 在代码中调用
BPlusTree的Draw()函数,可以在指定目录生成一个 dot 文件。
拿到 dot 文件后,可以在本地生成对应的 B+ 树 png:
dot -Tpng -O my-tree.dot
或者把文件内容复制到 这里。(更推荐,生成 svg,对 B+ 树大小无限制)
这个可视化工具对早期发现 B+ 树的各种基本 bug 非常有用。
至此,Checkpoint1 的内容已经全部完成。其实有很多细节我都没有提到,比如二分搜索的边界问题,如何查询左右兄弟节点,如何在兄弟节点间移动 KV 对,第一次 Insert 树为空怎么办等等。这些都属于比较细枝末节的问题,比较折磨人,但认真思考应该都能够解决。更重要的是,这些实现都是自由的。
测试
插入测试用例:
//===----------------------------------------------------------------------===//
//
// BusTub
//
// b_plus_tree_insert_test.cpp
//
// Identification: test/storage/b_plus_tree_insert_test.cpp
//
// Copyright (c) 2015-2021, Carnegie Mellon University Database Group
//
//===----------------------------------------------------------------------===//
// #include <algorithm>
// #include <cstdio>
// #include "buffer/buffer_pool_manager_instance.h"
// #include "gtest/gtest.h"
// #include "storage/index/b_plus_tree.h"
// #include "../src/include/common/logger.h"
// #include <random>
// #include "test_util.h" // NOLINT
// namespace bustub {
// TEST(BPlusTreeTests, InsertTest1) {
// // create KeyComparator and index schema
// auto key_schema = ParseCreateStatement("a bigint");
// GenericComparator<8> comparator(key_schema.get());
// auto *disk_manager = new DiskManager("test.db");
// BufferPoolManager *bpm = new BufferPoolManagerInstance(50, disk_manager);
// // create b+ tree
// BPlusTree<GenericKey<8>, RID, GenericComparator<8>> tree("foo_pk", bpm, comparator, 2, 3);
// GenericKey<8> index_key;
// RID rid;
// // create transaction
// auto *transaction = new Transaction(0);
// // create and fetch header_page
// page_id_t page_id;
// auto header_page = bpm->NewPage(&page_id);
// ASSERT_EQ(page_id, HEADER_PAGE_ID);
// (void)header_page;
// int64_t key = 42;
// int64_t value = key & 0xFFFFFFFF;
// rid.Set(static_cast<int32_t>(key), value);
// index_key.SetFromInteger(key);
// tree.Insert(index_key, rid, transaction);
// auto root_page_id = tree.GetRootPageId();
// auto root_page = reinterpret_cast<BPlusTreePage *>(bpm->FetchPage(root_page_id)->GetData());
// // LOG_INFO(" auto root_page->GetSize() = %d", root_page->GetSize());
// // LOG_INFO(" auto root_page->GetMaxSize() = %d", root_page->GetMaxSize());
// ASSERT_NE(root_page, nullptr);
// ASSERT_TRUE(root_page->IsLeafPage());
// auto root_as_leaf = reinterpret_cast<BPlusTreeLeafPage<GenericKey<8>, RID, GenericComparator<8>> *>(root_page);
// ASSERT_EQ(root_as_leaf->GetSize(), 1);
// ASSERT_EQ(comparator(root_as_leaf->KeyAt(0), index_key), 0);
// bpm->UnpinPage(root_page_id, false);
// bpm->UnpinPage(HEADER_PAGE_ID, true);
// delete transaction;
// delete disk_manager;
// delete bpm;
// remove("test.db");
// remove("test.log");
// }
// TEST(BPlusTreeTests, InsertTest2) {
// // create KeyComparator and index schema
// auto key_schema = ParseCreateStatement("a bigint");
// GenericComparator<8> comparator(key_schema.get());
// auto *disk_manager = new DiskManager("test.db");
// BufferPoolManager *bpm = new BufferPoolManagerInstance(50, disk_manager);
// // create b+ tree
// BPlusTree<GenericKey<8>, RID, GenericComparator<8>> tree("foo_pk", bpm, comparator, 2, 3);
// GenericKey<8> index_key;
// RID rid;
// // create transaction
// auto *transaction = new Transaction(0);
// // create and fetch header_page
// page_id_t page_id;
// auto header_page = bpm->NewPage(&page_id);
// (void)header_page;
// // std::vector<int64_t> keys = {16, 17, 18};
// int64_t scale = 100;
// std::vector<int64_t> keys;
// for (int64_t key = 1; key < scale + 1; key++) {
// keys.push_back(key);
// }
// // randomized the insertion order
// auto rng = std::default_random_engine{};
// std::shuffle(keys.begin(), keys.end(), rng);
// for (auto key : keys) {
// LOG_INFO("key = %d", static_cast<int32_t>(key));
// int64_t value = key & 0xFFFFFFFF;
// // LOG_INFO("key >> 32 = %d", static_cast<int32_t>(key >> 32));
// rid.Set(static_cast<int32_t>(key >> 32), value);
// index_key.SetFromInteger(key);
// tree.Insert(index_key, rid, transaction);
// tree.Draw(bpm, "./InsertTest_step" + std::to_string(static_cast<int32_t>(key)) + ".dot");
// // LOG_INFO("rids[0].GetSlotNum() = %d, value = %ld", rids[0].GetSlotNum(), value);
// }
// tree.Draw(bpm, "./InsertTest_step.dot");
// std::vector<RID> rids;
// for (auto key : keys) {
// rids.clear();
// index_key.SetFromInteger(key);
// tree.GetValue(index_key, &rids);
// EXPECT_EQ(rids.size(), 1);
// int64_t value = key & 0xFFFFFFFF;
// // LOG_INFO("rids[0].GetSlotNum() = %d, value = %ld", rids[0].GetSlotNum(), value);
// EXPECT_EQ(rids[0].GetSlotNum(), value);
// }
// int64_t size = 0;
// bool is_present;
// for (auto key : keys) {
// rids.clear();
// index_key.SetFromInteger(key);
// is_present = tree.GetValue(index_key, &rids);
// EXPECT_EQ(is_present, true);
// EXPECT_EQ(rids.size(), 1);
// EXPECT_EQ(rids[0].GetPageId(), 0);
// EXPECT_EQ(rids[0].GetSlotNum(), key);
// size = size + 1;
// }
// EXPECT_EQ(size, keys.size());
// bpm->UnpinPage(HEADER_PAGE_ID, true);
// delete transaction;
// delete disk_manager;
// delete bpm;
// remove("test.db");
// remove("test.log");
// }
// TEST(BPlusTreeTests, InsertTest3) {
// // create KeyComparator and index schema
// auto key_schema = ParseCreateStatement("a bigint");
// GenericComparator<8> comparator(key_schema.get());
// auto *disk_manager = new DiskManager("test.db");
// BufferPoolManager *bpm = new BufferPoolManagerInstance(50, disk_manager);
// // create b+ tree
// BPlusTree<GenericKey<8>, RID, GenericComparator<8>> tree("foo_pk", bpm, comparator);
// GenericKey<8> index_key;
// RID rid;
// // create transaction
// auto *transaction = new Transaction(0);
// // create and fetch header_page
// page_id_t page_id;
// auto header_page = bpm->NewPage(&page_id);
// ASSERT_EQ(page_id, HEADER_PAGE_ID);
// (void)header_page;
// std::vector<int64_t> keys = {5, 4, 3, 2, 1};
// for (auto key : keys) {
// int64_t value = key & 0xFFFFFFFF;
// rid.Set(static_cast<int32_t>(key >> 32), value);
// index_key.SetFromInteger(key);
// tree.Insert(index_key, rid, transaction);
// }
// std::vector<RID> rids;
// for (auto key : keys) {
// rids.clear();
// index_key.SetFromInteger(key);
// tree.GetValue(index_key, &rids);
// EXPECT_EQ(rids.size(), 1);
// int64_t value = key & 0xFFFFFFFF;
// EXPECT_EQ(rids[0].GetSlotNum(), value);
// }
// int64_t start_key = 1;
// int64_t current_key = start_key;
// index_key.SetFromInteger(start_key);
// for (auto iterator = tree.Begin(index_key); iterator != tree.End(); ++iterator) {
// auto location = (*iterator).second;
// EXPECT_EQ(location.GetPageId(), 0);
// EXPECT_EQ(location.GetSlotNum(), current_key);
// current_key = current_key + 1;
// }
// EXPECT_EQ(current_key, keys.size() + 1);
// start_key = 3;
// current_key = start_key;
// index_key.SetFromInteger(start_key);
// for (auto iterator = tree.Begin(index_key); iterator != tree.End(); ++iterator) {
// auto location = (*iterator).second;
// EXPECT_EQ(location.GetPageId(), 0);
// EXPECT_EQ(location.GetSlotNum(), current_key);
// current_key = current_key + 1;
// }
// bpm->UnpinPage(HEADER_PAGE_ID, true);
// delete transaction;
// delete disk_manager;
// delete bpm;
// remove("test.db");
// remove("test.log");
// }
// } // namespace bustub
//===----------------------------------------------------------------------===//
//
// BusTub
//
// b_plus_tree_insert_test.cpp
//
// Identification: test/storage/b_plus_tree_insert_test.cpp
//
// Copyright (c) 2015-2021, Carnegie Mellon University Database Group
//
//===----------------------------------------------------------------------===//
#include <algorithm>
#include <cstdio>
#include <random>
#include "../src/include/common/logger.h"
#include "buffer/buffer_pool_manager_instance.h"
#include "common/config.h"
#include "gtest/gtest.h"
#include "storage/index/b_plus_tree.h"
#include "test_util.h" // NOLINT
namespace bustub {
TEST(BPlusTreeTests, InsertTest1) {
// create KeyComparator and index schema
auto key_schema = ParseCreateStatement("a bigint");
GenericComparator<8> comparator(key_schema.get());
auto *disk_manager = new DiskManager("test.db");
BufferPoolManager *bpm = new BufferPoolManagerInstance(50, disk_manager);
// create b+ tree
BPlusTree<GenericKey<8>, RID, GenericComparator<8>> tree("foo_pk", bpm, comparator, 2, 3);
GenericKey<8> index_key;
RID rid;
// create transaction
auto *transaction = new Transaction(0);
// create and fetch header_page
page_id_t page_id;
auto header_page = bpm->NewPage(&page_id);
ASSERT_EQ(page_id, HEADER_PAGE_ID);
(void)header_page;
int64_t key = 42;
int64_t value = key & 0xFFFFFFFF;
rid.Set(static_cast<int32_t>(key), value);
index_key.SetFromInteger(key);
tree.Insert(index_key, rid, transaction);
auto root_page_id = tree.GetRootPageId();
auto root_page = reinterpret_cast<BPlusTreePage *>(bpm->FetchPage(root_page_id)->GetData());
ASSERT_NE(root_page, nullptr);
ASSERT_TRUE(root_page->IsLeafPage());
auto root_as_leaf = reinterpret_cast<BPlusTreeLeafPage<GenericKey<8>, RID, GenericComparator<8>> *>(root_page);
ASSERT_EQ(root_as_leaf->GetSize(), 1);
ASSERT_EQ(comparator(root_as_leaf->KeyAt(0), index_key), 0);
bpm->UnpinPage(root_page_id, false);
bpm->UnpinPage(HEADER_PAGE_ID, true);
delete transaction;
delete disk_manager;
delete bpm;
remove("test.db");
remove("test.log");
}
TEST(BPlusTreeTests, InsertTest2) {
// create KeyComparator and index schema
auto key_schema = ParseCreateStatement("a bigint");
GenericComparator<8> comparator(key_schema.get());
auto *disk_manager = new DiskManager("test.db");
BufferPoolManager *bpm = new BufferPoolManagerInstance(50, disk_manager);
// create b+ tree
BPlusTree<GenericKey<8>, RID, GenericComparator<8>> tree("foo_pk", bpm, comparator, 2, 3);
GenericKey<8> index_key;
RID rid;
// create transaction
auto *transaction = new Transaction(0);
// create and fetch header_page
page_id_t page_id;
auto header_page = bpm->NewPage(&page_id);
(void)header_page;
std::vector<int64_t> keys = {1, 2, 3, 4, 5};
for (auto key : keys) {
int64_t value = key & 0xFFFFFFFF;
rid.Set(static_cast<int32_t>(key >> 32), value);
index_key.SetFromInteger(key);
tree.Insert(index_key, rid, transaction);
}
std::vector<RID> rids;
for (auto key : keys) {
rids.clear();
index_key.SetFromInteger(key);
tree.GetValue(index_key, &rids);
EXPECT_EQ(rids.size(), 1);
int64_t value = key & 0xFFFFFFFF;
EXPECT_EQ(rids[0].GetSlotNum(), value);
}
int64_t size = 0;
bool is_present;
for (auto key : keys) {
rids.clear();
index_key.SetFromInteger(key);
is_present = tree.GetValue(index_key, &rids);
EXPECT_EQ(is_present, true);
EXPECT_EQ(rids.size(), 1);
EXPECT_EQ(rids[0].GetPageId(), 0);
EXPECT_EQ(rids[0].GetSlotNum(), key);
size = size + 1;
}
EXPECT_EQ(size, keys.size());
bpm->UnpinPage(HEADER_PAGE_ID, true);
delete transaction;
delete disk_manager;
delete bpm;
remove("test.db");
remove("test.log");
}
TEST(BPlusTreeTests, IteratorTest) {
// create KeyComparator and index schema
auto key_schema = ParseCreateStatement("a bigint");
GenericComparator<8> comparator(key_schema.get());
auto *disk_manager = new DiskManager("test.db");
BufferPoolManager *bpm = new BufferPoolManagerInstance(50, disk_manager);
// create b+ tree
BPlusTree<GenericKey<8>, RID, GenericComparator<8>> tree("foo_pk", bpm, comparator, 2, 3);
GenericKey<8> index_key;
RID rid;
// create transaction
auto *transaction = new Transaction(0);
// create and fetch header_page
page_id_t page_id;
auto header_page = bpm->NewPage(&page_id);
ASSERT_EQ(page_id, HEADER_PAGE_ID);
(void)header_page;
std::vector<int64_t> keys = {5, 4, 3, 2, 1};
for (auto key : keys) {
int64_t value = key & 0xFFFFFFFF;
rid.Set(static_cast<int32_t>(key >> 32), value);
index_key.SetFromInteger(key);
tree.Insert(index_key, rid, transaction);
}
std::vector<RID> rids;
for (auto key : keys) {
rids.clear();
index_key.SetFromInteger(key);
tree.GetValue(index_key, &rids);
EXPECT_EQ(rids.size(), 1);
int64_t value = key & 0xFFFFFFFF;
EXPECT_EQ(rids[0].GetSlotNum(), value);
}
int64_t start_key = 1;
int64_t current_key = start_key;
index_key.SetFromInteger(start_key);
for (auto iterator = tree.Begin(index_key); iterator != tree.End(); ++iterator) {
auto location = (*iterator).second;
EXPECT_EQ(location.GetPageId(), 0);
EXPECT_EQ(location.GetSlotNum(), current_key);
current_key = current_key + 1;
}
EXPECT_EQ(current_key, keys.size() + 1);
current_key = 1;
for (auto iterator = tree.Begin(); iterator != tree.End(); ++iterator) {
auto location = (*iterator).second;
EXPECT_EQ(location.GetPageId(), 0);
EXPECT_EQ(location.GetSlotNum(), current_key);
current_key = current_key + 1;
}
EXPECT_EQ(current_key, keys.size() + 1);
start_key = 3;
current_key = start_key;
index_key.SetFromInteger(start_key);
for (auto iterator = tree.Begin(index_key); iterator != tree.End(); ++iterator) {
auto location = (*iterator).second;
EXPECT_EQ(location.GetPageId(), 0);
EXPECT_EQ(location.GetSlotNum(), current_key);
current_key = current_key + 1;
}
bpm->UnpinPage(HEADER_PAGE_ID, true);
delete transaction;
delete disk_manager;
delete bpm;
remove("test.db");
remove("test.log");
}
/*
* Score: 20
* Description: Insert keys range from 1 to 5 repeatedly,
* check whether insertion of repeated keys fail.
* Then check whether the keys are distributed in separate
* leaf nodes
*/
TEST(BPlusTreeTests, SplitTest) {
// create KeyComparator and index schema
// create KeyComparator and index schema
auto key_schema = ParseCreateStatement("a bigint");
GenericComparator<8> comparator(key_schema.get());
auto *disk_manager = new DiskManager("test.db");
BufferPoolManager *bpm = new BufferPoolManagerInstance(50, disk_manager);
// create b+ tree
BPlusTree<GenericKey<8>, RID, GenericComparator<8>> tree("foo_pk", bpm, comparator, 2, 3);
GenericKey<8> index_key;
RID rid;
// create transaction
auto *transaction = new Transaction(0);
// create and fetch header_page
page_id_t page_id;
auto header_page = bpm->NewPage(&page_id);
(void)header_page;
std::vector<int64_t> keys = {1, 2, 3, 4, 5};
for (auto key : keys) {
int64_t value = key & 0xFFFFFFFF;
rid.Set(static_cast<int32_t>(key >> 32), value);
index_key.SetFromInteger(key);
tree.Insert(index_key, rid, transaction);
// tree.Draw(bpm, "./thhtest" + std::to_string(key) + ".dot");
}
// LOG_INFO("462");
// tree.Draw(bpm, "./thhtest1_.dot");
// insert into repetitive key, all failed
for (auto key : keys) {
int64_t value = key & 0xFFFFFFFF;
rid.Set(static_cast<int32_t>(key >> 32), value);
index_key.SetFromInteger(key);
EXPECT_EQ(false, tree.Insert(index_key, rid, transaction));
}
// LOG_INFO("471");
page_id_t root_page_id = tree.GetRootPageId();
index_key.SetFromInteger(1);
auto leaf_node = reinterpret_cast<BPlusTreeLeafPage<GenericKey<8>, RID, GenericComparator<8>> *>(
tree.FindLeafPage(index_key, root_page_id));
ASSERT_NE(nullptr, leaf_node);
EXPECT_EQ(1, leaf_node->GetSize());
EXPECT_EQ(2, leaf_node->GetMaxSize());
// tree.Draw(bpm, "./thhtest2_.dot");
// Check the next 4 pages
for (int i = 0; i < 3; i++) {
// LOG_INFO("i = %d", i);
EXPECT_NE(INVALID_PAGE_ID, leaf_node->GetNextPageId());
leaf_node = reinterpret_cast<BPlusTreeLeafPage<GenericKey<8>, RID, GenericComparator<8>> *>(
bpm->FetchPage(leaf_node->GetNextPageId()));
}
// LOG_INFO("487");
EXPECT_EQ(INVALID_PAGE_ID, leaf_node->GetNextPageId());
bpm->UnpinPage(HEADER_PAGE_ID, true);
delete transaction;
delete disk_manager;
delete bpm;
remove("test.db");
remove("test.log");
}
/*
* Score: 20
* Description: Insert a set of keys range from 1 to 5 in the
* increasing order. Check whether the key-value pair is valid
* using GetValue
*/
TEST(BPlusTreeTests, InsertTest0) {
// create KeyComparator and index schema
auto key_schema = ParseCreateStatement("a bigint");
GenericComparator<8> comparator(key_schema.get());
auto *disk_manager = new DiskManager("test.db");
BufferPoolManager *bpm = new BufferPoolManagerInstance(50, disk_manager);
// create b+ tree
BPlusTree<GenericKey<8>, RID, GenericComparator<8>> tree("foo_pk", bpm, comparator, 2, 3);
GenericKey<8> index_key;
RID rid;
// create transaction
auto *transaction = new Transaction(0);
// create and fetch header_page
page_id_t page_id;
auto header_page = bpm->NewPage(&page_id);
(void)header_page;
std::vector<int64_t> keys = {1, 2, 3, 4, 5};
for (auto key : keys) {
int64_t value = key & 0xFFFFFFFF;
rid.Set(static_cast<int32_t>(key >> 32), value);
index_key.SetFromInteger(key);
tree.Insert(index_key, rid, transaction);
}
std::vector<RID> rids;
for (auto key : keys) {
rids.clear();
index_key.SetFromInteger(key);
tree.GetValue(index_key, &rids);
EXPECT_EQ(rids.size(), 1);
int64_t value = key & 0xFFFFFFFF;
EXPECT_EQ(rids[0].GetSlotNum(), value);
}
bpm->UnpinPage(HEADER_PAGE_ID, true);
delete transaction;
delete disk_manager;
delete bpm;
remove("test.db");
remove("test.log");
}
/*
* Score: 30
* Description: Insert a set of keys range from 1 to 5 in
* a reversed order. Check whether the key-value pair is valid
* using GetValue
*/
TEST(BPlusTreeTests, InsertTest3) {
// create KeyComparator and index schema
auto key_schema = ParseCreateStatement("a bigint");
GenericComparator<8> comparator(key_schema.get());
auto *disk_manager = new DiskManager("test.db");
BufferPoolManager *bpm = new BufferPoolManagerInstance(50, disk_manager);
// create b+ tree
BPlusTree<GenericKey<8>, RID, GenericComparator<8>> tree("foo_pk", bpm, comparator, 2, 3);
GenericKey<8> index_key;
RID rid;
// create transaction
auto *transaction = new Transaction(0);
// create and fetch header_page
page_id_t page_id;
auto header_page = bpm->NewPage(&page_id);
(void)header_page;
std::vector<int64_t> keys = {5, 4, 3, 2, 1};
for (auto key : keys) {
int64_t value = key & 0xFFFFFFFF;
rid.Set(static_cast<int32_t>(key >> 32), value);
index_key.SetFromInteger(key);
tree.Insert(index_key, rid, transaction);
}
std::vector<RID> rids;
for (auto key : keys) {
rids.clear();
index_key.SetFromInteger(key);
tree.GetValue(index_key, &rids);
EXPECT_EQ(rids.size(), 1);
int64_t value = key & 0xFFFFFFFF;
EXPECT_EQ(rids[0].GetSlotNum(), value);
}
bpm->UnpinPage(HEADER_PAGE_ID, true);
delete transaction;
delete disk_manager;
delete bpm;
remove("test.db");
remove("test.log");
}
/*
* Score: 30
* Description: Insert a set of keys range from 1 to 10000 in
* a random order. Check whether the key-value pair is valid
* using GetValue
*/
TEST(BPlusTreeTests, ScaleTest) {
// create KeyComparator and index schema
auto key_schema = ParseCreateStatement("a bigint");
GenericComparator<8> comparator(key_schema.get());
auto *disk_manager = new DiskManager("test.db");
BufferPoolManager *bpm = new BufferPoolManagerInstance(50, disk_manager);
// create b+ tree
BPlusTree<GenericKey<8>, RID, GenericComparator<8>> tree("foo_pk", bpm, comparator);
GenericKey<8> index_key;
RID rid;
// create transaction
auto *transaction = new Transaction(0);
// create and fetch header_page
page_id_t page_id;
auto header_page = bpm->NewPage(&page_id);
(void)header_page;
int64_t scale = 10000;
std::vector<int64_t> keys;
for (int64_t key = 1; key < scale; key++) {
keys.push_back(key);
}
// LOG_INFO("616");
// randomized the insertion order
auto rng = std::default_random_engine{};
std::shuffle(keys.begin(), keys.end(), rng);
// int step = 0;
for (auto key : keys) {
// step++;
// LOG_INFO("key = %ld", key);
int64_t value = key & 0xFFFFFFFF;
rid.Set(static_cast<int32_t>(key >> 32), value);
index_key.SetFromInteger(key);
tree.Insert(index_key, rid, transaction);
// tree.Draw(bpm, "./thh-insert" + std::to_string(key) + ".dot");
}
// tree.Draw(bpm, "./thh-insert" + std::to_string(1000) + ".dot");
// LOG_INFO("627");
std::vector<RID> rids;
for (auto key : keys) {
rids.clear();
index_key.SetFromInteger(key);
tree.GetValue(index_key, &rids);
EXPECT_EQ(rids.size(), 1);
int64_t value = key & 0xFFFFFFFF;
EXPECT_EQ(rids[0].GetSlotNum(), value);
}
bpm->UnpinPage(HEADER_PAGE_ID, true);
delete transaction;
delete disk_manager;
delete bpm;
remove("test.db");
remove("test.log");
}
TEST(BPlusTreeTests, Scaled_InsertTest1) {
// create KeyComparator and index schema
auto key_schema = ParseCreateStatement("a bigint");
GenericComparator<8> comparator(key_schema.get());
auto *disk_manager = new DiskManager("test.db");
BufferPoolManager *bpm = new BufferPoolManagerInstance(50, disk_manager);
// create b+ tree
BPlusTree<GenericKey<8>, RID, GenericComparator<8>> tree("foo_pk", bpm, comparator, 3, 3);
GenericKey<8> index_key{};
RID rid;
// create transaction
auto *transaction = new Transaction(0);
// create and fetch header_page
page_id_t page_id;
auto header_page = bpm->NewPage(&page_id);
(void)header_page;
// DONE(wxx) 修改这里测试
int size = 10000;
std::vector<int64_t> keys(size);
std::iota(keys.begin(), keys.end(), 1);
std::random_device rd;
std::mt19937 g(rd());
std::shuffle(keys.begin(), keys.end(), g);
// int cnt = 0;
for (auto key : keys) {
// cnt++;
// std::cout << cnt << std::endl;
// tree.Draw(bpm, "/home/silas/tree/tree-insert" + std::to_string(cnt) + "key" + std::to_string(key) + ".dot");
int64_t value = key & 0xFFFFFFFF;
rid.Set(static_cast<int32_t>(key >> 32), value);
index_key.SetFromInteger(key);
tree.Insert(index_key, rid, transaction);
}
std::vector<RID> rids;
std::shuffle(keys.begin(), keys.end(), g);
for (auto key : keys) {
rids.clear();
index_key.SetFromInteger(key);
tree.GetValue(index_key, &rids);
EXPECT_EQ(rids.size(), 1);
int64_t value = key & 0xFFFFFFFF;
EXPECT_EQ(rids[0].GetSlotNum(), value);
}
std::shuffle(keys.begin(), keys.end(), g);
// LOG_INFO("718");
for (auto key : keys) {
// ++cnt;
// tree.Draw(bpm, "/home/silas/tree/tree-remove" + std::to_string(cnt) + "key" + std::to_string(key) + ".dot");
index_key.SetFromInteger(key);
tree.Remove(index_key, transaction);
}
// LOG_INFO("725");
EXPECT_EQ(true, tree.IsEmpty());
bpm->UnpinPage(HEADER_PAGE_ID, true);
delete transaction;
delete disk_manager;
delete bpm;
remove("test.db");
remove("test.log");
}
/*
* Score: 20
* Description: Insert keys range from 1 to 5 repeatedly,
* check whether insertion of repeated keys fail.
* Then check whether the keys are distributed in separate
* leaf nodes
*/
TEST(BPlusTreeConcurrentTestC1, SplitTest) {
// create KeyComparator and index schema
auto key_schema = ParseCreateStatement("a bigint");
GenericComparator<8> comparator(key_schema.get());
DiskManager *disk_manager = new DiskManager("test.db");
BufferPoolManager *bpm = new BufferPoolManagerInstance(50, disk_manager);
// create b+ tree
BPlusTree<GenericKey<8>, RID, GenericComparator<8>> tree("foo_pk", bpm, comparator, 2, 3);
GenericKey<8> index_key;
RID rid;
// create transaction
Transaction *transaction = new Transaction(0);
// create and fetch header_page
page_id_t page_id;
auto header_page = bpm->NewPage(&page_id);
(void)header_page;
std::vector<int64_t> keys = {1, 2, 3, 4, 5};
for (auto key : keys) {
int64_t value = key & 0xFFFFFFFF;
rid.Set(static_cast<int32_t>(key >> 32), value);
index_key.SetFromInteger(key);
tree.Insert(index_key, rid, transaction);
}
// tree.Draw(bpm, "/home/silas/tree/tree-split" + std::to_string(12345) + ".dot");
// insert into repetitive key, all failed
for (auto key : keys) {
int64_t value = key & 0xFFFFFFFF;
rid.Set(static_cast<int32_t>(key >> 32), value);
index_key.SetFromInteger(key);
EXPECT_EQ(false, tree.Insert(index_key, rid, transaction));
}
page_id_t root_page_id = tree.GetRootPageId();
index_key.SetFromInteger(1);
auto leaf_node = reinterpret_cast<BPlusTreeLeafPage<GenericKey<8>, RID, GenericComparator<8>> *>(
tree.FindLeafPage(index_key, root_page_id));
ASSERT_NE(nullptr, leaf_node);
EXPECT_EQ(1, leaf_node->GetSize());
EXPECT_EQ(2, leaf_node->GetMaxSize());
// Check the next 4 pages
for (int i = 0; i < 3; i++) {
EXPECT_NE(INVALID_PAGE_ID, leaf_node->GetNextPageId());
leaf_node = reinterpret_cast<BPlusTreeLeafPage<GenericKey<8>, RID, GenericComparator<8>> *>(
bpm->FetchPage(leaf_node->GetNextPageId()));
}
EXPECT_EQ(INVALID_PAGE_ID, leaf_node->GetNextPageId());
bpm->UnpinPage(HEADER_PAGE_ID, true);
delete transaction;
delete disk_manager;
delete bpm;
remove("test.db");
remove("test.log");
}
/*
* Score: 20
* Description: Insert a set of keys range from 1 to 5 in the
* increasing order. Check whether the key-value pair is valid
* using GetValue
*/
TEST(BPlusTreeConcurrentTestC1, InsertTest1) {
// create KeyComparator and index schema
auto key_schema = ParseCreateStatement("a bigint");
GenericComparator<8> comparator(key_schema.get());
DiskManager *disk_manager = new DiskManager("test.db");
BufferPoolManager *bpm = new BufferPoolManagerInstance(50, disk_manager);
// create b+ tree
BPlusTree<GenericKey<8>, RID, GenericComparator<8>> tree("foo_pk", bpm, comparator);
GenericKey<8> index_key;
RID rid;
// create transaction
Transaction *transaction = new Transaction(0);
// create and fetch header_page
page_id_t page_id;
auto header_page = bpm->NewPage(&page_id);
(void)header_page;
std::vector<int64_t> keys = {1, 2, 3, 4, 5};
for (auto key : keys) {
int64_t value = key & 0xFFFFFFFF;
rid.Set(static_cast<int32_t>(key >> 32), value);
index_key.SetFromInteger(key);
tree.Insert(index_key, rid, transaction);
}
std::vector<RID> rids;
for (auto key : keys) {
rids.clear();
index_key.SetFromInteger(key);
tree.GetValue(index_key, &rids);
EXPECT_EQ(rids.size(), 1);
int64_t value = key & 0xFFFFFFFF;
EXPECT_EQ(rids[0].GetSlotNum(), value);
}
bpm->UnpinPage(HEADER_PAGE_ID, true);
delete transaction;
delete disk_manager;
delete bpm;
remove("test.db");
remove("test.log");
}
/*
* Score: 30
* Description: Insert a set of keys range from 1 to 5 in
* a reversed order. Check whether the key-value pair is valid
* using GetValue
*/
TEST(BPlusTreeConcurrentTestC1, InsertTest2) {
// create KeyComparator and index schema
auto key_schema = ParseCreateStatement("a bigint");
GenericComparator<8> comparator(key_schema.get());
DiskManager *disk_manager = new DiskManager("test.db");
BufferPoolManager *bpm = new BufferPoolManagerInstance(50, disk_manager);
// create b+ tree
BPlusTree<GenericKey<8>, RID, GenericComparator<8>> tree("foo_pk", bpm, comparator);
GenericKey<8> index_key;
RID rid;
// create transaction
Transaction *transaction = new Transaction(0);
// create and fetch header_page
page_id_t page_id;
auto header_page = bpm->NewPage(&page_id);
(void)header_page;
std::vector<int64_t> keys = {5, 4, 3, 2, 1};
for (auto key : keys) {
int64_t value = key & 0xFFFFFFFF;
rid.Set(static_cast<int32_t>(key >> 32), value);
index_key.SetFromInteger(key);
tree.Insert(index_key, rid, transaction);
}
std::vector<RID> rids;
for (auto key : keys) {
rids.clear();
index_key.SetFromInteger(key);
tree.GetValue(index_key, &rids);
EXPECT_EQ(rids.size(), 1);
int64_t value = key & 0xFFFFFFFF;
EXPECT_EQ(rids[0].GetSlotNum(), value);
}
bpm->UnpinPage(HEADER_PAGE_ID, true);
delete transaction;
delete disk_manager;
delete bpm;
remove("test.db");
remove("test.log");
}
/*
* Score: 30
* Description: Insert a set of keys range from 1 to 10000 in
* a random order. Check whether the key-value pair is valid
* using GetValue
*/
TEST(BPlusTreeConcurrentTestC1, ScaleTestC1) {
// create KeyComparator and index schema
auto key_schema = ParseCreateStatement("a bigint");
GenericComparator<8> comparator(key_schema.get());
DiskManager *disk_manager = new DiskManager("test.db");
BufferPoolManager *bpm = new BufferPoolManagerInstance(50, disk_manager);
// create b+ tree
BPlusTree<GenericKey<8>, RID, GenericComparator<8>> tree("foo_pk", bpm, comparator);
GenericKey<8> index_key;
RID rid;
// create transaction
Transaction *transaction = new Transaction(0);
// create and fetch header_page
page_id_t page_id;
auto header_page = bpm->NewPage(&page_id);
(void)header_page;
int64_t scale = 100;
std::vector<int64_t> keys;
for (int64_t key = 1; key < scale; key++) {
keys.push_back(key);
}
// randomized the insertion order
auto rng = std::default_random_engine{};
std::shuffle(keys.begin(), keys.end(), rng);
for (auto key : keys) {
int64_t value = key & 0xFFFFFFFF;
rid.Set(static_cast<int32_t>(key >> 32), value);
index_key.SetFromInteger(key);
tree.Insert(index_key, rid, transaction);
}
std::vector<RID> rids;
for (auto key : keys) {
rids.clear();
index_key.SetFromInteger(key);
tree.GetValue(index_key, &rids);
EXPECT_EQ(rids.size(), 1);
int64_t value = key & 0xFFFFFFFF;
EXPECT_EQ(rids[0].GetSlotNum(), value);
}
// tree.Draw(bpm, "/home/silas/tree/tree-" + std::to_string(2) + "-before_delete" + std::to_string(1000) + ".dot");
for (auto key : keys) {
index_key.SetFromInteger(key);
tree.Remove(index_key, transaction);
}
// tree.Draw(bpm, "/home/silas/tree/tree-" + std::to_string(2) + "-delete" + std::to_string(1000) + ".dot");
for (auto key : keys) {
int64_t value = key & 0xFFFFFFFF;
rid.Set(static_cast<int32_t>(key >> 32), value);
index_key.SetFromInteger(key);
tree.Insert(index_key, rid, transaction);
}
// tree.Draw(bpm, "/home/silas/tree/tree-" + std::to_string(3) + "-insert" + std::to_string(1000) + ".dot");
for (auto key : keys) {
rids.clear();
index_key.SetFromInteger(key);
tree.GetValue(index_key, &rids);
EXPECT_EQ(rids.size(), 1);
int64_t value = key & 0xFFFFFFFF;
EXPECT_EQ(rids[0].GetSlotNum(), value);
}
bpm->UnpinPage(HEADER_PAGE_ID, true);
delete transaction;
delete disk_manager;
delete bpm;
remove("test.db");
remove("test.log");
}
TEST(BPlusTreeConcurrentTestC1, Scale_InsertTest1) {
// create KeyComparator and index schema
auto key_schema = ParseCreateStatement("a bigint");
GenericComparator<8> comparator(key_schema.get());
auto *disk_manager = new DiskManager("test.db");
BufferPoolManager *bpm = new BufferPoolManagerInstance(50, disk_manager);
// create b+ tree
BPlusTree<GenericKey<8>, RID, GenericComparator<8>> tree("foo_pk", bpm, comparator, 6, 6);
GenericKey<8> index_key;
RID rid;
// create transaction
auto *transaction = new Transaction(0);
// create and fetch header_page
page_id_t page_id;
auto header_page = bpm->NewPage(&page_id);
(void)header_page;
int size = 10000;
std::vector<int64_t> keys(size);
std::iota(keys.begin(), keys.end(), 1);
std::random_device rd;
std::mt19937 g(rd());
std::shuffle(keys.begin(), keys.end(), g);
std::cout << "---------" << std::endl;
int i = 0;
(void)i;
for (auto key : keys) {
i++;
// std::cout << i << std::endl;
int64_t value = key & 0xFFFFFFFF;
rid.Set(static_cast<int32_t>(key >> 32), value);
index_key.SetFromInteger(key);
tree.Insert(index_key, rid, transaction);
}
std::vector<RID> rids;
std::shuffle(keys.begin(), keys.end(), g);
for (auto key : keys) {
rids.clear();
index_key.SetFromInteger(key);
tree.GetValue(index_key, &rids);
EXPECT_EQ(rids.size(), 1);
int64_t value = key & 0xFFFFFFFF;
EXPECT_EQ(rids[0].GetSlotNum(), value);
}
int64_t start_key = 1;
int64_t current_key = start_key;
index_key.SetFromInteger(start_key);
for (auto iterator = tree.Begin(index_key); iterator != tree.End(); ++iterator) {
auto location = (*iterator).second;
EXPECT_EQ(location.GetPageId(), 0);
EXPECT_EQ(location.GetSlotNum(), current_key);
current_key = current_key + 1;
}
EXPECT_EQ(current_key, keys.size() + 1);
// i = 0;
// keys = {17, 9, 19, 3, 11, 1, 15, 7, 5, 13};
std::shuffle(keys.begin(), keys.end(), g);
for (auto key : keys) {
i++;
index_key.SetFromInteger(key);
tree.Remove(index_key, transaction);
}
EXPECT_EQ(true, tree.IsEmpty());
bpm->UnpinPage(HEADER_PAGE_ID, true);
delete transaction;
delete disk_manager;
delete bpm;
remove("test.db");
remove("test.log");
}
/*
* Score: 5
* Description: The same test that has been run for checkpoint 1,
* but added iterator for value checking
*/
TEST(GradeScopeBPlusTreeTests, InsertTest1) {
auto key_schema = ParseCreateStatement("a bigint");
GenericComparator<8> comparator(key_schema.get());
auto *disk_manager = new DiskManager("test.db");
BufferPoolManager *bpm = new BufferPoolManagerInstance(50, disk_manager);
// create b+ tree
BPlusTree<GenericKey<8>, RID, GenericComparator<8>> tree("foo_pk", bpm, comparator);
GenericKey<8> index_key;
RID rid;
// create transaction
Transaction *transaction = new Transaction(0);
// create and fetch header_page
page_id_t page_id;
auto header_page = bpm->NewPage(&page_id);
(void)header_page;
std::vector<int64_t> keys = {1, 2, 3, 4, 5};
for (auto key : keys) {
int64_t value = key & 0xFFFFFFFF;
rid.Set(static_cast<int32_t>(key >> 32), value);
index_key.SetFromInteger(key);
tree.Insert(index_key, rid, transaction);
}
std::vector<RID> rids;
for (auto key : keys) {
rids.clear();
index_key.SetFromInteger(key);
tree.GetValue(index_key, &rids);
EXPECT_EQ(rids.size(), 1);
int64_t value = key & 0xFFFFFFFF;
EXPECT_EQ(rids[0].GetSlotNum(), value);
}
int64_t start_key = 1;
int64_t current_key = start_key;
index_key.SetFromInteger(start_key);
for (auto iterator = tree.Begin(index_key); iterator != tree.End(); ++iterator) {
auto location = (*iterator).second;
EXPECT_EQ(location.GetPageId(), 0);
EXPECT_EQ(location.GetSlotNum(), current_key);
current_key = current_key + 1;
}
EXPECT_EQ(current_key, keys.size() + 1);
bpm->UnpinPage(HEADER_PAGE_ID, true);
delete transaction;
delete disk_manager;
delete bpm;
remove("test.db");
remove("test.log");
}
/*
* Score: 5
* Description: The same test that has been run for checkpoint 1
* but added iterator for value checking
*/
TEST(GradeScopeBPlusTreeTests, InsertTest2) {
auto key_schema = ParseCreateStatement("a bigint");
GenericComparator<8> comparator(key_schema.get());
auto *disk_manager = new DiskManager("test.db");
BufferPoolManager *bpm = new BufferPoolManagerInstance(50, disk_manager);
// create b+ tree
BPlusTree<GenericKey<8>, RID, GenericComparator<8>> tree("foo_pk", bpm, comparator, 2, 3);
GenericKey<8> index_key;
RID rid;
// create transaction
Transaction *transaction = new Transaction(0);
// create and fetch header_page
page_id_t page_id;
auto header_page = bpm->NewPage(&page_id);
(void)header_page;
std::vector<int64_t> keys = {5, 4, 3, 2, 1};
for (auto key : keys) {
int64_t value = key & 0xFFFFFFFF;
rid.Set(static_cast<int32_t>(key >> 32), value);
index_key.SetFromInteger(key);
tree.Insert(index_key, rid, transaction);
}
std::vector<RID> rids;
for (auto key : keys) {
rids.clear();
index_key.SetFromInteger(key);
tree.GetValue(index_key, &rids);
EXPECT_EQ(rids.size(), 1);
int64_t value = key & 0xFFFFFFFF;
EXPECT_EQ(rids[0].GetSlotNum(), value);
}
int64_t start_key = 1;
int64_t current_key = start_key;
index_key.SetFromInteger(start_key);
for (auto iterator = tree.Begin(); iterator != tree.End(); ++iterator) {
auto location = (*iterator).second;
EXPECT_EQ(location.GetPageId(), 0);
EXPECT_EQ(location.GetSlotNum(), current_key);
current_key = current_key + 1;
}
EXPECT_EQ(current_key, keys.size() + 1);
start_key = 3;
current_key = start_key;
index_key.SetFromInteger(start_key);
for (auto iterator = tree.Begin(index_key); !iterator.IsEnd(); ++iterator) {
auto location = (*iterator).second;
EXPECT_EQ(location.GetPageId(), 0);
EXPECT_EQ(location.GetSlotNum(), current_key);
current_key = current_key + 1;
}
bpm->UnpinPage(HEADER_PAGE_ID, true);
delete transaction;
delete disk_manager;
delete bpm;
remove("test.db");
remove("test.log");
}
/*
* Score: 10
* Description: Insert a set of keys, use GetValue and iterator to
* check the the inserted keys. Then delete a subset of the keys.
* Finally use the iterator to check the remained keys.
*/
TEST(GradeScopeBPlusTreeTests, DeleteTest1) {
auto key_schema = ParseCreateStatement("a bigint");
GenericComparator<8> comparator(key_schema.get());
auto *disk_manager = new DiskManager("test.db");
BufferPoolManager *bpm = new BufferPoolManagerInstance(50, disk_manager);
// create b+ tree
BPlusTree<GenericKey<8>, RID, GenericComparator<8>> tree("foo_pk", bpm, comparator);
GenericKey<8> index_key;
RID rid;
// create transaction
Transaction *transaction = new Transaction(0);
// create and fetch header_page
page_id_t page_id;
auto header_page = bpm->NewPage(&page_id);
(void)header_page;
std::vector<int64_t> keys = {1, 2, 3, 4, 5};
for (auto key : keys) {
int64_t value = key & 0xFFFFFFFF;
rid.Set(static_cast<int32_t>(key >> 32), value);
index_key.SetFromInteger(key);
tree.Insert(index_key, rid, transaction); // 传入了transaction
}
std::vector<RID> rids;
for (auto key : keys) {
rids.clear();
index_key.SetFromInteger(key);
tree.GetValue(index_key, &rids);
EXPECT_EQ(rids.size(), 1);
int64_t value = key & 0xFFFFFFFF;
EXPECT_EQ(rids[0].GetSlotNum(), value);
}
int64_t start_key = 1;
int64_t current_key = start_key;
// for (auto pair : tree) {
// auto location = pair.second;
// EXPECT_EQ(location.GetPageId(), 0);
// EXPECT_EQ(location.GetSlotNum(), current_key);
// current_key = current_key + 1;
// }
index_key.SetFromInteger(start_key);
for (auto iterator = tree.Begin(index_key); iterator != tree.End(); ++iterator) {
auto location = (*iterator).second;
EXPECT_EQ(0, location.GetPageId());
EXPECT_EQ(current_key, location.GetSlotNum());
current_key = current_key + 1;
}
EXPECT_EQ(current_key, keys.size() + 1);
std::vector<int64_t> remove_keys = {1, 5};
for (auto key : remove_keys) {
index_key.SetFromInteger(key);
tree.Remove(index_key, transaction);
}
start_key = 2;
current_key = start_key;
int64_t size = 0;
// for (auto pair : tree) {
// auto location = pair.second;
// EXPECT_EQ(location.GetPageId(), 0);
// EXPECT_EQ(location.GetSlotNum(), current_key);
// current_key = current_key + 1;
// size = size + 1;
// }
index_key.SetFromInteger(start_key);
for (auto iterator = tree.Begin(index_key); iterator != tree.End(); ++iterator) {
auto location = (*iterator).second;
EXPECT_EQ(0, location.GetPageId());
EXPECT_EQ(current_key, location.GetSlotNum());
current_key = current_key + 1;
size = size + 1;
}
EXPECT_EQ(size, 3);
bpm->UnpinPage(HEADER_PAGE_ID, true);
delete transaction;
delete disk_manager;
delete bpm;
remove("test.db");
remove("test.log");
}
/*
* Score: 10
* Description: Similar to DeleteTest2, except that, during the Remove step,
* a different subset of keys are removed.
*/
TEST(GradeScopeBPlusTreeTests, DeleteTest2) {
auto key_schema = ParseCreateStatement("a bigint");
GenericComparator<8> comparator(key_schema.get());
auto *disk_manager = new DiskManager("test.db");
BufferPoolManager *bpm = new BufferPoolManagerInstance(50, disk_manager);
// create b+ tree
BPlusTree<GenericKey<8>, RID, GenericComparator<8>> tree("foo_pk", bpm, comparator);
GenericKey<8> index_key;
RID rid;
// create transaction
Transaction *transaction = new Transaction(0);
// create and fetch header_page
page_id_t page_id;
auto header_page = bpm->NewPage(&page_id);
(void)header_page;
std::vector<int64_t> keys = {1, 2, 3, 4, 5};
for (auto key : keys) {
int64_t value = key & 0xFFFFFFFF;
rid.Set(static_cast<int32_t>(key >> 32), value);
index_key.SetFromInteger(key);
tree.Insert(index_key, rid, transaction);
}
std::vector<RID> rids;
for (auto key : keys) {
rids.clear();
index_key.SetFromInteger(key);
tree.GetValue(index_key, &rids);
EXPECT_EQ(rids.size(), 1);
int64_t value = key & 0xFFFFFFFF;
EXPECT_EQ(rids[0].GetSlotNum(), value);
}
int64_t start_key = 1;
int64_t current_key = start_key;
// for (auto pair : tree) {
// auto location = pair.second;
// EXPECT_EQ(location.GetPageId(), 0);
// EXPECT_EQ(location.GetSlotNum(), current_key);
// current_key = current_key + 1;
// }
index_key.SetFromInteger(start_key);
for (auto iterator = tree.Begin(index_key); iterator != tree.End(); ++iterator) {
auto location = (*iterator).second;
EXPECT_EQ(0, location.GetPageId());
EXPECT_EQ(current_key, location.GetSlotNum());
current_key = current_key + 1;
}
EXPECT_EQ(current_key, keys.size() + 1);
std::vector<int64_t> remove_keys = {1, 5, 3, 4};
for (auto key : remove_keys) {
index_key.SetFromInteger(key);
tree.Remove(index_key, transaction);
}
start_key = 2;
current_key = start_key;
int64_t size = 0;
// for (auto pair : tree) {
// auto location = pair.second;
// EXPECT_EQ(location.GetPageId(), 0);
// EXPECT_EQ(location.GetSlotNum(), current_key);
// current_key = current_key + 1;
// size = size + 1;
// }
index_key.SetFromInteger(start_key);
for (auto iterator = tree.Begin(index_key); iterator != tree.End(); ++iterator) {
auto location = (*iterator).second;
EXPECT_EQ(0, location.GetPageId());
EXPECT_EQ(current_key, location.GetSlotNum());
current_key = current_key + 1;
size = size + 1;
}
EXPECT_EQ(size, 1);
bpm->UnpinPage(HEADER_PAGE_ID, true);
delete transaction;
delete disk_manager;
delete bpm;
remove("test.db");
remove("test.log");
}
/*
* Score: 10
* Description: Insert 10000 keys. Use GetValue and the iterator to iterate
* through the inserted keys. Then remove 9900 inserted keys. Finally, use
* the iterator to check the correctness of the remaining keys.
*/
TEST(GradeScopeBPlusTreeTests, ScaleTest) {
auto key_schema = ParseCreateStatement("a bigint");
GenericComparator<8> comparator(key_schema.get());
auto *disk_manager = new DiskManager("test.db");
BufferPoolManager *bpm = new BufferPoolManagerInstance(30, disk_manager);
// create b+ tree
BPlusTree<GenericKey<8>, RID, GenericComparator<8>> tree("foo_pk", bpm, comparator);
GenericKey<8> index_key;
RID rid;
// create transaction
Transaction *transaction = new Transaction(0);
// create and fetch header_page
page_id_t page_id;
auto header_page = bpm->NewPage(&page_id);
(void)header_page;
int64_t scale = 10000;
std::vector<int64_t> keys;
for (int64_t key = 1; key < scale; key++) {
keys.push_back(key);
}
// shuffle keys
// std::random_shuffle(keys.begin(), keys.end());
// NOTE: 'std::random_shuffle' has been removed in C++17; use 'std::shuffle' instead
// std::shuffle(keys.begin(), keys.end(), std::mt19937(std::random_device()));
auto rng = std::default_random_engine{};
std::shuffle(keys.begin(), keys.end(), rng);
for (auto key : keys) {
int64_t value = key & 0xFFFFFFFF;
rid.Set(static_cast<int32_t>(key >> 32), value);
index_key.SetFromInteger(key);
tree.Insert(index_key, rid, transaction);
}
std::vector<RID> rids;
for (auto key : keys) {
rids.clear();
index_key.SetFromInteger(key);
tree.GetValue(index_key, &rids);
EXPECT_EQ(rids.size(), 1);
int64_t value = key & 0xFFFFFFFF;
EXPECT_EQ(rids[0].GetSlotNum(), value);
}
int64_t start_key = 1;
int64_t current_key = start_key;
// for (auto pair : tree) {
// (void)pair;
// current_key = current_key + 1;
// }
index_key.SetFromInteger(start_key);
for (auto iterator = tree.Begin(index_key); iterator != tree.End(); ++iterator) {
current_key = current_key + 1;
}
EXPECT_EQ(current_key, keys.size() + 1);
int64_t remove_scale = 9900;
std::vector<int64_t> remove_keys;
for (int64_t key = 1; key < remove_scale; key++) {
remove_keys.push_back(key);
}
// shuffle remove_keys
std::shuffle(remove_keys.begin(), remove_keys.end(), rng);
for (auto key : remove_keys) {
index_key.SetFromInteger(key);
tree.Remove(index_key, transaction);
}
start_key = remove_scale;
int64_t size = 0;
// for (auto pair : tree) {
// (void)pair;
// size = size + 1;
// }
index_key.SetFromInteger(start_key);
for (auto iterator = tree.Begin(index_key); iterator != tree.End(); ++iterator) {
size = size + 1;
}
EXPECT_EQ(size, 100);
remove_keys.clear();
for (int64_t key = remove_scale; key < scale; key++) {
remove_keys.push_back(key);
}
for (auto key : remove_keys) {
index_key.SetFromInteger(key);
tree.Remove(index_key, transaction);
}
EXPECT_EQ(true, tree.IsEmpty());
bpm->UnpinPage(HEADER_PAGE_ID, true);
delete transaction;
delete disk_manager;
delete bpm;
remove("test.db");
remove("test.log");
}
/*
* Score: 10
* Description: Insert a set of keys. Concurrently insert and delete
* a different set of keys.
* At the same time, concurrently get the previously inserted keys.
* Check all the keys get are the same set of keys as previously
* inserted.
*/
TEST(GradeScopeBPlusTreeTests, SequentialMixTest) {
auto key_schema = ParseCreateStatement("a bigint");
GenericComparator<8> comparator(key_schema.get());
auto *disk_manager = new DiskManager("test.db");
BufferPoolManager *bpm = new BufferPoolManagerInstance(50, disk_manager);
// create b+ tree
BPlusTree<GenericKey<8>, RID, GenericComparator<8>> tree("foo_pk", bpm, comparator);
GenericKey<8> index_key;
RID rid;
// create transaction
Transaction *transaction = new Transaction(0);
// create and fetch header_page
page_id_t page_id;
auto header_page = bpm->NewPage(&page_id);
(void)header_page;
// first, populate index
std::vector<int64_t> for_insert;
std::vector<int64_t> for_delete;
size_t sieve = 2; // divide evenly
size_t total_keys = 1000;
for (size_t i = 1; i <= total_keys; i++) {
if (i % sieve == 0) {
for_insert.push_back(i);
} else {
for_delete.push_back(i);
}
}
// Insert all the keys, including the ones that will remain at the end and
// the ones that are going to be removed next.
for (size_t i = 0; i < total_keys / 2; i++) {
int64_t insert_key = for_insert[i];
int64_t insert_value = insert_key & 0xFFFFFFFF;
rid.Set(static_cast<int32_t>(insert_key >> 32), insert_value);
index_key.SetFromInteger(insert_key);
tree.Insert(index_key, rid, transaction);
int64_t delete_key = for_delete[i];
int64_t delete_value = delete_key & 0xFFFFFFFF;
rid.Set(static_cast<int32_t>(delete_key >> 32), delete_value);
index_key.SetFromInteger(delete_key);
tree.Insert(index_key, rid, transaction);
}
// Remove the keys in for_delete
for (auto key : for_delete) {
index_key.SetFromInteger(key);
tree.Remove(index_key, transaction);
}
// Only half of the keys should remain
int64_t start_key = 2;
int64_t size = 0;
index_key.SetFromInteger(start_key);
// for (auto pair : tree) {
for (auto iterator = tree.Begin(index_key); iterator != tree.End(); ++iterator) {
auto key = (*iterator).first;
EXPECT_EQ(key.ToString(), for_insert[size]);
size++;
}
EXPECT_EQ(size, for_insert.size());
bpm->UnpinPage(HEADER_PAGE_ID, true);
delete transaction;
delete disk_manager;
delete bpm;
remove("test.db");
remove("test.log");
}
} // namespace bustub
插入测试结果:

删除测试用例:
//===----------------------------------------------------------------------===//
//
// BusTub
//
// b_plus_tree_delete_test.cpp
//
// Identification: test/storage/b_plus_tree_delete_test.cpp
//
// Copyright (c) 2015-2021, Carnegie Mellon University Database Group
//
//===----------------------------------------------------------------------===//
// #include <algorithm>
// #include <cstdio>
// #include "buffer/buffer_pool_manager_instance.h"
// #include "gtest/gtest.h"
// #include "storage/index/b_plus_tree.h"
// #include "test_util.h" // NOLINT
// namespace bustub {
// TEST(BPlusTreeTests, DeleteTest1) {
// // create KeyComparator and index schema
// auto key_schema = ParseCreateStatement("a bigint");
// GenericComparator<8> comparator(key_schema.get());
// auto *disk_manager = new DiskManager("test.db");
// BufferPoolManager *bpm = new BufferPoolManagerInstance(50, disk_manager);
// // create b+ tree
// BPlusTree<GenericKey<8>, RID, GenericComparator<8>> tree("foo_pk", bpm, comparator);
// GenericKey<8> index_key;
// RID rid;
// // create transaction
// auto *transaction = new Transaction(0);
// // create and fetch header_page
// page_id_t page_id;
// auto header_page = bpm->NewPage(&page_id);
// (void)header_page;
// std::vector<int64_t> keys = {1, 2, 3, 4, 5};
// for (auto key : keys) {
// int64_t value = key & 0xFFFFFFFF;
// rid.Set(static_cast<int32_t>(key >> 32), value);
// index_key.SetFromInteger(key);
// tree.Insert(index_key, rid, transaction);
// }
// std::vector<RID> rids;
// for (auto key : keys) {
// rids.clear();
// index_key.SetFromInteger(key);
// tree.GetValue(index_key, &rids);
// EXPECT_EQ(rids.size(), 1);
// int64_t value = key & 0xFFFFFFFF;
// EXPECT_EQ(rids[0].GetSlotNum(), value);
// }
// std::vector<int64_t> remove_keys = {1, 5};
// for (auto key : remove_keys) {
// index_key.SetFromInteger(key);
// tree.Remove(index_key, transaction);
// }
// int64_t size = 0;
// bool is_present;
// for (auto key : keys) {
// rids.clear();
// index_key.SetFromInteger(key);
// is_present = tree.GetValue(index_key, &rids);
// if (!is_present) {
// EXPECT_NE(std::find(remove_keys.begin(), remove_keys.end(), key), remove_keys.end());
// } else {
// EXPECT_EQ(rids.size(), 1);
// EXPECT_EQ(rids[0].GetPageId(), 0);
// EXPECT_EQ(rids[0].GetSlotNum(), key);
// size = size + 1;
// }
// }
// EXPECT_EQ(size, 3);
// bpm->UnpinPage(HEADER_PAGE_ID, true);
// delete transaction;
// delete disk_manager;
// delete bpm;
// remove("test.db");
// remove("test.log");
// }
// TEST(BPlusTreeTests, DeleteTest2) {
// // create KeyComparator and index schema
// auto key_schema = ParseCreateStatement("a bigint");
// GenericComparator<8> comparator(key_schema.get());
// auto *disk_manager = new DiskManager("test.db");
// BufferPoolManager *bpm = new BufferPoolManagerInstance(50, disk_manager);
// // create b+ tree
// BPlusTree<GenericKey<8>, RID, GenericComparator<8>> tree("foo_pk", bpm, comparator);
// GenericKey<8> index_key;
// RID rid;
// // create transaction
// auto *transaction = new Transaction(0);
// // create and fetch header_page
// page_id_t page_id;
// auto header_page = bpm->NewPage(&page_id);
// (void)header_page;
// std::vector<int64_t> keys = {1, 2, 3, 4, 5};
// for (auto key : keys) {
// int64_t value = key & 0xFFFFFFFF;
// rid.Set(static_cast<int32_t>(key >> 32), value);
// index_key.SetFromInteger(key);
// tree.Insert(index_key, rid, transaction);
// }
// std::vector<RID> rids;
// for (auto key : keys) {
// rids.clear();
// index_key.SetFromInteger(key);
// tree.GetValue(index_key, &rids);
// EXPECT_EQ(rids.size(), 1);
// int64_t value = key & 0xFFFFFFFF;
// EXPECT_EQ(rids[0].GetSlotNum(), value);
// }
// std::vector<int64_t> remove_keys = {1, 5, 3, 4};
// for (auto key : remove_keys) {
// index_key.SetFromInteger(key);
// tree.Remove(index_key, transaction);
// }
// int64_t size = 0;
// bool is_present;
// for (auto key : keys) {
// rids.clear();
// index_key.SetFromInteger(key);
// is_present = tree.GetValue(index_key, &rids);
// if (!is_present) {
// EXPECT_NE(std::find(remove_keys.begin(), remove_keys.end(), key), remove_keys.end());
// } else {
// EXPECT_EQ(rids.size(), 1);
// EXPECT_EQ(rids[0].GetPageId(), 0);
// EXPECT_EQ(rids[0].GetSlotNum(), key);
// size = size + 1;
// }
// }
// EXPECT_EQ(size, 1);
// bpm->UnpinPage(HEADER_PAGE_ID, true);
// delete transaction;
// delete disk_manager;
// delete bpm;
// remove("test.db");
// remove("test.log");
// }
// } // namespace bustub
//===----------------------------------------------------------------------===//
//
// BusTub
//
// b_plus_tree_delete_test.cpp
//
// Identification: test/storage/b_plus_tree_delete_test.cpp
//
// Copyright (c) 2015-2021, Carnegie Mellon University Database Group
//
//===----------------------------------------------------------------------===//
#include <algorithm>
#include <cstdio>
#include <numeric>
#include <random>
#include <string>
#include "../src/include/common/logger.h"
#include "buffer/buffer_pool_manager_instance.h"
#include "gtest/gtest.h"
#include "storage/index/b_plus_tree.h"
#include "test_util.h" // NOLINT
namespace bustub {
TEST(BPlusTreeTests, DeleteTest1) {
// create KeyComparator and index schema
auto key_schema = ParseCreateStatement("a bigint");
GenericComparator<8> comparator(key_schema.get());
auto *disk_manager = new DiskManager("test.db");
BufferPoolManager *bpm = new BufferPoolManagerInstance(50, disk_manager);
// create b+ tree
BPlusTree<GenericKey<8>, RID, GenericComparator<8>> tree("foo_pk", bpm, comparator);
GenericKey<8> index_key;
RID rid;
// create transaction
auto *transaction = new Transaction(0);
// create and fetch header_page
page_id_t page_id;
auto header_page = bpm->NewPage(&page_id);
(void)header_page;
std::vector<int64_t> keys = {1, 2, 3, 4, 5};
for (auto key : keys) {
int64_t value = key & 0xFFFFFFFF;
rid.Set(static_cast<int32_t>(key >> 32), value);
index_key.SetFromInteger(key);
tree.Insert(index_key, rid, transaction);
}
std::vector<RID> rids;
for (auto key : keys) {
rids.clear();
index_key.SetFromInteger(key);
tree.GetValue(index_key, &rids);
EXPECT_EQ(rids.size(), 1);
int64_t value = key & 0xFFFFFFFF;
EXPECT_EQ(rids[0].GetSlotNum(), value);
}
std::vector<int64_t> remove_keys = {1, 5};
for (auto key : remove_keys) {
index_key.SetFromInteger(key);
tree.Remove(index_key, transaction);
}
int64_t size = 0;
bool is_present;
for (auto key : keys) {
rids.clear();
index_key.SetFromInteger(key);
is_present = tree.GetValue(index_key, &rids);
if (!is_present) {
EXPECT_NE(std::find(remove_keys.begin(), remove_keys.end(), key), remove_keys.end());
} else {
EXPECT_EQ(rids.size(), 1);
EXPECT_EQ(rids[0].GetPageId(), 0);
EXPECT_EQ(rids[0].GetSlotNum(), key);
size = size + 1;
}
}
EXPECT_EQ(size, 3);
bpm->UnpinPage(HEADER_PAGE_ID, true);
delete transaction;
delete disk_manager;
delete bpm;
remove("test.db");
remove("test.log");
}
TEST(BPlusTreeTests, DeleteTest2) {
// create KeyComparator and index schema
auto key_schema = ParseCreateStatement("a bigint");
GenericComparator<8> comparator(key_schema.get());
auto *disk_manager = new DiskManager("test.db");
BufferPoolManager *bpm = new BufferPoolManagerInstance(50, disk_manager);
// create b+ tree
BPlusTree<GenericKey<8>, RID, GenericComparator<8>> tree("foo_pk", bpm, comparator);
GenericKey<8> index_key;
RID rid;
// create transaction
auto *transaction = new Transaction(0);
// create and fetch header_page
page_id_t page_id;
auto header_page = bpm->NewPage(&page_id);
(void)header_page;
std::vector<int64_t> keys = {1, 2, 3, 4, 5};
for (auto key : keys) {
int64_t value = key & 0xFFFFFFFF;
rid.Set(static_cast<int32_t>(key >> 32), value);
index_key.SetFromInteger(key);
tree.Insert(index_key, rid, transaction);
}
std::vector<RID> rids;
for (auto key : keys) {
rids.clear();
index_key.SetFromInteger(key);
tree.GetValue(index_key, &rids);
EXPECT_EQ(rids.size(), 1);
int64_t value = key & 0xFFFFFFFF;
EXPECT_EQ(rids[0].GetSlotNum(), value);
}
std::vector<int64_t> remove_keys = {1, 5, 3, 4};
for (auto key : remove_keys) {
index_key.SetFromInteger(key);
tree.Remove(index_key, transaction);
}
int64_t size = 0;
bool is_present;
for (auto key : keys) {
rids.clear();
index_key.SetFromInteger(key);
is_present = tree.GetValue(index_key, &rids);
if (!is_present) {
EXPECT_NE(std::find(remove_keys.begin(), remove_keys.end(), key), remove_keys.end());
} else {
EXPECT_EQ(rids.size(), 1);
EXPECT_EQ(rids[0].GetPageId(), 0);
EXPECT_EQ(rids[0].GetSlotNum(), key);
size = size + 1;
}
}
EXPECT_EQ(size, 1);
bpm->UnpinPage(HEADER_PAGE_ID, true);
delete transaction;
delete disk_manager;
delete bpm;
remove("test.db");
remove("test.log");
}
TEST(BPlusTreeTests, ScaleTest) {
// create KeyComparator and index schema
auto key_schema = ParseCreateStatement("a bigint");
GenericComparator<8> comparator(key_schema.get());
auto *disk_manager = new DiskManager("test.db");
BufferPoolManager *bpm = new BufferPoolManagerInstance(50, disk_manager);
// create b+ tree
BPlusTree<GenericKey<8>, RID, GenericComparator<8>> tree("foo_pk", bpm, comparator);
GenericKey<8> index_key;
RID rid;
// create transaction
auto *transaction = new Transaction(0);
// create and fetch header_page
page_id_t page_id;
auto header_page = bpm->NewPage(&page_id);
(void)header_page;
int64_t scale = 10000;
std::vector<int64_t> keys;
for (int64_t key = 1; key < scale; key++) {
keys.push_back(key);
}
// randomized the insertion order
auto rng = std::default_random_engine{};
std::shuffle(keys.begin(), keys.end(), rng);
for (auto key : keys) {
int64_t value = key & 0xFFFFFFFF;
// LOG_INFO("key = %ld", key);
rid.Set(static_cast<int32_t>(key >> 32), value);
index_key.SetFromInteger(key);
tree.Insert(index_key, rid, transaction);
// tree.Draw(bpm, "./thh-insert" + std::to_string(key) + ".dot");
}
// tree.Draw(bpm, "./thh-insert-final.dot");
std::vector<RID> rids;
for (auto key : keys) {
rids.clear();
index_key.SetFromInteger(key);
tree.GetValue(index_key, &rids);
EXPECT_EQ(rids.size(), 1);
int64_t value = key & 0xFFFFFFFF;
EXPECT_EQ(rids[0].GetSlotNum(), value);
}
bpm->UnpinPage(HEADER_PAGE_ID, true);
delete transaction;
delete disk_manager;
delete bpm;
remove("test.db");
remove("test.log");
}
TEST(BPlusTreeTests, InsertTest1) {
// create KeyComparator and index schema
auto key_schema = ParseCreateStatement("a bigint");
GenericComparator<8> comparator(key_schema.get());
auto *disk_manager = new DiskManager("test.db");
BufferPoolManager *bpm = new BufferPoolManagerInstance(50, disk_manager);
// create b+ tree
BPlusTree<GenericKey<8>, RID, GenericComparator<8>> tree("foo_pk", bpm, comparator, 3, 3);
GenericKey<8> index_key;
RID rid;
// create transaction
auto *transaction = new Transaction(0);
// create and fetch header_page
page_id_t page_id;
auto header_page = bpm->NewPage(&page_id);
(void)header_page;
// (wxx) 修改这里测试
int size = 10000;
// std::vector<int64_t> keys = {
// 4,1,2,13,14,12,8,3,6,7,15,5,9,11,10
// };
std::vector<int64_t> keys(size);
std::iota(keys.begin(), keys.end(), 1);
std::random_device rd;
std::mt19937 g(rd());
std::shuffle(keys.begin(), keys.end(), g);
for (auto key : keys) {
int64_t value = key & 0xFFFFFFFF;
rid.Set(static_cast<int32_t>(key >> 32), value);
index_key.SetFromInteger(key);
tree.Insert(index_key, rid, transaction);
// LOG_INFO("insert-key = %ld", key);
}
// LOG_INFO("428");
// std::string out = "/home/ephmeral/My-tree.dot";
// std::cout << out << std::endl;
std::vector<RID> rids;
std::shuffle(keys.begin(), keys.end(), g);
for (auto key : keys) {
rids.clear();
index_key.SetFromInteger(key);
tree.GetValue(index_key, &rids);
EXPECT_EQ(rids.size(), 1);
int64_t value = key & 0xFFFFFFFF;
EXPECT_EQ(rids[0].GetSlotNum(), value);
}
// LOG_INFO("447");
std::shuffle(keys.begin(), keys.end(), g);
// keys = {
// 1,11,3,4,8,13,6,15,14,5,10,7,9,12,2
// };
// std::sort(keys.begin(), keys.end(), [](int64_t a, int64_t b){
// return a < b;
// });
// std::string out = "/home/ephmeral/before_delete.dot";
// tree.Draw(bpm, out);
// int cnt = 0;
// tree.Draw(bpm, "thh_delete_befor.dot");
// int step = 0;
for (auto key : keys) {
// ++cnt;
// std::cout << "key = " << key << " count = " << cnt << std::endl;
// out = "/home/ephmeral/tree/in_delete" + std::to_string(cnt) + ".dot";
// step++;
// LOG_INFO("delete-key = %ld", key);
// tree.Draw(bpm, "thh_test_delete" + std::to_string(step) + ".dot");
index_key.SetFromInteger(key);
tree.Remove(index_key, transaction);
}
// tree.Draw(bpm, "thh_test_delete" + std::to_string(70) + ".dot");
// LOG_INFO("463");
// int cnt = 0;
// for (auto key : keys) {
// rids.clear();
// index_key.SetFromInteger(key);
// if (tree.GetValue(index_key, &rids)) {
// std::cout << key << std::endl;
// ++cnt;
// }
// }
// std::cout << "the count is = " << cnt << std::endl;
EXPECT_EQ(true, tree.IsEmpty());
bpm->UnpinPage(HEADER_PAGE_ID, true);
delete transaction;
delete disk_manager;
delete bpm;
remove("test.db");
remove("test.log");
}
TEST(BPlusTreeTests, InsertTest2) {
// create KeyComparator and index schema
auto key_schema = ParseCreateStatement("a bigint");
GenericComparator<8> comparator(key_schema.get());
auto *disk_manager = new DiskManager("test.db");
BufferPoolManager *bpm = new BufferPoolManagerInstance(50, disk_manager);
// create b+ tree
BPlusTree<GenericKey<8>, RID, GenericComparator<8>> tree("foo_pk", bpm, comparator, 4, 4);
GenericKey<8> index_key;
RID rid;
// create transaction
auto *transaction = new Transaction(0);
// create and fetch header_page
page_id_t page_id;
auto header_page = bpm->NewPage(&page_id);
(void)header_page;
// (wxx) 修改这里测试
int size = 5000;
std::vector<int64_t> keys(size);
std::iota(keys.begin(), keys.end(), 1);
std::random_device rd;
std::mt19937 g(rd());
std::shuffle(keys.begin(), keys.end(), g);
for (auto key : keys) {
int64_t value = key & 0xFFFFFFFF;
rid.Set(static_cast<int32_t>(key >> 32), value);
index_key.SetFromInteger(key);
tree.Insert(index_key, rid, transaction);
}
std::vector<RID> rids;
std::shuffle(keys.begin(), keys.end(), g);
for (auto key : keys) {
rids.clear();
index_key.SetFromInteger(key);
tree.GetValue(index_key, &rids);
EXPECT_EQ(rids.size(), 1);
int64_t value = key & 0xFFFFFFFF;
EXPECT_EQ(rids[0].GetSlotNum(), value);
}
std::shuffle(keys.begin(), keys.end(), g);
for (auto key : keys) {
index_key.SetFromInteger(key);
tree.Remove(index_key, transaction);
}
EXPECT_EQ(true, tree.IsEmpty());
bpm->UnpinPage(HEADER_PAGE_ID, true);
delete transaction;
delete disk_manager;
delete bpm;
remove("test.db");
remove("test.log");
}
TEST(BPlusTreeTests, InsertTest3) {
// create KeyComparator and index schema
auto key_schema = ParseCreateStatement("a bigint");
GenericComparator<8> comparator(key_schema.get());
auto *disk_manager = new DiskManager("test.db");
BufferPoolManager *bpm = new BufferPoolManagerInstance(50, disk_manager);
// create b+ tree
BPlusTree<GenericKey<8>, RID, GenericComparator<8>> tree("foo_pk", bpm, comparator);
GenericKey<8> index_key;
RID rid;
// create transaction
auto *transaction = new Transaction(0);
// create and fetch header_page
page_id_t page_id;
auto header_page = bpm->NewPage(&page_id);
(void)header_page;
// (wxx) 修改这里测试
int size = 10000;
std::vector<int64_t> keys(size);
std::iota(keys.begin(), keys.end(), 1);
std::random_device rd;
std::mt19937 g(rd());
std::shuffle(keys.begin(), keys.end(), g);
for (auto key : keys) {
int64_t value = key & 0xFFFFFFFF;
rid.Set(static_cast<int32_t>(key >> 32), value);
index_key.SetFromInteger(key);
tree.Insert(index_key, rid, transaction);
}
std::vector<RID> rids;
std::shuffle(keys.begin(), keys.end(), g);
for (auto key : keys) {
rids.clear();
index_key.SetFromInteger(key);
tree.GetValue(index_key, &rids);
EXPECT_EQ(rids.size(), 1);
int64_t value = key & 0xFFFFFFFF;
EXPECT_EQ(rids[0].GetSlotNum(), value);
}
std::shuffle(keys.begin(), keys.end(), g);
for (auto key : keys) {
index_key.SetFromInteger(key);
tree.Remove(index_key, transaction);
}
EXPECT_EQ(true, tree.IsEmpty());
bpm->UnpinPage(HEADER_PAGE_ID, true);
delete transaction;
delete disk_manager;
delete bpm;
remove("test.db");
remove("test.log");
}
} // namespace bustub
删除测试结果:

Checkpoint2 Multi Thread B+Tree
Checkpoint2 也分为两个部分: - Task3:Index Iterator。实现 leaf page 的 range scan。 - Task4:Concurrent Index。支持 B+ 树并发操作。
Task3 Index Iterator
这个部分没有什么太多好说的,实现一个遍历 leaf page 的迭代器。在迭代器中存储当前 leaf page 的指针和当前停留的位置即可。遍历完当前 page 后,通过 next page id 找到下一个 leaf page。同样,记得 unpin 已经遍历完的 page。关于可能存在的死锁问题,暂时不讨论。
Task4 Concurrent Index
这是并发 B+ 树的重点,应该也是 Project2 中最难的部分。我们要使此前实现的 B+ 树支持并发的 Search/Insert/Delete 操作。整棵树一把锁逻辑上来说当然是可以的,但性能也会可想而知地糟糕。在这里,我们会使用一种特殊的加锁方式,叫做 latch crabbing。顾名思义,就像螃蟹一样,移动一只脚,放下,移动另一只脚,再放下。基本思想是: 1. 先锁住 parent page, 2. 再锁住 child page, 3. 假设 child page 是安全的,则释放 parent page 的锁。安全指当前 page 在当前操作下一定不会发生 split/steal/merge。同时,安全对不同操作的定义是不同的,Search 时,任何节点都安全;Insert 时,判断 max size;Delete 时,判断 min size。
这么做的原因和正确性还是比较明显的。当 page 为安全的时候,当前操作仅可能改变此 page 及其 child page 的值,因此可以提前释放掉其祖先的锁来提高并发性能。
Search
Search 时,从 root page 开始,先给 parent 上读锁,再给 child page 上读锁,然后释放 parent page 的锁。如此向下递归。




Insert
Insert 时,从 root page 开始,先给 parent 上写锁,再给 child page 上写锁。假如 child page 安全,则释放所有祖先的锁;否则不释放锁,继续向下递归。




在 child page 不安全时,需要持续持有祖先的写锁。并在出现安全的 child page 后,释放所有祖先写锁。如何记录哪些 page 当前持有锁?这里就要用到在 Checkpoint1 里一直没有提到的一个参数,transaction。
transaction 就是 Bustub 里的事务。在 Project2 中,可以暂时不用理解事务是什么,而是将其看作当前在对 B+ 树进行操作的线程。调用 transaction 的 AddIntoPageSet() 方法,来跟踪当前线程获取的 page 锁。在发现一个安全的 child page 后,将 transaction 中记录的 page 锁全部释放掉。按理来说,释放锁的顺序可以从上到下也可以从下到上,但由于上层节点的竞争一般更加激烈,所以最好是从上到下地释放锁。
在完成整个 Insert 操作后,释放所有锁。
Delete
和 Insert 基本一样。仅是判断是否安全的方法不同(检测 min size)。需要另外注意的是,当需要 steal/merge sibling 时,也需要对 sibling 加锁。并在完成 steal/merge 后马上释放。这里是为了避免其他线程正在对 sibling 进行 Search/Insert 操作,从而发生 data race。这里的加锁就不需要在 transaction 里记录了,只是临时使用。
Implementation
可以发现,latch crabbing 是在 Find Leaf 的过程中进行的,因此需要修改 Checkpoint1 中的 FindLeaf(),根据操作的不同沿途加锁。
1 判断当前节点是否安全
注意插入和删除时有个向上更新key的过程,所以不能简单判断size了(考虑到向上更新的过程, 即避免过早释放父页面锁,造成向上更新的时候还要加写锁, 从上到下与从下到上反方向加锁容易死锁)
// 插入和删除用
INDEX_TEMPLATE_ARGUMENTS
auto BPLUSTREE_TYPE::IsSafePage(const KeyType &key, BPlusTreePage *page, OperationType type) const -> bool {
if (type == OperationType::FIND) {
return true;
}
KeyType cur_page0k;
if (page->IsLeafPage()) {
auto leafpage = reinterpret_cast<LeafPage *>(page);
if(leafpage->GetMaxSize() == static_cast<int>(LEAF_PAGE_SIZE) && leafpage->GetSize() == leafpage->GetMaxSize() - 1){
return false;
}
cur_page0k = leafpage->KeyAt(0);
}else{
auto internal_page = reinterpret_cast<InternalPage *>(page);
if(internal_page->GetMaxSize() == static_cast<int>(INTERNAL_PAGE_SIZE) && internal_page->GetSize() == internal_page->GetMaxSize() - 1){
return false;
}
cur_page0k = internal_page->KeyAt(0);
}
if (type == OperationType::INSERT) {
return page->GetSize() < page->GetMaxSize() && comparator_(key, cur_page0k) >= 0;
}
// 删除
// 根节点的特殊性
if (page->IsRootPage()) {
return page->GetSize() > (page->IsLeafPage() ? 1 : 2);
}
return page->GetSize() > page->GetMinSize() && comparator_(key, cur_page0k) != 0;
}
// 查找用:
INDEX_TEMPLATE_ARGUMENTS
auto BPLUSTREE_TYPE::IsSafePage(BPlusTreePage *page, OperationType type) const -> bool {
if (type == OperationType::FIND) {
return true;
}
if (type == OperationType::INSERT) {
return page->GetSize() < page->GetMaxSize();
}
// 根节点的特殊性
if (page->IsRootPage()) {
return page->GetSize() > (page->IsLeafPage() ? 1 : 2);
}
return page->GetSize() > page->GetMinSize();
}
2 加页面锁和解锁
// 插入和删除用
INDEX_TEMPLATE_ARGUMENTS
auto BPLUSTREE_TYPE::LockPage(const KeyType &key, Page *page, Transaction *transaction, OperationType type, bool pessimistic) -> bool {
...
}
// 查找用:
INDEX_TEMPLATE_ARGUMENTS
auto BPLUSTREE_TYPE::LockPage(Page *page, Transaction *transaction, OperationType type, bool pessimistic) -> bool {
...
}
// 解锁
INDEX_TEMPLATE_ARGUMENTS
void BPLUSTREE_TYPE::UnlockAllPages(Transaction *transaction, OperationType type, bool is_dirty) {
...
}
3 给根节点加锁
(因为插入和删除过程中,比如树为空,比如更换根节点, 这些过程都要保护), 相当于在根页面之上又加了一个页面(或者叫做根节点的父节点dummy)
INDEX_TEMPLATE_ARGUMENTS
void BPLUSTREE_TYPE::LockRoot(Transaction *transaction, OperationType type) {
auto page_set = transaction->GetPageSet();
bool have_root_lock = false;
// 正确遍历队列并删除 nullptr
for (auto it = page_set->begin(); it != page_set->end(); ) {
if (*it == nullptr) {
have_root_lock = true;
break; // 假设每个事务至多有一个根锁标记
} else {
++it; // 继续检查下一个元素
}
}
if(have_root_lock){
return;
}
if (type == OperationType::FIND) {
root_mutex_.lock_shared();
// LOG_INFO("LockRoot by transaction %d, 加root_mutex_读锁", transaction->GetTransactionId());
} else {
root_mutex_.lock();
// LOG_INFO("LockRoot by transaction %d, 加root_mutex_写锁", transaction->GetTransactionId());
}
transaction->AddIntoPageSet(nullptr);
}
INDEX_TEMPLATE_ARGUMENTS
void BPLUSTREE_TYPE::UnLockRoot(Transaction *transaction, OperationType type) {
auto page_set = transaction->GetPageSet();
bool have_root_lock = false;
// 正确遍历队列并删除 nullptr
for (auto it = page_set->begin(); it != page_set->end(); ) {
if (*it == nullptr) {
it = page_set->erase(it); // 删除 nullptr 并更新迭代器
have_root_lock = true;
break; // 假设每个事务至多有一个根锁标记
} else {
++it; // 继续检查下一个元素
}
}
if(!have_root_lock){
return;
}
if (type == OperationType::FIND) {
root_mutex_.unlock_shared(); // 释放共享锁
// LOG_INFO("UnLockRoot by transaction %d, 释放root_mutex_读锁", transaction->GetTransactionId());
} else {
root_mutex_.unlock(); // 释放独占锁
// LOG_INFO("UnLockRoot by transaction %d, 释放root_mutex_写锁", transaction->GetTransactionId());
}
}
4 其他加锁情况:
-
插入分裂时,是new出来的页面不需要加锁
-
删除steal/merge时需要访问邻居节点, 加写锁, 利用之前实现的pageguard类,再次利用ACID特性封装加锁和解锁
// 叶子结点向左steal if (parent_page->TheKeyLeftOne(target_key, comparator_, left_page_id) && left_page_id != INVALID_PAGE_ID) { auto left_leaf_page = PageGuard<LeafPage>(buffer_pool_manager_, left_page_id, true, PageGuard<LeafPage>::LOCKTYPE::WRITE); //加锁 if (left_leaf_page->GetSize() > left_leaf_page->GetMinSize()) { left_leaf_page->MoveOneTo(left_leaf_page.operator->(), target.operator->(), true); //向上更新 auto new_target_key = target->KeyAt(0); parent_page->SetKey(target_key, new_target_key, comparator_); return; } left_leaf_page.SetDirty(false); } #pragma once #include <stdexcept> #include <type_traits> #include "buffer/buffer_pool_manager.h" #include "common/config.h" #include "storage/page/b_plus_tree_internal_page.h" #include "storage/page/b_plus_tree_leaf_page.h" #include "storage/page/b_plus_tree_page.h" namespace bustub { // 前置声明 // class BPlusTreePage; // class LeafPage; // class InternalPage; template <typename PageType> class PageGuard { // 添加模板友元声明(允许所有 PageGuard 实例互访私有成员) template <typename T> friend class PageGuard; public: enum class LOCKTYPE { WRITE, READ, NONE }; enum class TRYLOCKTYPE { WRITE, READ, NONE }; // 空构造函数用于转换 PageGuard() = default; // 禁用拷贝构造函数 PageGuard(const PageGuard &) = delete; // 禁用拷贝赋值操作符 PageGuard &operator=(const PageGuard &) = delete; // 基本构造函数 PageGuard(BufferPoolManager *bpm, page_id_t page_id, bool is_dirty = false, LOCKTYPE lock_type = LOCKTYPE::NONE, TRYLOCKTYPE try_lock_type = TRYLOCKTYPE::NONE) : bpm_(bpm), page_id_(page_id), is_dirty_(is_dirty), lock_type_(lock_type), try_lock_type_(try_lock_type) { if (page_id != INVALID_PAGE_ID) { page_ = bpm_->FetchPage(page_id_); Lock(lock_type_); TryLock(try_lock_type_); } else { page_ = bpm_->NewPage(&page_id_); is_dirty_ = true; // 新页面自动标记为脏 } } // 类型安全转换构造函数 template <typename OtherType> PageGuard(PageGuard<OtherType> &&other) { static_assert(std::is_base_of_v<BPlusTreePage, PageType> && std::is_base_of_v<BPlusTreePage, OtherType>, "Can only convert between B+ tree page types"); if (other.page_) { TransferOwnershipFrom(other); } else { throw std::bad_cast(); } } // 移动构造函数 PageGuard(PageGuard &&other) noexcept : bpm_(other.bpm_), page_(other.page_), page_id_(other.page_id_), is_dirty_(other.is_dirty_) { other.Reset(); } PageGuard &operator=(PageGuard &&other) noexcept { if (this != &other) { // 释放当前资源 if (page_ != nullptr) { bpm_->UnpinPage(page_id_, is_dirty_); } // 转移资源 bpm_ = other.bpm_; page_ = other.page_; page_id_ = other.page_id_; is_dirty_ = other.is_dirty_; // 重置源对象 other.bpm_ = nullptr; other.page_ = nullptr; other.page_id_ = INVALID_PAGE_ID; other.is_dirty_ = false; } return *this; } ~PageGuard() { if (page_ != nullptr) { if (page_id_ != INVALID_PAGE_ID) { bpm_->UnpinPage(page_id_, is_dirty_); } UnLock(lock_type_); TryUnLock(try_lock_type_); } } // 显式类型转换方法 template <typename TargetType> PageGuard<TargetType> Convert() { static_assert(std::is_base_of_v<BPlusTreePage, TargetType>, "Target must be a B+ tree page type"); PageGuard<TargetType> new_guard; new_guard.TransferOwnershipFrom(*this); return new_guard; } // 访问操作符 PageType *operator->() { return reinterpret_cast<PageType *>(page_->GetData()); } // 状态检查 explicit operator bool() const { return page_ != nullptr; } page_id_t GetPageId() const { return page_id_; } void SetPageId(page_id_t page_id) { page_id_ = page_id; } bool IsDirty() const { return is_dirty_; } void SetDirty(bool dirty) { is_dirty_ = dirty; } void Lock(LOCKTYPE lock_type) { switch (lock_type) { case LOCKTYPE::WRITE: page_->WLatch(); break; case LOCKTYPE::READ: page_->RLatch(); break; default: break; } } void TryLock(TRYLOCKTYPE lock_type) { switch (lock_type) { case TRYLOCKTYPE::WRITE: trywlock_result_ = page_->TryWLatch(); break; case TRYLOCKTYPE::READ: tryrlock_result_ = page_->TryRLatch(); break; default: break; } } void UnLock(LOCKTYPE lock_type) { switch (lock_type) { case LOCKTYPE::WRITE: page_->WUnlatch(); break; case LOCKTYPE::READ: page_->RUnlatch(); break; default: break; } } void TryUnLock(TRYLOCKTYPE lock_type) { switch (lock_type) { case TRYLOCKTYPE::WRITE: if (trywlock_result_) page_->WUnlatch(); break; case TRYLOCKTYPE::READ: if (tryrlock_result_) page_->RUnlatch(); break; default: break; } } bool GetTryrLockResult() { return tryrlock_result_; } bool GetTrywLockResult() { return trywlock_result_; } private: // 资源转移方法 template <typename OtherPageType> void TransferOwnershipFrom(PageGuard<OtherPageType> &other) { bpm_ = other.bpm_; page_ = other.page_; page_id_ = other.page_id_; is_dirty_ = other.is_dirty_; other.Reset(); } void Reset() { page_ = nullptr; page_id_ = INVALID_PAGE_ID; is_dirty_ = false; } BufferPoolManager *bpm_ = nullptr; Page *page_ = nullptr; page_id_t page_id_ = INVALID_PAGE_ID; bool is_dirty_ = false; LOCKTYPE lock_type_ = LOCKTYPE::NONE; TRYLOCKTYPE try_lock_type_ = TRYLOCKTYPE::NONE; bool trywlock_result_ = false; bool tryrlock_result_ = false; }; } // namespace bustub
不在多言上代码:
代码删了, 不展示了
测试
并发测试用例:
//===----------------------------------------------------------------------===//
//
// BusTub
//
// b_plus_tree_concurrent_test.cpp
//
// Identification: test/storage/b_plus_tree_concurrent_test.cpp
//
// Copyright (c) 2015-2021, Carnegie Mellon University Database Group
//
//===----------------------------------------------------------------------===//
// #include <chrono> // NOLINT
// #include <cstdio>
// #include <functional>
// #include <thread> // NOLINT
// #include "../src/include/common/logger.h"
// #include "buffer/buffer_pool_manager_instance.h"
// #include "gtest/gtest.h"
// #include "storage/index/b_plus_tree.h"
// #include "test_util.h" // NOLINT
// namespace bustub {
// // helper function to launch multiple threads
// template <typename... Args>
// void LaunchParallelTest(uint64_t num_threads, Args &&...args) {
// std::vector<std::thread> thread_group;
// // Launch a group of threads
// for (uint64_t thread_itr = 0; thread_itr < num_threads; ++thread_itr) {
// thread_group.push_back(std::thread(args..., thread_itr));
// }
// // Join the threads with the main thread
// for (uint64_t thread_itr = 0; thread_itr < num_threads; ++thread_itr) {
// thread_group[thread_itr].join();
// }
// }
// // helper function to insert
// void InsertHelper(BPlusTree<GenericKey<8>, RID, GenericComparator<8>> *tree, const std::vector<int64_t> &keys,
// __attribute__((unused)) uint64_t thread_itr = 0) {
// GenericKey<8> index_key;
// RID rid;
// // create transaction
// auto *transaction = new Transaction(0);
// for (auto key : keys) {
// int64_t value = key & 0xFFFFFFFF;
// rid.Set(static_cast<int32_t>(key >> 32), value);
// index_key.SetFromInteger(key);
// tree->Insert(index_key, rid, transaction);
// }
// delete transaction;
// }
// // helper function to seperate insert
// void InsertHelperSplit(BPlusTree<GenericKey<8>, RID, GenericComparator<8>> *tree, const std::vector<int64_t> &keys,
// int total_threads, __attribute__((unused)) uint64_t thread_itr) {
// GenericKey<8> index_key;
// RID rid;
// // create transaction
// auto *transaction = new Transaction(0);
// for (auto key : keys) {
// if (static_cast<uint64_t>(key) % total_threads == thread_itr) {
// int64_t value = key & 0xFFFFFFFF;
// rid.Set(static_cast<int32_t>(key >> 32), value);
// index_key.SetFromInteger(key);
// tree->Insert(index_key, rid, transaction);
// }
// }
// delete transaction;
// }
// // helper function to delete
// void DeleteHelper(BPlusTree<GenericKey<8>, RID, GenericComparator<8>> *tree, const std::vector<int64_t> &remove_keys,
// __attribute__((unused)) uint64_t thread_itr = 0) {
// GenericKey<8> index_key;
// // create transaction
// auto *transaction = new Transaction(0);
// for (auto key : remove_keys) {
// index_key.SetFromInteger(key);
// tree->Remove(index_key, transaction);
// }
// delete transaction;
// }
// // helper function to seperate delete
// void DeleteHelperSplit(BPlusTree<GenericKey<8>, RID, GenericComparator<8>> *tree,
// const std::vector<int64_t> &remove_keys, int total_threads,
// __attribute__((unused)) uint64_t thread_itr) {
// GenericKey<8> index_key;
// // create transaction
// auto *transaction = new Transaction(0);
// for (auto key : remove_keys) {
// if (static_cast<uint64_t>(key) % total_threads == thread_itr) {
// index_key.SetFromInteger(key);
// tree->Remove(index_key, transaction);
// }
// }
// delete transaction;
// }
// TEST(BPlusTreeConcurrentTest, InsertTest1) {
// // create KeyComparator and index schema
// auto key_schema = ParseCreateStatement("a bigint");
// GenericComparator<8> comparator(key_schema.get());
// auto *disk_manager = new DiskManager("test.db");
// BufferPoolManager *bpm = new BufferPoolManagerInstance(50, disk_manager);
// // create b+ tree
// BPlusTree<GenericKey<8>, RID, GenericComparator<8>> tree("foo_pk", bpm, comparator);
// // create and fetch header_page
// page_id_t page_id;
// auto header_page = bpm->NewPage(&page_id);
// (void)header_page;
// // keys to Insert
// std::vector<int64_t> keys;
// int64_t scale_factor = 100;
// for (int64_t key = 1; key < scale_factor; key++) {
// keys.push_back(key);
// }
// LaunchParallelTest(2, InsertHelper, &tree, keys);
// LOG_INFO("124");
// std::vector<RID> rids;
// GenericKey<8> index_key;
// for (auto key : keys) {
// rids.clear();
// index_key.SetFromInteger(key);
// tree.GetValue(index_key, &rids);
// EXPECT_EQ(rids.size(), 1);
// int64_t value = key & 0xFFFFFFFF;
// EXPECT_EQ(rids[0].GetSlotNum(), value);
// }
// LOG_INFO("136");
// int64_t start_key = 1;
// int64_t current_key = start_key;
// index_key.SetFromInteger(start_key);
// for (auto iterator = tree.Begin(index_key); iterator != tree.End(); ++iterator) {
// auto location = (*iterator).second;
// EXPECT_EQ(location.GetPageId(), 0);
// EXPECT_EQ(location.GetSlotNum(), current_key);
// current_key = current_key + 1;
// }
// EXPECT_EQ(current_key, keys.size() + 1);
// bpm->UnpinPage(HEADER_PAGE_ID, true);
// delete disk_manager;
// delete bpm;
// remove("test.db");
// remove("test.log");
// }
// TEST(BPlusTreeConcurrentTest, InsertTest2) {
// // create KeyComparator and index schema
// auto key_schema = ParseCreateStatement("a bigint");
// GenericComparator<8> comparator(key_schema.get());
// auto *disk_manager = new DiskManager("test.db");
// BufferPoolManager *bpm = new BufferPoolManagerInstance(50, disk_manager);
// // create b+ tree
// BPlusTree<GenericKey<8>, RID, GenericComparator<8>> tree("foo_pk", bpm, comparator);
// // create and fetch header_page
// page_id_t page_id;
// auto header_page = bpm->NewPage(&page_id);
// (void)header_page;
// // keys to Insert
// std::vector<int64_t> keys;
// int64_t scale_factor = 100;
// for (int64_t key = 1; key < scale_factor; key++) {
// keys.push_back(key);
// }
// LaunchParallelTest(2, InsertHelperSplit, &tree, keys, 2);
// std::vector<RID> rids;
// GenericKey<8> index_key;
// for (auto key : keys) {
// rids.clear();
// index_key.SetFromInteger(key);
// tree.GetValue(index_key, &rids);
// EXPECT_EQ(rids.size(), 1);
// int64_t value = key & 0xFFFFFFFF;
// EXPECT_EQ(rids[0].GetSlotNum(), value);
// }
// int64_t start_key = 1;
// int64_t current_key = start_key;
// index_key.SetFromInteger(start_key);
// for (auto iterator = tree.Begin(index_key); iterator != tree.End(); ++iterator) {
// auto location = (*iterator).second;
// EXPECT_EQ(location.GetPageId(), 0);
// EXPECT_EQ(location.GetSlotNum(), current_key);
// current_key = current_key + 1;
// }
// EXPECT_EQ(current_key, keys.size() + 1);
// bpm->UnpinPage(HEADER_PAGE_ID, true);
// delete disk_manager;
// delete bpm;
// remove("test.db");
// remove("test.log");
// }
// TEST(BPlusTreeConcurrentTest, DeleteTest1) {
// // create KeyComparator and index schema
// auto key_schema = ParseCreateStatement("a bigint");
// GenericComparator<8> comparator(key_schema.get());
// auto *disk_manager = new DiskManager("test.db");
// BufferPoolManager *bpm = new BufferPoolManagerInstance(50, disk_manager);
// // create b+ tree
// BPlusTree<GenericKey<8>, RID, GenericComparator<8>> tree("foo_pk", bpm, comparator);
// GenericKey<8> index_key;
// // create and fetch header_page
// page_id_t page_id;
// auto header_page = bpm->NewPage(&page_id);
// (void)header_page;
// // sequential insert
// std::vector<int64_t> keys = {1, 2, 3, 4, 5};
// InsertHelper(&tree, keys);
// std::vector<int64_t> remove_keys = {1, 5, 3, 4};
// LaunchParallelTest(2, DeleteHelper, &tree, remove_keys);
// int64_t start_key = 2;
// int64_t current_key = start_key;
// int64_t size = 0;
// index_key.SetFromInteger(start_key);
// for (auto iterator = tree.Begin(index_key); iterator != tree.End(); ++iterator) {
// auto location = (*iterator).second;
// EXPECT_EQ(location.GetPageId(), 0);
// EXPECT_EQ(location.GetSlotNum(), current_key);
// current_key = current_key + 1;
// size = size + 1;
// }
// EXPECT_EQ(size, 1);
// bpm->UnpinPage(HEADER_PAGE_ID, true);
// delete disk_manager;
// delete bpm;
// remove("test.db");
// remove("test.log");
// }
// TEST(BPlusTreeConcurrentTest, DeleteTest2) {
// // create KeyComparator and index schema
// auto key_schema = ParseCreateStatement("a bigint");
// GenericComparator<8> comparator(key_schema.get());
// auto *disk_manager = new DiskManager("test.db");
// BufferPoolManager *bpm = new BufferPoolManagerInstance(50, disk_manager);
// // create b+ tree
// BPlusTree<GenericKey<8>, RID, GenericComparator<8>> tree("foo_pk", bpm, comparator);
// GenericKey<8> index_key;
// // create and fetch header_page
// page_id_t page_id;
// auto header_page = bpm->NewPage(&page_id);
// (void)header_page;
// // sequential insert
// std::vector<int64_t> keys = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
// InsertHelper(&tree, keys);
// std::vector<int64_t> remove_keys = {1, 4, 3, 2, 5, 6};
// LaunchParallelTest(2, DeleteHelperSplit, &tree, remove_keys, 2);
// int64_t start_key = 7;
// int64_t current_key = start_key;
// int64_t size = 0;
// index_key.SetFromInteger(start_key);
// for (auto iterator = tree.Begin(index_key); iterator != tree.End(); ++iterator) {
// auto location = (*iterator).second;
// EXPECT_EQ(location.GetPageId(), 0);
// EXPECT_EQ(location.GetSlotNum(), current_key);
// current_key = current_key + 1;
// size = size + 1;
// }
// EXPECT_EQ(size, 4);
// bpm->UnpinPage(HEADER_PAGE_ID, true);
// delete disk_manager;
// delete bpm;
// remove("test.db");
// remove("test.log");
// }
// TEST(BPlusTreeConcurrentTest, MixTest) {
// // create KeyComparator and index schema
// auto key_schema = ParseCreateStatement("a bigint");
// GenericComparator<8> comparator(key_schema.get());
// auto *disk_manager = new DiskManager("test.db");
// BufferPoolManager *bpm = new BufferPoolManagerInstance(50, disk_manager);
// // create b+ tree
// BPlusTree<GenericKey<8>, RID, GenericComparator<8>> tree("foo_pk", bpm, comparator);
// GenericKey<8> index_key;
// // create and fetch header_page
// page_id_t page_id;
// auto header_page = bpm->NewPage(&page_id);
// (void)header_page;
// // first, populate index
// std::vector<int64_t> keys = {1, 2, 3, 4, 5};
// InsertHelper(&tree, keys);
// // concurrent insert
// keys.clear();
// for (int i = 6; i <= 10; i++) {
// keys.push_back(i);
// }
// LaunchParallelTest(1, InsertHelper, &tree, keys);
// // concurrent delete
// std::vector<int64_t> remove_keys = {1, 4, 3, 5, 6};
// LaunchParallelTest(1, DeleteHelper, &tree, remove_keys);
// int64_t start_key = 2;
// int64_t size = 0;
// index_key.SetFromInteger(start_key);
// for (auto iterator = tree.Begin(index_key); iterator != tree.End(); ++iterator) {
// size = size + 1;
// }
// EXPECT_EQ(size, 5);
// bpm->UnpinPage(HEADER_PAGE_ID, true);
// delete disk_manager;
// delete bpm;
// remove("test.db");
// remove("test.log");
// }
// } // namespace bustub
//===----------------------------------------------------------------------===//
//
// BusTub
//
// b_plus_tree_concurrent_test.cpp
//
// Identification: test/storage/b_plus_tree_concurrent_test.cpp
//
// Copyright (c) 2015-2021, Carnegie Mellon University Database Group
//
//===----------------------------------------------------------------------===//
/**
* grading_b_plus_tree_checkpoint_2_concurrent_test.cpp
*/
#include <chrono> // NOLINT
#include <cstdio>
#include <functional>
#include <future> // NOLINT
#include <thread> // NOLINT
#include "common/logger.h"
#include "test_util.h" // NOLINT
#include "buffer/buffer_pool_manager_instance.h"
#include "gtest/gtest.h"
#include "storage/index/b_plus_tree.h"
// Macro for time out mechanism
#define TEST_TIMEOUT_BEGIN \
std::promise<bool> promisedFinished; \
auto futureResult = promisedFinished.get_future(); \
std::thread([](std::promise<bool>& finished) {
#define TEST_TIMEOUT_FAIL_END(X) \
finished.set_value(true); \
}, std::ref(promisedFinished)).detach(); \
EXPECT_TRUE(futureResult.wait_for(std::chrono::milliseconds(X)) != std::future_status::timeout) \
<< "Test Failed Due to Time Out";
namespace bustub {
// helper function to launch multiple threads
template <typename... Args>
void LaunchParallelTest(uint64_t num_threads, uint64_t txn_id_start, Args &&... args) {
std::vector<std::thread> thread_group;
// Launch a group of threads
for (uint64_t thread_itr = 0; thread_itr < num_threads; ++thread_itr) {
thread_group.push_back(std::thread(args..., txn_id_start + thread_itr, thread_itr));
}
// Join the threads with the main thread
for (uint64_t thread_itr = 0; thread_itr < num_threads; ++thread_itr) {
thread_group[thread_itr].join();
}
}
// helper function to insert
void InsertHelper(BPlusTree<GenericKey<8>, RID, GenericComparator<8>> *tree, const std::vector<int64_t> &keys,
uint64_t tid, __attribute__((unused)) uint64_t thread_itr = 0) {
GenericKey<8> index_key;
RID rid;
// create transaction
Transaction *transaction = new Transaction(tid);
for (auto key : keys) {
int64_t value = key & 0xFFFFFFFF;
rid.Set(static_cast<int32_t>(key >> 32), value);
index_key.SetFromInteger(key);
tree->Insert(index_key, rid, transaction);
}
delete transaction;
}
// helper function to seperate insert
void InsertHelperSplit(BPlusTree<GenericKey<8>, RID, GenericComparator<8>> *tree, const std::vector<int64_t> &keys,
int total_threads, uint64_t tid, __attribute__((unused)) uint64_t thread_itr) {
GenericKey<8> index_key;
RID rid;
// create transaction
Transaction *transaction = new Transaction(tid);
for (auto key : keys) {
if (static_cast<uint64_t>(key) % total_threads == thread_itr) {
int64_t value = key & 0xFFFFFFFF;
rid.Set(static_cast<int32_t>(key >> 32), value);
index_key.SetFromInteger(key);
tree->Insert(index_key, rid, transaction);
}
}
delete transaction;
}
// helper function to delete
void DeleteHelper(BPlusTree<GenericKey<8>, RID, GenericComparator<8>> *tree, const std::vector<int64_t> &remove_keys,
uint64_t tid, __attribute__((unused)) uint64_t thread_itr = 0) {
GenericKey<8> index_key;
// create transaction
Transaction *transaction = new Transaction(tid);
for (auto key : remove_keys) {
index_key.SetFromInteger(key);
tree->Remove(index_key, transaction);
}
delete transaction;
}
// helper function to seperate delete
void DeleteHelperSplit(BPlusTree<GenericKey<8>, RID, GenericComparator<8>> *tree,
const std::vector<int64_t> &remove_keys, int total_threads, uint64_t tid,
__attribute__((unused)) uint64_t thread_itr) {
GenericKey<8> index_key;
// create transaction
Transaction *transaction = new Transaction(tid);
for (auto key : remove_keys) {
if (static_cast<uint64_t>(key) % total_threads == thread_itr) {
index_key.SetFromInteger(key);
tree->Remove(index_key, transaction);
}
}
delete transaction;
}
void LookupHelper(BPlusTree<GenericKey<8>, RID, GenericComparator<8>> *tree, const std::vector<int64_t> &keys,
uint64_t tid, __attribute__((unused)) uint64_t thread_itr = 0) {
Transaction *transaction = new Transaction(tid);
GenericKey<8> index_key;
RID rid;
for (auto key : keys) {
int64_t value = key & 0xFFFFFFFF;
rid.Set(static_cast<int32_t>(key >> 32), value);
index_key.SetFromInteger(key);
std::vector<RID> result;
bool res = tree->GetValue(index_key, &result, transaction);
EXPECT_EQ(res, true);
EXPECT_EQ(result.size(), 1);
EXPECT_EQ(result[0], rid);
}
delete transaction;
}
const size_t NUM_ITERS = 100;
const size_t NUM_ITERS_DEBUG = 100;
void InsertTest1Call() {
for (size_t iter = 0; iter < NUM_ITERS_DEBUG; iter++) {
// create KeyComparator and index schema
auto key_schema = ParseCreateStatement("a bigint");
GenericComparator<8> comparator(key_schema.get());
DiskManager *disk_manager = new DiskManager("test.db");
BufferPoolManager *bpm = new BufferPoolManagerInstance(50, disk_manager);
// create b+ tree
BPlusTree<GenericKey<8>, RID, GenericComparator<8>> tree("foo_pk", bpm, comparator);
// create and fetch header_page
page_id_t page_id;
auto header_page = bpm->NewPage(&page_id);
(void)header_page;
// keys to Insert
std::vector<int64_t> keys;
int64_t scale_factor = 100;
for (int64_t key = 1; key < scale_factor; key++) {
keys.push_back(key);
}
LaunchParallelTest(4, 0, InsertHelper, &tree, keys);
std::vector<RID> rids;
GenericKey<8> index_key;
for (auto key : keys) {
rids.clear();
index_key.SetFromInteger(key);
tree.GetValue(index_key, &rids);
EXPECT_EQ(rids.size(), 1);
int64_t value = key & 0xFFFFFFFF;
EXPECT_EQ(rids[0].GetSlotNum(), value);
}
int64_t start_key = 1;
int64_t current_key = start_key;
index_key.SetFromInteger(start_key);
for (auto iterator = tree.Begin(); iterator != tree.End(); ++iterator) {
auto location = (*iterator).second;
EXPECT_EQ(location.GetPageId(), 0);
EXPECT_EQ(location.GetSlotNum(), current_key);
current_key = current_key + 1;
}
EXPECT_EQ(current_key, keys.size() + 1);
bpm->UnpinPage(HEADER_PAGE_ID, true);
delete disk_manager;
delete bpm;
remove("test.db");
remove("test.log");
}
}
void InsertTest2Call() {
for (size_t iter = 0; iter < NUM_ITERS_DEBUG; iter++) {
// create KeyComparator and index schema
auto key_schema = ParseCreateStatement("a bigint");
GenericComparator<8> comparator(key_schema.get());
DiskManager *disk_manager = new DiskManager("test.db");
BufferPoolManager *bpm = new BufferPoolManagerInstance(50, disk_manager);
// create b+ tree
BPlusTree<GenericKey<8>, RID, GenericComparator<8>> tree("foo_pk", bpm, comparator);
// create and fetch header_page
page_id_t page_id;
auto header_page = bpm->NewPage(&page_id);
(void)header_page;
// keys to Insert
std::vector<int64_t> keys;
int64_t scale_factor = 1000;
for (int64_t key = 1; key < scale_factor; key++) {
keys.push_back(key);
}
LaunchParallelTest(8, 0, InsertHelperSplit, &tree, keys, 2);
std::vector<RID> rids;
GenericKey<8> index_key;
for (auto key : keys) {
rids.clear();
index_key.SetFromInteger(key);
tree.GetValue(index_key, &rids);
EXPECT_EQ(rids.size(), 1);
int64_t value = key & 0xFFFFFFFF;
EXPECT_EQ(rids[0].GetSlotNum(), value);
}
int64_t start_key = 1;
int64_t current_key = start_key;
index_key.SetFromInteger(start_key);
for (auto iterator = tree.Begin(); iterator != tree.End(); ++iterator) {
auto location = (*iterator).second;
EXPECT_EQ(location.GetPageId(), 0);
EXPECT_EQ(location.GetSlotNum(), current_key);
current_key = current_key + 1;
}
EXPECT_EQ(current_key, keys.size() + 1);
bpm->UnpinPage(HEADER_PAGE_ID, true);
delete disk_manager;
delete bpm;
remove("test.db");
remove("test.log");
}
}
void DeleteTest1Call() {
for (size_t iter = 0; iter < NUM_ITERS; iter++) {
// create KeyComparator and index schema
auto key_schema = ParseCreateStatement("a bigint");
GenericComparator<8> comparator(key_schema.get());
DiskManager *disk_manager = new DiskManager("test.db");
BufferPoolManager *bpm = new BufferPoolManagerInstance(50, disk_manager);
// create b+ tree
BPlusTree<GenericKey<8>, RID, GenericComparator<8>> tree("foo_pk", bpm, comparator);
// create and fetch header_page
page_id_t page_id;
auto header_page = bpm->NewPage(&page_id);
(void)header_page;
// sequential insert
std::vector<int64_t> keys = {1, 2, 3, 4, 5};
InsertHelper(&tree, keys, 1);
std::vector<int64_t> remove_keys = {1, 5, 3, 4};
LaunchParallelTest(2, 1, DeleteHelper, &tree, remove_keys);
int64_t start_key = 2;
int64_t current_key = start_key;
int64_t size = 0;
for (auto iterator = tree.Begin(); iterator != tree.End(); ++iterator) {
auto location = (*iterator).second;
EXPECT_EQ(location.GetPageId(), 0);
EXPECT_EQ(location.GetSlotNum(), current_key);
current_key = current_key + 1;
size = size + 1;
}
EXPECT_EQ(size, 1);
bpm->UnpinPage(HEADER_PAGE_ID, true);
delete disk_manager;
delete bpm;
remove("test.db");
remove("test.log");
}
}
void DeleteTest2Call() {
for (size_t iter = 0; iter < NUM_ITERS; iter++) {
// create KeyComparator and index schema
auto key_schema = ParseCreateStatement("a bigint");
GenericComparator<8> comparator(key_schema.get());
DiskManager *disk_manager = new DiskManager("test.db");
BufferPoolManager *bpm = new BufferPoolManagerInstance(50, disk_manager);
// create b+ tree
BPlusTree<GenericKey<8>, RID, GenericComparator<8>> tree("foo_pk", bpm, comparator);
// create and fetch header_page
page_id_t page_id;
auto header_page = bpm->NewPage(&page_id);
(void)header_page;
// sequential insert
std::vector<int64_t> keys = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
InsertHelper(&tree, keys, 1);
std::vector<int64_t> remove_keys = {1, 4, 3, 2, 5, 6};
LaunchParallelTest(2, 1, DeleteHelperSplit, &tree, remove_keys, 2);
int64_t start_key = 7;
int64_t current_key = start_key;
int64_t size = 0;
for (auto iterator = tree.Begin(); iterator != tree.End(); ++iterator) {
auto location = (*iterator).second;
EXPECT_EQ(location.GetPageId(), 0);
EXPECT_EQ(location.GetSlotNum(), current_key);
current_key = current_key + 1;
size = size + 1;
}
EXPECT_EQ(size, 4);
bpm->UnpinPage(HEADER_PAGE_ID, true);
delete disk_manager;
delete bpm;
remove("test.db");
remove("test.log");
}
}
void MixTest1Call() {
for (size_t iter = 0; iter < NUM_ITERS; iter++) {
// create KeyComparator and index schema
auto key_schema = ParseCreateStatement("a bigint");
GenericComparator<8> comparator(key_schema.get());
DiskManager *disk_manager = new DiskManager("test.db");
BufferPoolManager *bpm = new BufferPoolManagerInstance(50, disk_manager);
// create b+ tree
BPlusTree<GenericKey<8>, RID, GenericComparator<8>> tree("foo_pk", bpm, comparator);
// create and fetch header_page
page_id_t page_id;
auto header_page = bpm->NewPage(&page_id);
(void)header_page;
// first, populate index
std::vector<int64_t> for_insert;
std::vector<int64_t> for_delete;
size_t sieve = 2; // divide evenly
size_t total_keys = 1000;
for (size_t i = 1; i <= total_keys; i++) {
if (i % sieve == 0) {
for_insert.push_back(i);
} else {
for_delete.push_back(i);
}
}
// Insert all the keys to delete
InsertHelper(&tree, for_delete, 1);
auto insert_task = [&](int tid) { InsertHelper(&tree, for_insert, tid); };
auto delete_task = [&](int tid) { DeleteHelper(&tree, for_delete, tid); };
std::vector<std::function<void(int)>> tasks;
tasks.emplace_back(insert_task);
tasks.emplace_back(delete_task);
std::vector<std::thread> threads;
size_t num_threads = 10;
for (size_t i = 0; i < num_threads; i++) {
threads.emplace_back(std::thread{tasks[i % tasks.size()], i});
}
for (size_t i = 0; i < num_threads; i++) {
threads[i].join();
}
int64_t size = 0;
for (auto iterator = tree.Begin(); iterator != tree.End(); ++iterator) {
EXPECT_EQ(((*iterator).first).ToString(), for_insert[size]);
size++;
}
EXPECT_EQ(size, for_insert.size());
bpm->UnpinPage(HEADER_PAGE_ID, true);
delete disk_manager;
delete bpm;
remove("test.db");
remove("test.log");
}
}
const int T2 = 200;
void MixTest2Call() {
for (size_t iter = 0; iter < T2; iter++) {
// create KeyComparator and index schema
// LOG_DEBUG("iteration %lu", iter);
auto key_schema = ParseCreateStatement("a bigint");
GenericComparator<8> comparator(key_schema.get());
DiskManager *disk_manager = new DiskManager("test.db");
BufferPoolManager *bpm = new BufferPoolManagerInstance(50, disk_manager);
// create b+ tree
BPlusTree<GenericKey<8>, RID, GenericComparator<8>> tree("foo_pk", bpm, comparator);
// create and fetch header_page
page_id_t page_id;
auto header_page = bpm->NewPage(&page_id);
(void)header_page;
// Add perserved_keys
std::vector<int64_t> perserved_keys;
std::vector<int64_t> dynamic_keys;
size_t total_keys = 1000;
size_t sieve = 5;
for (size_t i = 1; i <= total_keys; i++) {
if (i % sieve == 0) {
perserved_keys.push_back(i);
} else {
dynamic_keys.push_back(i);
}
}
InsertHelper(&tree, perserved_keys, 1);
// Check there are 1000 keys in there
size_t size;
auto insert_task = [&](int tid) { InsertHelper(&tree, dynamic_keys, tid); };
auto delete_task = [&](int tid) { DeleteHelper(&tree, dynamic_keys, tid); };
auto lookup_task = [&](int tid) { LookupHelper(&tree, perserved_keys, tid); };
std::vector<std::thread> threads;
std::vector<std::function<void(int)>> tasks;
tasks.emplace_back(insert_task);
tasks.emplace_back(delete_task);
tasks.emplace_back(lookup_task);
size_t num_threads = 6;
for (size_t i = 0; i < num_threads; i++) {
threads.emplace_back(std::thread{tasks[i % tasks.size()], i});
}
for (size_t i = 0; i < num_threads; i++) {
threads[i].join();
}
size = 0;
for (auto iterator = tree.Begin(); iterator != tree.End(); ++iterator) {
if (((*iterator).first).ToString() % sieve == 0) {
size++;
}
}
EXPECT_EQ(size, perserved_keys.size());
bpm->UnpinPage(HEADER_PAGE_ID, true);
delete disk_manager;
delete bpm;
remove("test.db");
remove("test.log");
}
}
void MixTest3Call() {
for (size_t iter = 0; iter < NUM_ITERS; iter++) {
// create KeyComparator and index schema
auto key_schema = ParseCreateStatement("a bigint");
GenericComparator<8> comparator(key_schema.get());
DiskManager *disk_manager = new DiskManager("test.db");
BufferPoolManager *bpm = new BufferPoolManagerInstance(10, disk_manager);
// create b+ tree
BPlusTree<GenericKey<8>, RID, GenericComparator<8>> tree("foo_pk", bpm, comparator);
// create and fetch header_page
page_id_t page_id;
auto header_page = bpm->NewPage(&page_id);
(void)header_page;
// first, populate index
std::vector<int64_t> for_insert;
std::vector<int64_t> for_delete;
size_t total_keys = 1000;
for (size_t i = 1; i <= total_keys; i++) {
if (i > 500) {
for_insert.push_back(i);
} else {
for_delete.push_back(i);
}
}
// Insert all the keys to delete
InsertHelper(&tree, for_delete, 1);
auto insert_task = [&](int tid) { InsertHelper(&tree, for_insert, tid); };
auto delete_task = [&](int tid) { DeleteHelper(&tree, for_delete, tid); };
std::vector<std::function<void(int)>> tasks;
tasks.emplace_back(insert_task);
tasks.emplace_back(delete_task);
std::vector<std::thread> threads;
size_t num_threads = 10;
for (size_t i = 0; i < num_threads; i++) {
threads.emplace_back(std::thread{tasks[i % tasks.size()], i});
}
for (size_t i = 0; i < num_threads; i++) {
threads[i].join();
}
int64_t size = 0;
for (auto iterator = tree.Begin(); iterator != tree.End(); ++iterator) {
EXPECT_EQ(((*iterator).first).ToString(), for_insert[size]);
size++;
}
EXPECT_EQ(size, for_insert.size());
bpm->UnpinPage(HEADER_PAGE_ID, true);
delete disk_manager;
delete bpm;
remove("test.db");
remove("test.log");
}
}
void MixTest4Call() {
for (size_t iter = 0; iter < NUM_ITERS; iter++) {
// create KeyComparator and index schema
auto key_schema = ParseCreateStatement("a bigint");
GenericComparator<8> comparator(key_schema.get());
DiskManager *disk_manager = new DiskManager("test.db");
BufferPoolManager *bpm = new BufferPoolManagerInstance(50, disk_manager);
// create b+ tree
BPlusTree<GenericKey<8>, RID, GenericComparator<8>> tree("foo_pk", bpm, comparator);
// create and fetch header_page
page_id_t page_id;
auto header_page = bpm->NewPage(&page_id);
(void)header_page;
// first, populate index
std::vector<int64_t> for_insert;
std::vector<int64_t> for_delete;
size_t total_keys = 1000;
for (size_t i = 1; i <= total_keys; i++) {
if (i > total_keys / 2) {
for_insert.push_back(i);
} else {
for_delete.push_back(i);
}
}
// Insert all the keys to delete
InsertHelper(&tree, for_delete, 1);
int64_t size = 0;
auto insert_task = [&](int tid) { InsertHelper(&tree, for_insert, tid); };
auto delete_task = [&](int tid) { DeleteHelper(&tree, for_delete, tid); };
std::vector<std::function<void(int)>> tasks;
tasks.emplace_back(insert_task);
tasks.emplace_back(delete_task);
std::vector<std::thread> threads;
size_t num_threads = 10;
for (size_t i = 0; i < num_threads; i++) {
threads.emplace_back(std::thread{tasks[i % tasks.size()], i});
}
for (size_t i = 0; i < num_threads; i++) {
threads[i].join();
}
for (auto iterator = tree.Begin(); iterator != tree.End(); ++iterator) {
EXPECT_EQ(((*iterator).first).ToString(), for_insert[size]);
size++;
}
EXPECT_EQ(size, for_insert.size());
DeleteHelper(&tree, for_insert, 1);
size = 0;
// LOG_INFO("tree.IsEmpty() is %d", tree.IsEmpty());
EXPECT_EQ(tree.IsEmpty(), true);
for (auto iterator = tree.Begin(); iterator != tree.End(); ++iterator) {
// LOG_INFO("size = %ld", size);
EXPECT_EQ(((*iterator).first).ToString(), for_insert[size]);
size++;
}
EXPECT_EQ(size, 0);
bpm->UnpinPage(HEADER_PAGE_ID, true);
delete disk_manager;
delete bpm;
remove("test.db");
remove("test.log");
}
}
/*
* Score: 5
* Description: Concurrently insert a set of keys.
*/
TEST(BPlusTreeTestC2Con, InsertTest1) {
TEST_TIMEOUT_BEGIN
InsertTest1Call();
remove("test.db");
remove("test.log");
TEST_TIMEOUT_FAIL_END(1000 * 600)
}
/*
* Score: 5
* Description: Split the concurrent insert test to multiple threads
* without overlap.
*/
TEST(BPlusTreeTestC2Con, InsertTest2) {
TEST_TIMEOUT_BEGIN
InsertTest2Call();
remove("test.db");
remove("test.log");
TEST_TIMEOUT_FAIL_END(1000 * 600)
}
/*
* Score: 5
* Description: Concurrently delete a set of keys.
*/
TEST(BPlusTreeTestC2Con, DeleteTest1) {
TEST_TIMEOUT_BEGIN
DeleteTest1Call();
remove("test.db");
remove("test.log");
TEST_TIMEOUT_FAIL_END(1000 * 600)
}
/*
* Score: 5
* Description: Split the concurrent delete task to multiple threads
* without overlap.
*/
TEST(BPlusTreeTestC2Con, DeleteTest2) {
TEST_TIMEOUT_BEGIN
DeleteTest2Call();
remove("test.db");
remove("test.log");
TEST_TIMEOUT_FAIL_END(1000 * 600)
}
/*
* Score: 5
* Description: First insert a set of keys.
* Then concurrently delete those already inserted keys and
* insert different set of keys. Check if all old keys are
* deleted and new keys are added correctly.
*/
TEST(BPlusTreeTestC2Con, MixTest1) {
TEST_TIMEOUT_BEGIN
MixTest1Call();
remove("test.db");
remove("test.log");
TEST_TIMEOUT_FAIL_END(1000 * 600)
}
/*
* Score: 5
* Description: Insert a set of keys. Concurrently insert and delete
* a differnt set of keys.
* At the same time, concurrently get the previously inserted keys.
* Check all the keys get are the same set of keys as previously
* inserted.
*/
TEST(BPlusTreeTestC2Con, MixTest2) {
TEST_TIMEOUT_BEGIN
MixTest2Call();
remove("test.db");
remove("test.log");
TEST_TIMEOUT_FAIL_END(1000 * 600)
}
/*
* Score: 5
* Description: First insert a set of keys.
* Then concurrently delete those already inserted keys and
* insert different set of keys. Check if all old keys are
* deleted and new keys are added correctly.
*/
TEST(BPlusTreeTestC2Con, MixTest3) {
TEST_TIMEOUT_BEGIN
MixTest3Call();
remove("test.db");
remove("test.log");
TEST_TIMEOUT_FAIL_END(1000 * 600)
}
TEST(BPlusTreeTestC2Con, MixTest4) {
TEST_TIMEOUT_BEGIN
MixTest4Call();
remove("test.db");
remove("test.log");
TEST_TIMEOUT_FAIL_END(1000 * 600)
}
/*
* Score: 5
* Description: The same test that has been run for checkpoint 1,
* but added iterator for value checking
*/
TEST(BPlusTreeConcurrentTestC2Seq, InsertTest1) {
// create KeyComparator and index schema
auto key_schema = ParseCreateStatement("a bigint");
GenericComparator<8> comparator(key_schema.get());
DiskManager *disk_manager = new DiskManager("test.db");
BufferPoolManager *bpm = new BufferPoolManagerInstance(50, disk_manager);
// create b+ tree
BPlusTree<GenericKey<8>, RID, GenericComparator<8>> tree("foo_pk", bpm, comparator);
GenericKey<8> index_key;
RID rid;
// create transaction
Transaction *transaction = new Transaction(0);
// create and fetch header_page
page_id_t page_id;
auto header_page = bpm->NewPage(&page_id);
(void)header_page;
std::vector<int64_t> keys = {1, 2, 3, 4, 5};
for (auto key : keys) {
int64_t value = key & 0xFFFFFFFF;
rid.Set(static_cast<int32_t>(key >> 32), value);
index_key.SetFromInteger(key);
tree.Insert(index_key, rid, transaction);
}
std::vector<RID> rids;
for (auto key : keys) {
rids.clear();
index_key.SetFromInteger(key);
tree.GetValue(index_key, &rids);
EXPECT_EQ(rids.size(), 1);
int64_t value = key & 0xFFFFFFFF;
EXPECT_EQ(rids[0].GetSlotNum(), value);
}
int64_t start_key = 1;
int64_t current_key = start_key;
for (auto iterator = tree.Begin(); iterator != tree.End(); ++iterator) {
auto location = (*iterator).second;
EXPECT_EQ(location.GetPageId(), 0);
EXPECT_EQ(location.GetSlotNum(), current_key);
current_key = current_key + 1;
}
EXPECT_EQ(current_key, keys.size() + 1);
bpm->UnpinPage(HEADER_PAGE_ID, true);
delete transaction;
delete disk_manager;
delete bpm;
remove("test.db");
remove("test.log");
}
/*
* Score: 5
* Description: The same test that has been run for checkpoint 1
* but added iterator for value checking
*/
TEST(BPlusTreeConcurrentTestC2Seq, InsertTest2) {
// create KeyComparator and index schema
auto key_schema = ParseCreateStatement("a bigint");
GenericComparator<8> comparator(key_schema.get());
DiskManager *disk_manager = new DiskManager("test.db");
BufferPoolManager *bpm = new BufferPoolManagerInstance(50, disk_manager);
// create b+ tree
BPlusTree<GenericKey<8>, RID, GenericComparator<8>> tree("foo_pk", bpm, comparator);
GenericKey<8> index_key;
RID rid;
// create transaction
Transaction *transaction = new Transaction(0);
// create and fetch header_page
page_id_t page_id;
auto header_page = bpm->NewPage(&page_id);
(void)header_page;
std::vector<int64_t> keys = {5, 4, 3, 2, 1};
for (auto key : keys) {
int64_t value = key & 0xFFFFFFFF;
rid.Set(static_cast<int32_t>(key >> 32), value);
index_key.SetFromInteger(key);
tree.Insert(index_key, rid, transaction);
}
std::vector<RID> rids;
for (auto key : keys) {
rids.clear();
index_key.SetFromInteger(key);
tree.GetValue(index_key, &rids);
EXPECT_EQ(rids.size(), 1);
int64_t value = key & 0xFFFFFFFF;
EXPECT_EQ(rids[0].GetSlotNum(), value);
}
int64_t start_key = 1;
int64_t current_key = start_key;
for (auto iterator = tree.Begin(index_key); iterator != tree.End(); ++iterator) {
auto location = (*iterator).second;
EXPECT_EQ(location.GetPageId(), 0);
EXPECT_EQ(location.GetSlotNum(), current_key);
current_key = current_key + 1;
}
EXPECT_EQ(current_key, keys.size() + 1);
start_key = 3;
current_key = start_key;
index_key.SetFromInteger(start_key);
for (auto iterator = tree.Begin(index_key); !iterator.IsEnd(); ++iterator) {
auto location = (*iterator).second;
EXPECT_EQ(location.GetPageId(), 0);
EXPECT_EQ(location.GetSlotNum(), current_key);
current_key = current_key + 1;
}
bpm->UnpinPage(HEADER_PAGE_ID, true);
delete transaction;
delete disk_manager;
delete bpm;
remove("test.db");
remove("test.log");
}
/*
* Score: 10
* Description: Insert a set of keys, use GetValue and iterator to
* check the the inserted keys. Then delete a subset of the keys.
* Finally use the iterator to check the remained keys.
*/
TEST(BPlusTreeConcurrentTestC2Seq, DeleteTest1) {
// create KeyComparator and index schema
std::string createStmt = "a bigint";
auto key_schema = ParseCreateStatement(createStmt);
GenericComparator<8> comparator(key_schema.get());
DiskManager *disk_manager = new DiskManager("test.db");
BufferPoolManager *bpm = new BufferPoolManagerInstance(50, disk_manager);
// create b+ tree
BPlusTree<GenericKey<8>, RID, GenericComparator<8>> tree("foo_pk", bpm, comparator);
GenericKey<8> index_key;
RID rid;
// create transaction
Transaction *transaction = new Transaction(0);
// create and fetch header_page
page_id_t page_id;
auto header_page = bpm->NewPage(&page_id);
(void)header_page;
std::vector<int64_t> keys = {1, 2, 3, 4, 5};
for (auto key : keys) {
int64_t value = key & 0xFFFFFFFF;
rid.Set(static_cast<int32_t>(key >> 32), value);
index_key.SetFromInteger(key);
tree.Insert(index_key, rid, transaction);
}
std::vector<RID> rids;
for (auto key : keys) {
rids.clear();
index_key.SetFromInteger(key);
tree.GetValue(index_key, &rids);
EXPECT_EQ(rids.size(), 1);
int64_t value = key & 0xFFFFFFFF;
EXPECT_EQ(rids[0].GetSlotNum(), value);
}
int64_t start_key = 1;
int64_t current_key = start_key;
for (auto iterator = tree.Begin(); iterator != tree.End(); ++iterator) {
auto location = (*iterator).second;
EXPECT_EQ(location.GetPageId(), 0);
EXPECT_EQ(location.GetSlotNum(), current_key);
current_key = current_key + 1;
}
EXPECT_EQ(current_key, keys.size() + 1);
std::vector<int64_t> remove_keys = {1, 5};
for (auto key : remove_keys) {
index_key.SetFromInteger(key);
tree.Remove(index_key, transaction);
}
start_key = 2;
current_key = start_key;
int64_t size = 0;
for (auto iterator = tree.Begin(); iterator != tree.End(); ++iterator) {
auto location = (*iterator).second;
EXPECT_EQ(location.GetPageId(), 0);
EXPECT_EQ(location.GetSlotNum(), current_key);
current_key = current_key + 1;
size = size + 1;
}
EXPECT_EQ(size, 3);
bpm->UnpinPage(HEADER_PAGE_ID, true);
delete transaction;
delete disk_manager;
delete bpm;
remove("test.db");
remove("test.log");
}
/*
* Score: 10
* Description: Similar to DeleteTest2, except that, during the Remove step,
* a different subset of keys are removed.
*/
TEST(BPlusTreeConcurrentTestC2Seq, DeleteTest2) {
// create KeyComparator and index schema
auto key_schema = ParseCreateStatement("a bigint");
GenericComparator<8> comparator(key_schema.get());
DiskManager *disk_manager = new DiskManager("test.db");
BufferPoolManager *bpm = new BufferPoolManagerInstance(50, disk_manager);
// create b+ tree
BPlusTree<GenericKey<8>, RID, GenericComparator<8>> tree("foo_pk", bpm, comparator);
GenericKey<8> index_key;
RID rid;
// create transaction
Transaction *transaction = new Transaction(0);
// create and fetch header_page
page_id_t page_id;
auto header_page = bpm->NewPage(&page_id);
(void)header_page;
// std::vector<int64_t> keys;
// for (int64_t key = 1; key < 10000; key++) {
// keys.push_back(key);
// }
std::vector<int64_t> keys = {1, 2, 3, 4, 5};
for (auto key : keys) {
int64_t value = key & 0xFFFFFFFF;
rid.Set(static_cast<int32_t>(key >> 32), value);
index_key.SetFromInteger(key);
tree.Insert(index_key, rid, transaction);
}
std::vector<RID> rids;
for (auto key : keys) {
rids.clear();
index_key.SetFromInteger(key);
tree.GetValue(index_key, &rids);
EXPECT_EQ(rids.size(), 1);
int64_t value = key & 0xFFFFFFFF;
EXPECT_EQ(rids[0].GetSlotNum(), value);
}
int64_t start_key = 1;
int64_t current_key = start_key;
index_key.SetFromInteger(start_key);
for (auto iterator = tree.Begin(index_key); iterator != tree.End(); ++iterator) {
auto location = (*iterator).second;
EXPECT_EQ(location.GetPageId(), 0);
EXPECT_EQ(location.GetSlotNum(), current_key);
current_key = current_key + 1;
}
EXPECT_EQ(current_key, keys.size() + 1);
std::vector<int64_t> remove_keys = {1, 5, 3, 4};
for (auto key : remove_keys) {
index_key.SetFromInteger(key);
tree.Remove(index_key, transaction);
}
start_key = 2;
current_key = start_key;
int64_t size = 0;
for (auto iterator = tree.Begin(); iterator != tree.End(); ++iterator) {
auto location = (*iterator).second;
EXPECT_EQ(location.GetPageId(), 0);
EXPECT_EQ(location.GetSlotNum(), current_key);
current_key = current_key + 1;
size = size + 1;
}
EXPECT_EQ(size, 1);
bpm->UnpinPage(HEADER_PAGE_ID, true);
delete transaction;
delete disk_manager;
delete bpm;
remove("test.db");
remove("test.log");
}
/*
* Score: 10
* Description: Insert 10000 keys. Use GetValue and the iterator to iterate
* through the inserted keys. Then remove 9900 inserted keys. Finally, use
* the iterator to check the correctness of the remaining keys.
*/
TEST(BPlusTreeConcurrentTestC2Seq, ScaleTest) {
// create KeyComparator and index schema
auto key_schema = ParseCreateStatement("a bigint");
GenericComparator<8> comparator(key_schema.get());
DiskManager *disk_manager = new DiskManager("test.db");
BufferPoolManager *bpm = new BufferPoolManagerInstance(12, disk_manager);
// create b+ tree
BPlusTree<GenericKey<8>, RID, GenericComparator<8>> tree("foo_pk", bpm, comparator);
GenericKey<8> index_key;
RID rid;
// create transaction
Transaction *transaction = new Transaction(0);
// create and fetch header_page
page_id_t page_id;
auto header_page = bpm->NewPage(&page_id);
(void)header_page;
int64_t scale = 10000;
std::vector<int64_t> keys;
for (int64_t key = 1; key < scale; key++) {
keys.push_back(key);
}
for (auto key : keys) {
int64_t value = key & 0xFFFFFFFF;
rid.Set(static_cast<int32_t>(key >> 32), value);
index_key.SetFromInteger(key);
tree.Insert(index_key, rid, transaction);
}
std::vector<RID> rids;
for (auto key : keys) {
rids.clear();
index_key.SetFromInteger(key);
tree.GetValue(index_key, &rids);
EXPECT_EQ(rids.size(), 1);
int64_t value = key & 0xFFFFFFFF;
EXPECT_EQ(rids[0].GetSlotNum(), value);
}
int64_t start_key = 1;
int64_t current_key = start_key;
for (auto iterator = tree.Begin(); iterator != tree.End(); ++iterator) {
(void)*iterator;
auto location = (*iterator).second;
EXPECT_EQ(location.GetPageId(), 0);
EXPECT_EQ(location.GetSlotNum(), current_key);
current_key = current_key + 1;
}
EXPECT_EQ(current_key, keys.size() + 1);
int64_t remove_scale = 9900;
std::vector<int64_t> remove_keys;
for (int64_t key = 1; key < remove_scale; key++) {
remove_keys.push_back(key);
}
// std::random_shuffle(remove_keys.begin(), remove_keys.end());
for (auto key : remove_keys) {
index_key.SetFromInteger(key);
tree.Remove(index_key, transaction);
}
start_key = 9900;
current_key = start_key;
int64_t size = 0;
index_key.SetFromInteger(start_key);
for (auto iterator = tree.Begin(index_key); iterator != tree.End(); ++iterator) {
(void)*iterator;
current_key = current_key + 1;
size = size + 1;
}
EXPECT_EQ(size, 100);
bpm->UnpinPage(HEADER_PAGE_ID, true);
delete transaction;
delete disk_manager;
delete bpm;
remove("test.db");
remove("test.log");
}
/*
* Score: 10
* Description: Insert a set of keys. Concurrently insert and delete
* a different set of keys.
* At the same time, concurrently get the previously inserted keys.
* Check all the keys get are the same set of keys as previously
* inserted.
*/
TEST(BPlusTreeConcurrentTestC2Seq, SequentialMixTest) {
// create KeyComparator and index schema
auto key_schema = ParseCreateStatement("a bigint");
GenericComparator<8> comparator(key_schema.get());
DiskManager *disk_manager = new DiskManager("test.db");
BufferPoolManager *bpm = new BufferPoolManagerInstance(5, disk_manager);
// create b+ tree
BPlusTree<GenericKey<8>, RID, GenericComparator<8>> tree("foo_pk", bpm, comparator);
GenericKey<8> index_key;
RID rid;
// create transaction
Transaction *transaction = new Transaction(0);
// create and fetch header_page
page_id_t page_id;
auto header_page = bpm->NewPage(&page_id);
(void)header_page;
// first, populate index
std::vector<int64_t> for_insert;
std::vector<int64_t> for_delete;
size_t sieve = 2; // divide evenly
size_t total_keys = 1000;
for (size_t i = 1; i <= total_keys; i++) {
if (i % sieve == 0) {
for_insert.push_back(i);
} else {
for_delete.push_back(i);
}
}
// Insert all the keys, including the ones that will remain at the end and
// the ones that are going to be removed next.
for (size_t i = 0; i < total_keys / 2; i++) {
int64_t insert_key = for_insert[i];
int64_t insert_value = insert_key & 0xFFFFFFFF;
rid.Set(static_cast<int32_t>(insert_key >> 32), insert_value);
index_key.SetFromInteger(insert_key);
tree.Insert(index_key, rid, transaction);
int64_t delete_key = for_delete[i];
int64_t delete_value = delete_key & 0xFFFFFFFF;
rid.Set(static_cast<int32_t>(delete_key >> 32), delete_value);
index_key.SetFromInteger(delete_key);
tree.Insert(index_key, rid, transaction);
}
// Remove the keys in for_delete
for (auto key : for_delete) {
index_key.SetFromInteger(key);
tree.Remove(index_key, transaction);
}
// Only half of the keys should remain
int64_t start_key = 2;
int64_t size = 0;
index_key.SetFromInteger(start_key);
for (auto iterator = tree.Begin(index_key); iterator != tree.End(); ++iterator) {
EXPECT_EQ(((*iterator).first).ToString(), for_insert[size]);
size++;
}
EXPECT_EQ(size, for_insert.size());
bpm->UnpinPage(HEADER_PAGE_ID, true);
delete transaction;
delete disk_manager;
delete bpm;
remove("test.db");
remove("test.log");
}
} // namespace bustub
并发测试结果:

Deadlock?
可以看出,需要持多个锁时,都是从上到下地获取锁,获取锁的方向是相同的。在对 sibling 上锁时,一定持有其 parent page 的锁,因此不可能存在另一个既持有 sibling 锁又持有 parent page 锁的线程来造成循环等待。因此,死锁是不存在的。
但如果把 Index Iterator 也纳入讨论,就有可能产生死锁了。Index Iterator 是从左到右地获取 leaf page 的锁,假如存在一个需要 steal/merge 的 page 尝试获取其 left sibling 的锁,则一个从左到右,一个从右到左,可能会造成循环等待,也就是死锁。因此在 Index Iterator 无法获取锁时,应放弃获取。
Optimization
对于 latch crabbing,存在一种比较简单的优化。在普通的 latch crabbing 中,Insert/Delete 均需对节点上写锁,而越上层的节点被访问的可能性越大,锁竞争也越激烈,频繁对上层节点上互斥的写锁对性能影响较大。因此可以做出如下优化:
Search 操作不变,在 Insert/Delete 操作中,我们可以先乐观地认为不会发生 split/steal/merge,对沿途的节点上读锁,并及时释放,对 leaf page 上写锁。当发现操作对 leaf page 确实不会造成 split/steal/merge 时,可以直接完成操作。当发现操作会使 leaf page split/steal/merge 时,则放弃所有持有的锁,从 root page 开始重新悲观地进行这次操作,即沿途上写锁。
这个优化实现起来比较简单,修改一下 FindLeaf() 即可。
Summary
整个 Project2 的内容大致就是这些。难度相对于 Project1 可以说是陡增。Checkpoint1 的难点主要在细节的处理上,Checkpoint2 的难点则是对 latch crabbing 的正确理解。当看到自己从 0 实现的 B+ 树能够正确运行,特别是可视化时,还是很有成就感的。

浙公网安备 33010602011771号