声明的类模板中成员函数、成员模板函数、友元函数、友元模板函数的定义方式
点击查看代码
template<typename T>
class TensorBase {
public:
void init(std::string name = NULL, int num = 0, int nbDims = 1, std::vector<int> dimA = {1});
// singleton pattern
static TensorBase<T>* getInstance() {
if(instance == NULL)
instance = new TensorBase<T>();
return instance;
}
T* dev_malloc(int nums);
T* host_malloc(int nums);
public:
int size() const;
template<typename T2, typename...Args>
friend void print_display(const T2 &val, const Args&...rest);
template<typename Tc>
friend int compare(const Tc &val1, const Tc & val2);
void print_stride() {
for(auto x : strideA_m) {
std::cout<<x<<' ';
}
}
private:
std::string name_m;
int num_m;
int nbDims_m;
std::vector<int> dimA_m;
std::vector<int> strideA_m;
T* dev_ptr;
T* host_ptr;
static TensorBase<T> *instance;
};
/* data initialize */
template<typename T>
TensorBase<T>* TensorBase<T>::instance = NULL;
template<typename T>
void TensorBase<T>::init(std::string name, int num, int nbDims, std::vector<int> dimA) {
int nums = 1;
for(auto x : dimA) {
nums *= x;
}
if(dimA.size() != nbDims || nums != num) {
std::cout << "input param error!" << std::endl;
return;
}
name_m = name;
num_m = num;
nbDims_m = nbDims;
strideA_m.clear();
dimA_m.assign(dimA.begin(), dimA.end());
for(int i = 0; i < nbDims; i++) {
if(dimA.size() == 1) {
strideA_m.push_back(1);
break;
}
int stride_cal = 1, j = 0;
for(j = i + 1; j < nbDims; j++) {
stride_cal *= dimA[j];
}
strideA_m.push_back(stride_cal);
}
}
/* size() */
template<typename T>
int TensorBase<T>::size() const{
return num_m;
}
/* dev_malloc() */
template<typename T>
T* TensorBase<T>::dev_malloc(int nums) {
dev_ptr = (T*) malloc(nums * sizeof(T));
memset(dev_ptr, 0, nums);
return dev_ptr;
}
/* host_malloc() */
template<typename T>
T* TensorBase<T>::host_malloc(int nums) {
host_ptr = (T*) malloc(nums * sizeof(T));
memset(dev_ptr, 0, nums);
return host_ptr;
}
/* Variadic Function Template */
inline void print_display() {}
template<typename T2, typename...Args>
void print_display(const T2 &val, const Args&...rest) {
std::cout << val << ' ';
print_display(rest...);
}
/* vector<int> Partial specialization */
template< typename...Args>
void print_display(const std::vector<int> &val, const Args&...rest) {
for(auto x : val) { std::cout << x <<' '; }
print_display(rest...);
}
/* vector<float> Partial specialization */
template< typename...Args>
void print_display(const std::vector<float> &val, const Args&...rest) {
for(auto x : val) { std::cout << x <<' '; }
print_display(rest...);
}
/* operator overide */
template<typename To>
std::ostream& operator<< (std::ostream& out , std::vector<To> &arr) {
for(auto x : arr) { out<< x << ' '; }
return out;
}
/* compare() */
template<typename Tc>
int compare(const Tc &val1, const Tc & val2) {
return val1 == val2;
}
注: **友元函数模板的模板类型需与类模板不同**
---
此代码中包含可变参数模板的使用方法,因可变参数是递归调用的,因此需要在定义最后一层递归的函数实现【==最后一次函数实现最好定义为内联,否则可能会出现重定义问题==】