shared_ptr 模拟实现

#include <iostream>
#include <atomic>

template <typename T>
class shared_ptr {
private:
    T* ptr;
    std::atomic<int>* ref_count;

public:
    // 构造函数
    explicit shared_ptr(T* p = nullptr) : ptr(p), ref_count(new std::atomic<int>(1)) {
        if (ptr == nullptr) {
            ref_count->store(0); // 如果没有指向对象,引用计数为0
        }
    }

    // 拷贝构造函数
    shared_ptr(const shared_ptr<T>& other) : ptr(other.ptr), ref_count(other.ref_count) {
        if (ptr) {
            ref_count->fetch_add(1); // 增加引用计数
        }
    }

    // 移动构造函数
    shared_ptr(shared_ptr<T>&& other) noexcept : ptr(other.ptr), ref_count(other.ref_count) {
        other.ptr = nullptr; // 清空源指针
        other.ref_count = nullptr;
    }

    // 析构函数
    ~shared_ptr() {
        if (ptr && ref_count && ref_count->fetch_sub(1) == 1) {
            delete ptr;
            delete ref_count;
        }
    }

    // 拷贝赋值操作符
    shared_ptr<T>& operator=(const shared_ptr<T>& other) {
        if (this != &other) {
            if (ptr && ref_count && ref_count->fetch_sub(1) == 1) {
                delete ptr;
                delete ref_count;
            }
            ptr = other.ptr;
            ref_count = other.ref_count;
            if (ptr) {
                ref_count->fetch_add(1);
            }
        }
        return *this;
    }

    // 移动赋值操作符
    shared_ptr<T>& operator=(shared_ptr<T>&& other) noexcept {
        if (this != &other) {
            if (ptr && ref_count && ref_count->fetch_sub(1) == 1) {
                delete ptr;
                delete ref_count;
            }
            ptr = other.ptr;
            ref_count = other.ref_count;
            other.ptr = nullptr;
            other.ref_count = nullptr;
        }
        return *this;
    }

    // 解引用操作符
    T& operator*() const {
        return *ptr;
    }

    // 指针操作符
    T* operator->() const {
        return ptr;
    }

    // 获取引用计数
    int use_count() const {
        return ref_count ? ref_count->load() : 0;
    }

    // 是否为空
    bool operator!() const {
        return ptr == nullptr;
    }

    // 获取原始指针
    T* get() const {
        return ptr;
    }
};

// 测试
class Test {
public:
    void hello() {
        std::cout << "Hello, shared_ptr!" << std::endl;
    }
};

int main() {
    shared_ptr<Test> sp1(new Test());
    std::cout << "Use count of sp1: " << sp1.use_count() << std::endl;

    {
        shared_ptr<Test> sp2 = sp1; // 拷贝构造
        std::cout << "Use count of sp1: " << sp1.use_count() << std::endl;
        sp2->hello();
    }

    std::cout << "Use count of sp1 after sp2 goes out of scope: " << sp1.use_count() << std::endl;

    return 0;
}
posted @ 2025-01-15 11:14  陈浩辉  阅读(26)  评论(0)    收藏  举报
ヾ(≧O≦)〃嗷~