milvus源码vector_index中Dataset数据类的使用

1、dataset数据类定义

using Value = std::any;
using ValuePtr = std::shared_ptr<Value>;

class Dataset {
 public:
    Dataset() = default;

    template <typename T>
    void
    Set(const std::string& k, T&& v) {
        std::lock_guard<std::mutex> lk(mutex_);
        data_[k] = std::make_shared<Value>(std::forward<T>(v));
    }

    template <typename T>
    T
    Get(const std::string& k) {
        std::lock_guard<std::mutex> lk(mutex_);
        try {
            return std::any_cast<T>(*(data_.at(k)));
        } catch (...) {
            throw std::logic_error("Can't find this key");
        }
    }

    const std::map<std::string, ValuePtr>&
    data() const {
        return data_;
    }

 private:
    std::mutex mutex_;
    std::map<std::string, ValuePtr> data_;
};
using DatasetPtr = std::shared_ptr<Dataset>;

2、定义一个宏解析数据类

能够直接解析dataset_ptr,拿到dim、rows和数据指针p_data

#define GETTENSOR(dataset_ptr)                            \
    int64_t dim = dataset_ptr->Get<int64_t>(meta::DIM);   \
    int64_t rows = dataset_ptr->Get<int64_t>(meta::ROWS); \
    const void* p_data = dataset_ptr->Get<const void*>(meta::TENSOR);

3、使用示例

调用宏之后可以拿到解析后的数据dim、rows和数据指针p_data

void
IVF::Train(const DatasetPtr& dataset_ptr, const Config& config) {
    GETTENSOR(dataset_ptr)

    int64_t nlist = config[IndexParams::nlist].get<int64_t>();
    faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
    faiss::Index* coarse_quantizer = new faiss::IndexFlat(dim, metric_type);
    auto index = std::make_shared<faiss::IndexIVFFlat>(coarse_quantizer, dim, nlist, metric_type);
    index->own_fields = true;
    index->train(rows, reinterpret_cast<const float*>(p_data));
    index_ = index;
}

 

posted on 2021-03-31 14:33  wulc++  阅读(149)  评论(0编辑  收藏  举报