c++读取mnist数据集到数组

#include <iostream>  
#include <fstream> 
#include <vector> 


using namespace std; 

int train_image[60000][28][28];
int train_label[60000];
int test_image[10000][28][28];
int test_label[10000];

int ReverseInt(int i)  
{  
    unsigned char ch1, ch2, ch3, ch4;  
    ch1 = i & 255;  
    ch2 = (i >> 8) & 255;  
    ch3 = (i >> 16) & 255;  
    ch4 = (i >> 24) & 255;  
    return((int)ch1 << 24) + ((int)ch2 << 16) + ((int)ch3 << 8) + ch4;  
}  
  
void read_Mnist_Label(char filename[], vector<int>&labels)  
{  
    ifstream file(filename, ios::binary);  
    if (file.is_open())  
    {  
        int magic_number = 0;  
        int number_of_images = 0;  
        file.read((char*)&magic_number, sizeof(magic_number));  
        file.read((char*)&number_of_images, sizeof(number_of_images));  
        magic_number = ReverseInt(magic_number);  
        number_of_images = ReverseInt(number_of_images);  
        //cout << "magic number = " << magic_number << endl;  
        //cout << "number of images = " << number_of_images << endl;  
          
        for (int i = 0; i < number_of_images; i++)  
        {  
            unsigned char label = 0;  
            file.read((char*)&label, sizeof(label));  
            labels.push_back((double)label);  
        }  
          
    }  
}  
  
void read_Mnist_Images(char filename[], vector< vector<int> > &images)  
{  
    ifstream file(filename, ios::binary);  
    if (file.is_open())  
    {  
        int magic_number = 0;  
        int number_of_images = 0;  
        int n_rows = 0;  
        int n_cols = 0;  
        unsigned char label;  
        file.read((char*)&magic_number, sizeof(magic_number));  
        file.read((char*)&number_of_images, sizeof(number_of_images));  
        file.read((char*)&n_rows, sizeof(n_rows));  
        file.read((char*)&n_cols, sizeof(n_cols));  
        magic_number = ReverseInt(magic_number);  
        number_of_images = ReverseInt(number_of_images);  
        n_rows = ReverseInt(n_rows);  
        n_cols = ReverseInt(n_cols);  
  
        //cout << "magic number = " << magic_number << endl;  
        //cout << "number of images = " << number_of_images << endl;  
        //cout << "rows = " << n_rows << endl;  
        //cout << "cols = " << n_cols << endl;  
  
        for (int i = 0; i < number_of_images; i++)  
        {  
            vector<int>tp;  
            for (int r = 0; r < n_rows; r++)  
            {  
                for (int c = 0; c < n_cols; c++)  
                {  
                    unsigned char image = 0;  
                    file.read((char*)&image, sizeof(image));  
                    tp.push_back(image);  
                }  
            }  
            images.push_back(tp);  
        }  
    }  
}  

void convert_array_image(vector< vector<int> > &feature_vector, bool is_train)
{
    for (int i = 0; i < feature_vector.size(); i++)  
    {  
        for (int a = 0; a < 28; a++)  
        {  
            for (int b = 0; b < 28; b++)
            {
                if (is_train == 1) 
                    train_image[i][a][b] = feature_vector[i][a*28+b];
                else
                    test_image[i][a][b] = feature_vector[i][a*28+b];
            }
        }  
    }     
} 

void convert_array_label(vector<int>&labels, bool is_train)
{
    for (int i = 0; i < labels.size(); i++)
    {
        if (is_train == 1)
            train_label[i] = labels[i];
        else
            test_label[i] = labels[i];
    //    cout << train_labels[i] << "&" << train_l[i] << " ";
    }
}

int main()
{
    char train_image_name[] = "train-images.idx3-ubyte";
    char train_label_name[] = "train-labels.idx1-ubyte";
    char test_image_name[] = "t10k-images.idx3-ubyte";
    char test_label_name[] = "t10k-labels.idx1-ubyte";
    
    
    // read mnist training data 
    vector< vector<int> > train_feature_vector;  
    read_Mnist_Images(train_image_name, train_feature_vector);  
    convert_array_image(train_feature_vector, 1);
    
    vector<int> train_labels;
    read_Mnist_Label(train_label_name, train_labels);
    convert_array_label(train_labels, 1);
    
    
    // read mnist test data
    vector< vector<int> > test_feature_vector;
    read_Mnist_Images(test_image_name, test_feature_vector);
    convert_array_image(test_feature_vector, 0);
    
    vector<int> test_labels;
    read_Mnist_Label(test_label_name, test_labels);
    convert_array_label(test_labels, 0);
    
    
}

 

posted @ 2020-07-07 15:49  平平的圆圆  阅读(376)  评论(0)    收藏  举报