PYTORCH基本语法

PYTORCH基本语法

文件

.pkl文件(pickle库)

pkl文件是python里面保存文件的一种格式,如果直接打开会显示一堆序列化的东西(二进制文件)。常用于保存神经网络训练的模型或者各种需要存储的数据。

  1. 保存神经网络训练模型举例(使用pytorch进行保存)
  • 保存整个网络:torch.save(net, ‘net.pkl’)
  • 保存网络的状态信息:torch.save(net.state_dict(), ‘net_params.pkl’)
  • 提取神经网络的方法:torch.load(‘net.pkl’)
    ————————————————
    原文链接:https://blog.csdn.net/Ving_x/article/details/114488844
  1. 存储数据

Series(常用数据结构)

pandas两个主要的数据结构:Series和DataFrame。
Series是一种类似于一维数组的对象,它由一组数据(各种NumPy数据类型)以及一组与之相关的数据标签(即索引)组成。

项目结构

  1. models
    • backbone.py 模型核心部分
      • ResNet
      • VGG
      • Inception
      • mobileNet
    • model.py
  2. Augmentation
    • 图片增强
    • 图片翻转
  3. visualization
  4. metric 评估指标
  5. dataset
    • read_dataset.py
    • train
      • image
      • label
    • validation
      • image
      • label
    • test
      • image
      • label
  6. utils(optional)
    • loss.py
    • prior_box:自定义获得方式先验框
    • iou: 交并比
    • match:先验框与真实边界框匹配方法
    • nms: 非极大值抑制
  7. oonx(optional) 转oonx格式
  8. train.py
  9. valid.py
  10. test.py
  11. weights:权重文件夹,用于保存权重参数的文件夹

数据

处理的基础数据都是dataset, 数据集索引tensor

dataset

并不是图片数据本身,而是一个个类的实例,提供了__getitem__、__len__等属性方法

__getitem__: 根据索引得到对应元素

simpledataset.__getitem__(0):{'x':tensor(-1.),'y':tensor(1.)}
simpledataset[0]:{'x':tensor(-1.),'y':tensor(1.)}

__len__:长度

simpledataset.__len__():10
len(simpledataset):10
from torch.utils.data import DataLoader

from dataset import FlatFolderDataset,PairWiseDataset

将图片转化为Tensor格式

transform = transforms.Compose(
[
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,), std=(0.5,)) #标准化到0-1之间 
])

使用的时候,将定义的transform赋值给图片的transform属性,则图片将转化为tensor

train_dataset[0][0].shape: torch.Size([3, 28, 28])

其中,需要注意的是,图片转化为tensor,首位表示的是图片的通道数

DataLoader:批量加载数据集

分batch加载训练数据集,是一个list,一个回合叫epoch,得到可迭代对象iterator

 train_loader = DataLoader( dataset=train_dataset,
                            batch_size=args.batch_size,
                            shuffle=shuffle,
                            num_workers=args.num_worker,
                            sampler=sampler,
                            drop_last=True )

enumberate: 同时返回下标索引和数据本身

for i,source_path in enumerate(test_source_images):
    source_img = Image.open(source_path).convert('RGB')

导包

tqdm: 进度条

tqdm来源于阿拉伯词汇:taqaddum,意思是“progress,”

from tqdm import tqdm
progress_bar = tqdm(range(iteration+1,args.num_iter))

参数

parser.add_argument('--num_worker',type=int,default=8)

  • num_worker: worker : 工作进程,负责把batch 加载到内存RAM
posted @ 2022-11-20 23:06  Hecto  阅读(180)  评论(0)    收藏  举报