** "mindspore\lite\examples\transfer_learning\src\dataset.cc"注释2**
一、代码用处
该处带代码主要用于读取文件函数以及数据集处理
二、代码注释
char *ReadFile(const std::string &file, size_t *size) {//定义一个读取文件方法
MS_ASSERT(size != nullptr);//假设它的条件返回错误,则终止程序运行
std::string realPath(file);
std::ifstream ifs(realPath);//读取文件
if (!ifs.good()) {//判断是否读取文件成功
std::cerr << "file: " << realPath << " does not exist";
return nullptr;
}
if (!ifs.is_open()) {//处理是否打开
std::cerr << "file: " << realPath << " open failed";
return nullptr;
}
ifs.seekg(0, std::ios::end);//读取完毕
*size = ifs.tellg();//输入流中获取位置
std::unique_ptr<char[]> buf(new (std::nothrow) char[*size]);//分配内存
if (buf == nullptr) {
std::cerr << "malloc buf failed, file: " << realPath;
ifs.close();
return nullptr;
}
ifs.seekg(0, std::ios::beg);//设置指针位置在beg
ifs.read(buf.get(), *size);//读取五文件信息
ifs.close();
return buf.release();//释放捕捉
}
DataSet::~DataSet() {//dataset函数
for (auto itr = train_data_.begin(); itr != train_data_.end(); ++itr) {
auto ptr = std::get<0>(*itr);//清理无关数据
delete[] ptr;
}
for (auto itr = test_data_.begin(); itr != test_data_.end(); ++itr) {
auto ptr = std::get<0>(*itr);
delete[] ptr;
}
for (auto itr = val_data_.begin(); itr != val_data_.end(); ++itr) {
auto ptr = std::get<0>(*itr);
delete[] ptr;
}
}
int DataSet::Init(const std::string &data_base_directory, database_type type) {
InitializeBMPFoldersDatabase(data_base_directory);
return 0;
}
void DataSet::InitializeBMPFoldersDatabase(std::string dpath) {//初始化 BMP 文件夹数据库
size_t file_size = 0;
const int ratio = 5;
auto vec = ReadDir(dpath);//读取目录
int running_index = 1;
for (const auto ft : vec) {//对目录每一项进行操作
int label;
std::string file_name;
std::tie(label, file_name) = ft;
char *data = ReadBitmapFile(file_name, &file_size);//文件数据
DataLabelTuple data_entry = std::make_tuple(data, label);将data和label变成tuple数据
if ((expected_data_size_ == 0) || (file_size == expected_data_size_)) {
if (running_index % ratio == 0) {
val_data_.push_back(data_entry);//根据筛选条件来添加数据
} else if (running_index % ratio == 1) {
test_data_.push_back(data_entry);
} else {
train_data_.push_back(data_entry);
}
running_index++;
}
}
}
//读取数据集目录函数
std::vector<FileTuple> DataSet::ReadDir(const std::string dpath) {
std::vector<FileTuple> vec;
struct dirent *entry = nullptr;
num_of_classes_ = 10;
for (int class_id = 0; class_id < num_of_classes_; class_id++) {
std::string dirname = dpath + "/" + std::to_string(class_id);
DIR *dp = opendir(dirname.c_str());
if (dp != nullptr) {
while ((entry = readdir(dp))) {
std::string filename = dirname + "/" + entry->d_name;
if (filename.find(".bmp") != std::string::npos) {
FileTuple ft = make_tuple(class_id, filename);
vec.push_back(ft);
}
}
closedir(dp);
} else {
std::cerr << "open directory: " << dirname << " failed." << std::endl;
}
}
return vec;
}