PyTorch图像分类全流程实战--构建自己的图像分类数据01

前言

【教程地址】
同济子豪兄教学视频:https://space.bilibili.com/1900783/channel/collectiondetail?sid=606800

导入工具包

import numpy as np
import os
import math
import time
import requests
import urllib3
import pandas as pd
import cv2
urllib3.disable_warnings()
from PIL import Image#删除非三通道的图片
import matplotlib.pyplot as plt #画图
%matplotlib inline
from scipy.stats import gaussian_kde
from matplotlib.colors import LogNorm
#可视化图像
import matplotlib.image as mpimg
from mpl_toolkits.axes_grid1 import ImageGrid
#划分数据集
import shutil
import random
#进度条
from tqdm import tqdm
#HTTP请求参数
cookies = {
    'BDqhfp': '%E7%8B%97%E7%8B%97%26%26NaN-1undefined%26%2618880%26%2621',
    'BIDUPSID': '06338E0BE23C6ADB52165ACEB972355B',
    'PSTM': '1646905430',
    'BAIDUID': '104BD58A7C408DABABCAC9E0A1B184B4:FG=1',
    'BDORZ': 'B490B5EBF6F3CD402E515D22BCDA1598',
    'H_PS_PSSID': '35836_35105_31254_36024_36005_34584_36142_36120_36032_35993_35984_35319_26350_35723_22160_36061',
    'BDSFRCVID': '8--OJexroG0xMovDbuOS5T78igKKHJQTDYLtOwXPsp3LGJLVgaSTEG0PtjcEHMA-2ZlgogKK02OTH6KF_2uxOjjg8UtVJeC6EG0Ptf8g0M5',
    'H_BDCLCKID_SF': 'tJPqoKtbtDI3fP36qR3KhPt8Kpby2D62aKDs2nopBhcqEIL4QTQM5p5yQ2c7LUvtynT2KJnz3Po8MUbSj4QoDjFjXJ7RJRJbK6vwKJ5s5h5nhMJSb67JDMP0-4F8exry523ioIovQpn0MhQ3DRoWXPIqbN7P-p5Z5mAqKl0MLPbtbb0xXj_0D6bBjHujtT_s2TTKLPK8fCnBDP59MDTjhPrMypomWMT-0bFH_-5L-l5js56SbU5hW5LSQxQ3QhLDQNn7_JjOX-0bVIj6Wl_-etP3yarQhxQxtNRdXInjtpvhHR38MpbobUPUDa59LUvEJgcdot5yBbc8eIna5hjkbfJBQttjQn3hfIkj0DKLtD8bMC-RDjt35n-Wqxobbtof-KOhLTrJaDkWsx7Oy4oTj6DD5lrG0P6RHmb8ht59JROPSU7mhqb_3MvB-fnEbf7r-2TP_R6GBPQtqMbIQft20-DIeMtjBMJaJRCqWR7jWhk2hl72ybCMQlRX5q79atTMfNTJ-qcH0KQpsIJM5-DWbT8EjHCet5DJJn4j_Dv5b-0aKRcY-tT5M-Lf5eT22-usy6Qd2hcH0KLKDh6gb4PhQKuZ5qutLTb4QTbqWKJcKfb1MRjvMPnF-tKZDb-JXtr92nuDal5TtUthSDnTDMRhXfIL04nyKMnitnr9-pnLJpQrh459XP68bTkA5bjZKxtq3mkjbPbDfn02eCKuj6tWj6j0DNRabK6aKC5bL6rJabC3b5CzXU6q2bDeQN3OW4Rq3Irt2M8aQI0WjJ3oyU7k0q0vWtvJWbbvLT7johRTWqR4enjb3MonDh83Mxb4BUrCHRrzWn3O5hvvhKoO3MA-yUKmDloOW-TB5bbPLUQF5l8-sq0x0bOte-bQXH_E5bj2qRCqVIKa3f',
    'BDSFRCVID_BFESS': '8--OJexroG0xMovDbuOS5T78igKKHJQTDYLtOwXPsp3LGJLVgaSTEG0PtjcEHMA-2ZlgogKK02OTH6KF_2uxOjjg8UtVJeC6EG0Ptf8g0M5',
    'H_BDCLCKID_SF_BFESS': 'tJPqoKtbtDI3fP36qR3KhPt8Kpby2D62aKDs2nopBhcqEIL4QTQM5p5yQ2c7LUvtynT2KJnz3Po8MUbSj4QoDjFjXJ7RJRJbK6vwKJ5s5h5nhMJSb67JDMP0-4F8exry523ioIovQpn0MhQ3DRoWXPIqbN7P-p5Z5mAqKl0MLPbtbb0xXj_0D6bBjHujtT_s2TTKLPK8fCnBDP59MDTjhPrMypomWMT-0bFH_-5L-l5js56SbU5hW5LSQxQ3QhLDQNn7_JjOX-0bVIj6Wl_-etP3yarQhxQxtNRdXInjtpvhHR38MpbobUPUDa59LUvEJgcdot5yBbc8eIna5hjkbfJBQttjQn3hfIkj0DKLtD8bMC-RDjt35n-Wqxobbtof-KOhLTrJaDkWsx7Oy4oTj6DD5lrG0P6RHmb8ht59JROPSU7mhqb_3MvB-fnEbf7r-2TP_R6GBPQtqMbIQft20-DIeMtjBMJaJRCqWR7jWhk2hl72ybCMQlRX5q79atTMfNTJ-qcH0KQpsIJM5-DWbT8EjHCet5DJJn4j_Dv5b-0aKRcY-tT5M-Lf5eT22-usy6Qd2hcH0KLKDh6gb4PhQKuZ5qutLTb4QTbqWKJcKfb1MRjvMPnF-tKZDb-JXtr92nuDal5TtUthSDnTDMRhXfIL04nyKMnitnr9-pnLJpQrh459XP68bTkA5bjZKxtq3mkjbPbDfn02eCKuj6tWj6j0DNRabK6aKC5bL6rJabC3b5CzXU6q2bDeQN3OW4Rq3Irt2M8aQI0WjJ3oyU7k0q0vWtvJWbbvLT7johRTWqR4enjb3MonDh83Mxb4BUrCHRrzWn3O5hvvhKoO3MA-yUKmDloOW-TB5bbPLUQF5l8-sq0x0bOte-bQXH_E5bj2qRCqVIKa3f',
    'indexPageSugList': '%5B%22%E7%8B%97%E7%8B%97%22%5D',
    'cleanHistoryStatus': '0',
    'BAIDUID_BFESS': '104BD58A7C408DABABCAC9E0A1B184B4:FG=1',
    'BDRCVFR[dG2JNJb_ajR]': 'mk3SLVN4HKm',
    'BDRCVFR[-pGxjrCMryR]': 'mk3SLVN4HKm',
    'ab_sr': '1.0.1_Y2YxZDkwMWZkMmY2MzA4MGU0OTNhMzVlNTcwMmM2MWE4YWU4OTc1ZjZmZDM2N2RjYmVkMzFiY2NjNWM4Nzk4NzBlZTliYWU0ZTAyODkzNDA3YzNiMTVjMTllMzQ0MGJlZjAwYzk5MDdjNWM0MzJmMDdhOWNhYTZhMjIwODc5MDMxN2QyMmE1YTFmN2QyY2M1M2VmZDkzMjMyOThiYmNhZA==',
    'delPer': '0',
    'PSINO': '2',
    'BA_HECTOR': '8h24a024042g05alup1h3g0aq0q',
}
headers = {
    'Connection': 'keep-alive',
    'sec-ch-ua': '" Not;A Brand";v="99", "Google Chrome";v="97", "Chromium";v="97"',
    'Accept': 'text/plain, */*; q=0.01',
    'X-Requested-With': 'XMLHttpRequest',
    'sec-ch-ua-mobile': '?0',
    'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/97.0.4692.99 Safari/537.36',
    'sec-ch-ua-platform': '"macOS"',
    'Sec-Fetch-Site': 'same-origin',
    'Sec-Fetch-Mode': 'cors',
    'Sec-Fetch-Dest': 'empty',
    'Referer': 'https://image.baidu.com/search/index?tn=baiduimage&ipn=r&ct=201326592&cl=2&lm=-1&st=-1&fm=result&fr=&sf=1&fmq=1647837998851_R&pv=&ic=&nc=1&z=&hd=&latest=&copyright=&se=1&showtab=0&fb=0&width=&height=&face=0&istype=2&dyTabStr=MCwzLDIsNiwxLDUsNCw4LDcsOQ%3D%3D&ie=utf-8&sid=&word=%E7%8B%97%E7%8B%97',
    'Accept-Language': 'zh-CN,zh;q=0.9',
}

图片爬虫

def craw_single_class(keyword, DOWNLOAD_NUM = 200):
    if os.path.exists('dataset/'+keyword):
        print('文件夹 dataset/{} 已存在,之后直接将爬取到的图片保存至该文件夹中'.format(keyword))
    else:
        os.makedirs('dataset/{}'.format(keyword))
        print('新建文件夹:dataset/{}'.format(keyword))
    count = 1

    with tqdm(total=DOWNLOAD_NUM, position=0, leave=True) as pbar:

        # 爬取第几张
        num = 0

        # 是否继续爬取
        FLAG = True

        while FLAG:

            page = 30 * count

            params = (
                ('tn', 'resultjson_com'),
                ('logid', '12508239107856075440'),
                ('ipn', 'rj'),
                ('ct', '201326592'),
                ('is', ''),
                ('fp', 'result'),
                ('fr', ''),
                ('word', f'{keyword}'),
                ('queryWord', f'{keyword}'),
                ('cl', '2'),
                ('lm', '-1'),
                ('ie', 'utf-8'),
                ('oe', 'utf-8'),
                ('adpicid', ''),
                ('st', '-1'),
                ('z', ''),
                ('ic', ''),
                ('hd', ''),
                ('latest', ''),
                ('copyright', ''),
                ('s', ''),
                ('se', ''),
                ('tab', ''),
                ('width', ''),
                ('height', ''),
                ('face', '0'),
                ('istype', '2'),
                ('qc', ''),
                ('nc', '1'),
                ('expermode', ''),
                ('nojc', ''),
                ('isAsync', ''),
                ('pn', f'{page}'),
                ('rn', '30'),
                ('gsm', '1e'),
                ('1647838001666', ''),
            )

            response = requests.get('https://image.baidu.com/search/acjson', headers=headers, params=params, cookies=cookies)
            if response.status_code == 200:
                try:
                    json_data = response.json().get("data")

                    if json_data:
                        for x in json_data:
                            type = x.get("type")
                            if type not in ["gif"]:
                                img = x.get("thumbURL")
                                fromPageTitleEnc = x.get("fromPageTitleEnc")
                                try:
                                    resp = requests.get(url=img, verify=False)
                                    time.sleep(1)
                                    # print(f"链接 {img}")

                                    # 保存文件名
                                    # file_save_path = f'dataset/{keyword}/{num}-{fromPageTitleEnc}.{type}'
                                    file_save_path = f'dataset/{keyword}/{num}.{type}'
                                    with open(file_save_path, 'wb') as f:
                                        f.write(resp.content)
                                        f.flush()
                                        # print('第 {} 张图像 {} 爬取完成'.format(num, fromPageTitleEnc))
                                        num += 1
                                        pbar.update(1) # 进度条更新

                                    # 爬取数量达到要求
                                    if num > DOWNLOAD_NUM:
                                        FLAG = False
                                        print('{} 张图像爬取完毕'.format(num))
                                        break

                                except Exception:
                                    pass
                except:
                    pass
            else:
                break

            count += 1

format用法:通过位置来填充字符串。会把参数按位置顺序来填充到字符串中,第一个参数是0,然后1 ……
也可以不输入数字,则会按照顺序自动分配,而且一个参数可以多次插入。
python中format用法

craw_single_class('猫猫',DOWNLOAD_NUM=200)

在这里遇到了两个错误:

  1. TypeError:'module' object is not callable
    原因是:模式名和函数名相同了,所以需要进行区别。修改import语句,导入模块内的函数或属性,而不是导入模块。
    改成from tqdm import tqdm
    Solved TypeError: ‘Module’ Object Is Not Callable in Python?
  2. ProxyError: Cannot connect to proxy
    原因:以前的vpn设置没有把注册表的代理删掉,更直接的解决办法是找到注册表里面用户的代理设置,把ProxyEnable的值改为0。
    python 错误:‘Cannot connect to proxy.‘由于目标计算机积极拒绝,无法连接
#爬取多个类
class_list = ['狗狗','机器猫','猫猫']
for i in class_list:
    craw_single_class(i,DOWNLOAD_NUM=200)

image

删除无关图片

#删除gif格式的图像文件
import imghdr
def del_type(path):
    if imghdr.what(path)=='gif':
        os.remove(path)
        print('remove--{}'.format(path))
for each_class in class_list:
    image_path=r'D:\01-learning\01-000-inbox\01-000-01-datawhale\Jan-pytorch-classify-image\dataset\{}'.format(each_class)
    img_list=os.listdir(image_path)

    for img in img_list:
        full_path=os.path.join(image_path,img)
        del_type(full_path)

删除非三通道的图像

#删除非三通道的图像
for fruit in tqdm(os.listdir(dataset_path)):
    for file in os.listdir(os.path.join(dataset_path, fruit)):
        file_path = os.path.join(dataset_path, fruit, file)
        img = np.array(Image.open(file_path))
        try:
            channel = img.shape[2]
            if channel != 3:
                print(file_path, '非三通道,删除')
                os.remove(file_path)
        except:
            print(file_path, '非三通道,删除')
            os.remove(file_path)

image

#封装
def del_channel(file_path):
    img = np.array(Image.open(file_path))
    try:
        channel = img.shape[2]
        if channel != 3:
            print(file_path, '非三通道,删除')
            os.remove(file_path)
        else:
            print(file_path,"此为"+str(channel)+"通道图片")
    except:
        print(file_path, '非三通道,删除')
        os.remove(file_path)
image_path=r'D:\01-learning\01-000-inbox\01-000-01-datawhale\Jan-pytorch-classify-image\dataset\dev'
img_list=os.listdir(image_path)

for img in img_list:
    full_path=os.path.join(image_path,img)
    del_channel(full_path)

图片的形状

#图片的形状
def get_imgsize(dataset_path):
    os.chdir(dataset_path)#cd 到dataset
    df = pd.DataFrame()
    for each in tqdm(os.listdir()):#遍历每个类别
        os.chdir(each)
        for file in os.listdir():#遍历每张图片
            try:
                img = cv2.imread(file)#从文件中读出图片
                df = df.append({'类别':each,'文件名':file,'图像宽':img.shape[1],'图像高':img.shape[0]},ignore_index=True)
            except:
                print(os.path.join(each,file),"读取错误")
        os.chdir('../')
    os.chdir('../')
    return df
dataset_path = r'D:\01-learning\01-000-inbox\01-000-01-datawhale\Jan-pytorch-classify-image\dataset'
df=get_imgsize(dataset_path)
df

image

#可视化图像尺寸分布
x = df['图像宽']
y = df['图像高']

xy = np.vstack([x,y])
z = gaussian_kde(xy)(xy)

# Sort the points by density, so that the densest points are plotted last
idx = z.argsort()
x, y, z = x[idx], y[idx], z[idx]

plt.figure(figsize=(10,10))
# plt.figure(figsize=(12,12))
plt.scatter(x, y, c=z,  s=5, cmap='Spectral_r')
# plt.colorbar()
# plt.xticks([])
# plt.yticks([])

plt.tick_params(labelsize=15)

xy_max = max(max(df['图像宽']), max(df['图像高']))
plt.xlim(xmin=0, xmax=xy_max)
plt.ylim(ymin=0, ymax=xy_max)

plt.ylabel('height', fontsize=25)
plt.xlabel('width', fontsize=25)

plt.savefig('图像尺寸分布.pdf', dpi=120, bbox_inches='tight')

plt.show()

image

划分train、val

#创建训练文件夹pre_data,之后建立train、test
pre_data_path = r'D:\01-learning\01-000-inbox\01-000-01-datawhale\Jan-pytorch-classify-image\pre_data'
#创建train文件夹
os.mkdir(os.path.join(pre_data_path,'train'))
#创建test文件夹
os.mkdir(os.path.join(pre_data_path,'val'))
#创建每个类别的子文件夹
for each in os.listdir(dataset_path):
    os.mkdir(os.path.join(pre_data_path,'train',each))
    os.mkdir(os.path.join(pre_data_path,'val',each))
	#划分训练集、测试集,移动文件
test_frac = 0.2  # 测试集比例
random.seed(123) # 随机数种子,便于复现
print('{:^18} {:^18} {:^18}'.format('类别', '训练集数据个数', '测试集数据个数'))
df = pd.DataFrame()
for fruit in os.listdir(dataset_path): # 遍历每个类别

    # 读取该类别的所有图像文件名
    old_dir = os.path.join(dataset_path, fruit)
    images_filename = os.listdir(old_dir)
    random.shuffle(images_filename) # 随机打乱

    # 划分训练集和测试集
    testset_numer = int(len(images_filename) * test_frac) # 测试集图像个数
    testset_images = images_filename[:testset_numer]      # 获取拟移动至 test 目录的测试集图像文件名
    trainset_images = images_filename[testset_numer:]     # 获取拟移动至 train 目录的训练集图像文件名

    # 移动图像至 test 目录
    for image in testset_images:
        old_img_path = os.path.join(dataset_path, fruit, image)         # 获取原始文件路径
        new_test_path = os.path.join(pre_data_path, 'val', fruit, image) # 获取 test 目录的新文件路径
        shutil.copy(old_img_path, new_test_path) # 移动文件

    # 移动图像至 train 目录
    for image in trainset_images:
        old_img_path = os.path.join(dataset_path, fruit, image)           # 获取原始文件路径
        new_train_path = os.path.join(pre_data_path, 'train', fruit, image) # 获取 train 目录的新文件路径
        shutil.copy(old_img_path, new_train_path) # copy文件

    #
    # # 删除旧文件夹
    # assert len(os.listdir(old_dir)) == 0 # 确保旧文件夹中的所有图像都被移动走
    # shutil.rmtree(old_dir) # 删除文件夹

    # 工整地输出每一类别的数据个数
    print('{:^18} {:^18} {:^18}'.format(fruit, len(trainset_images), len(testset_images)))

    # 保存到表格中
    df = df.append({'class':fruit, 'trainset':len(trainset_images), 'testset':len(testset_images)}, ignore_index=True)

# # 重命名数据集文件夹
# shutil.move(dataset_path, dataset_name+'_split')
#
# 数据集各类别数量统计表格,导出为 csv 文件
df['total'] = df['trainset'] + df['testset']
df.to_csv('数据量统计.csv', index=False)

image

#可视化图片
#指定需要可视化的文件夹
folder_path = r'D:\01-learning\01-000-inbox\01-000-01-datawhale\Jan-pytorch-classify-image\pre_data\train'
#数目
N = 25
# n行,n列
n = math.floor(np.sqrt(N))
images = []
#更改成英文名os.rename(原文件名,新文件名) : 对文件或目录改名
os.rename(os.path.join(folder_path,'机器猫'),os.path.join(folder_path,'doraemon'))
os.rename(os.path.join(folder_path,'狗狗'),os.path.join(folder_path,'dog'))
os.rename(os.path.join(folder_path,'猫猫'),os.path.join(folder_path,'cats'))

更改成英文主要是下面遇到问题了

child_path = r'D:\01-learning\01-000-inbox\01-000-01-datawhale\Jan-pytorch-classify-image\pre_data\train\cats'
for each_img in os.listdir(child_path)[:N]:
    image_path = os.path.join(child_path,each_img)
    img_bgr = cv2.imread(image_path)
    img_rgb = cv2.cvtColor(img_bgr,cv2.COLOR_BGR2RGB)
    images.append(img_rgb)

遇到了问题是:

error: (-215:Assertion failed) !_src.empty() in function 'cv::cvtColor'
原因:是因为文件名中有中文,将处理后文件进行保存后发现英文文件名的图像正常,而中文错误。
解决:
Opencv 解决问题 !_src.empty() in function 'cv::cvtColor'

所以上一步改成英文名之后没有问题了。

画图

#画图
fig = plt.figure(figsize=(10,10),dpi=300)
grid = ImageGrid(fig,111,#子图111
                 nrows_ncols=(n,n),
                 axes_pad = 0.02,#网格间距
                 share_all=True
 )
for ax,im in zip(grid,images):
    ax.imshow(im)
    ax.axis('off')

plt.tight_layout()
plt.show()
plt.savefig('cat25.png')

image

参考文献

【1】python中format用法
【2】Solved TypeError: ‘Module’ Object Is Not Callable in Python?
【3】python 错误:‘Cannot connect to proxy.‘由于目标计算机积极拒绝,无法连接
【4】Opencv 解决问题 !_src.empty() in function 'cv::cvtColor'
【5】cv2库(OpenCV,opencv-python)

posted on 2023-01-18 00:49  琢磨亿下  阅读(358)  评论(0编辑  收藏  举报

导航