PYTORCH基本语法
PYTORCH基本语法
文件
.pkl文件(pickle库)
pkl文件是python里面保存文件的一种格式,如果直接打开会显示一堆序列化的东西(二进制文件)。常用于保存神经网络训练的模型或者各种需要存储的数据。
- 保存神经网络训练模型举例(使用pytorch进行保存)
- 保存整个网络:torch.save(net, ‘net.pkl’)
- 保存网络的状态信息:torch.save(net.state_dict(), ‘net_params.pkl’)
- 提取神经网络的方法:torch.load(‘net.pkl’)
————————————————
原文链接:https://blog.csdn.net/Ving_x/article/details/114488844
- 存储数据
Series(常用数据结构)
pandas两个主要的数据结构:Series和DataFrame。
Series是一种类似于一维数组的对象,它由一组数据(各种NumPy数据类型)以及一组与之相关的数据标签(即索引)组成。
项目结构
- models
- backbone.py 模型核心部分
- ResNet
- VGG
- Inception
- mobileNet
- model.py
- backbone.py 模型核心部分
- Augmentation
- 图片增强
- 图片翻转
- visualization
- metric 评估指标
- dataset
- read_dataset.py
- train
- image
- label
- validation
- image
- label
- test
- image
- label
- utils(optional)
- loss.py
- prior_box:自定义获得方式先验框
- iou: 交并比
- match:先验框与真实边界框匹配方法
- nms: 非极大值抑制
- oonx(optional) 转oonx格式
- train.py
- valid.py
- test.py
- weights:权重文件夹,用于保存权重参数的文件夹
数据
处理的基础数据都是dataset, 数据集索引tensor
dataset
并不是图片数据本身,而是一个个类的实例,提供了__getitem__、__len__等属性方法
__getitem__: 根据索引得到对应元素
simpledataset.__getitem__(0):{'x':tensor(-1.),'y':tensor(1.)}
simpledataset[0]:{'x':tensor(-1.),'y':tensor(1.)}
__len__:长度
simpledataset.__len__():10
len(simpledataset):10
from torch.utils.data import DataLoader
from dataset import FlatFolderDataset,PairWiseDataset
将图片转化为Tensor格式
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(mean=(0.5,), std=(0.5,)) #标准化到0-1之间
])
使用的时候,将定义的transform赋值给图片的transform属性,则图片将转化为tensor
train_dataset[0][0].shape: torch.Size([3, 28, 28])
其中,需要注意的是,图片转化为tensor,首位表示的是图片的通道数
DataLoader:批量加载数据集
分batch加载训练数据集,是一个list,一个回合叫epoch,得到可迭代对象iterator
train_loader = DataLoader( dataset=train_dataset,
batch_size=args.batch_size,
shuffle=shuffle,
num_workers=args.num_worker,
sampler=sampler,
drop_last=True )
enumberate: 同时返回下标索引和数据本身
for i,source_path in enumerate(test_source_images):
source_img = Image.open(source_path).convert('RGB')
导包
tqdm: 进度条
tqdm来源于阿拉伯词汇:taqaddum,意思是“progress,”
from tqdm import tqdm
progress_bar = tqdm(range(iteration+1,args.num_iter))
参数
parser.add_argument('--num_worker',type=int,default=8)
- num_worker: worker : 工作进程,负责把batch 加载到内存RAM

浙公网安备 33010602011771号