手动下载CIFAR-10 数据集
手动下载CIFAR-10 数据集,不使用torchvision.datasets.CIFAR10()方法
CIFAR-10 数据集下载
下载地址:https://www.cs.toronto.edu/~kriz/cifar.html
数据集简单说明:
batches.meta 的文件它为上述 labels 数组中的数字标签提供有意义的名称。例如,label_names[0] == “飞机”, label_names[1] == “汽车”等。
data_batch { data_batch_1、 data_batch_2、...、 data_batch_5 } 文件为训练集
test_batch {test_batch}文件为测试集
数据集解压配置(Python 版本)
1、先解压-》文件夹名为:cifar-10-batches-py

2、使用CIFAR-10 数据集官网提供的unpickle(file)方法
import os
import sys
sys.path.append("..")
import numpy as np
import matplotlib.pyplot as plt
# 读取数据以字典
def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='latin1')
    return dict
#读取数据
def get_data(dir_path):
    #获取所有数据文件
    data_files= os.listdir(dir_path)    #path.glob('*.*')
    print("data_files: ",data_files)
    #定义训练集和测试集,以及标签
    # data_train,data_test,labes={},{},{}
    data_train,data_test,tarin_labels,test_labels,label_names=[],[],[],[],[]
    for file in data_files:
        file=os.path.join(dir_path, file)   #构建文件路径
        if 'data' in file:
            data_train.append(unpickle(file)['data'])
            tarin_labels.extend(unpickle(file)['labels'])
        elif 'test' in file:
            data_test.append(unpickle(file)['data'])
            test_labels.extend(unpickle(file)['labels'])
        elif 'meta' in file:
            #加载标签名
            label_names = unpickle(file)['label_names']
    #加载训练集
    data_train = np.concatenate(data_train)
    tarin_labels = np.array(tarin_labels)
    # 将图像数据转换为正确的格式
    data_train = data_train.reshape((50000,3,32,32)).transpose(0,2,3,1)
    #加载测试集
    data_test=np.concatenate(data_test)
    test_labels=np.array(test_labels)
    data_test = data_test.reshape((10000,3,32,32)).transpose(0,2,3,1)
    return data_train,data_test,tarin_labels,test_labels,label_names
if __name__=="__main__":
    #定义数据集的文件夹路径
    dir_path=r'D:\rgzn\images\cifar-10-batches-py'  #文件夹路径
    data_train,data_test,tarin_labels,test_labels,label_names=get_data(dir_path=dir_path)
    # 打印数据集的样本数和类别
    print(f"训练集样本数: {len(tarin_labels)}")
    print(f"测试集样本数: {len(test_labels)}")
    print(f"类别: {label_names}")
    #检验
    plt.figure(figsize=(10, 8))
    for i in range(10):
        plt.subplot(5, 2, i + 1)
        img = data_train[i]
        plt.imshow(img)
        plt.title(label_names[tarin_labels[i]])
        plt.axis('off')
    plt.show()
                    
                
                
            
        
浙公网安备 33010602011771号