#include <iostream>
#include <atomic>
template <typename T>
class shared_ptr {
private:
T* ptr;
std::atomic<int>* ref_count;
public:
// 构造函数
explicit shared_ptr(T* p = nullptr) : ptr(p), ref_count(new std::atomic<int>(1)) {
if (ptr == nullptr) {
ref_count->store(0); // 如果没有指向对象,引用计数为0
}
}
// 拷贝构造函数
shared_ptr(const shared_ptr<T>& other) : ptr(other.ptr), ref_count(other.ref_count) {
if (ptr) {
ref_count->fetch_add(1); // 增加引用计数
}
}
// 移动构造函数
shared_ptr(shared_ptr<T>&& other) noexcept : ptr(other.ptr), ref_count(other.ref_count) {
other.ptr = nullptr; // 清空源指针
other.ref_count = nullptr;
}
// 析构函数
~shared_ptr() {
if (ptr && ref_count && ref_count->fetch_sub(1) == 1) {
delete ptr;
delete ref_count;
}
}
// 拷贝赋值操作符
shared_ptr<T>& operator=(const shared_ptr<T>& other) {
if (this != &other) {
if (ptr && ref_count && ref_count->fetch_sub(1) == 1) {
delete ptr;
delete ref_count;
}
ptr = other.ptr;
ref_count = other.ref_count;
if (ptr) {
ref_count->fetch_add(1);
}
}
return *this;
}
// 移动赋值操作符
shared_ptr<T>& operator=(shared_ptr<T>&& other) noexcept {
if (this != &other) {
if (ptr && ref_count && ref_count->fetch_sub(1) == 1) {
delete ptr;
delete ref_count;
}
ptr = other.ptr;
ref_count = other.ref_count;
other.ptr = nullptr;
other.ref_count = nullptr;
}
return *this;
}
// 解引用操作符
T& operator*() const {
return *ptr;
}
// 指针操作符
T* operator->() const {
return ptr;
}
// 获取引用计数
int use_count() const {
return ref_count ? ref_count->load() : 0;
}
// 是否为空
bool operator!() const {
return ptr == nullptr;
}
// 获取原始指针
T* get() const {
return ptr;
}
};
// 测试
class Test {
public:
void hello() {
std::cout << "Hello, shared_ptr!" << std::endl;
}
};
int main() {
shared_ptr<Test> sp1(new Test());
std::cout << "Use count of sp1: " << sp1.use_count() << std::endl;
{
shared_ptr<Test> sp2 = sp1; // 拷贝构造
std::cout << "Use count of sp1: " << sp1.use_count() << std::endl;
sp2->hello();
}
std::cout << "Use count of sp1 after sp2 goes out of scope: " << sp1.use_count() << std::endl;
return 0;
}