add-pytorch代码实现
运行的是GitHub关于GAN算法的最高赞代码https://github.com/corenel/pytorch-adda
这个代码直接运行是不能直接出结果的,需要进行以下修改:
大佬的安装环境是
- Python 3.6
- PyTorch 0.2.0
但是我在anaconda上搭建的pytorch0.2.0显示的是cuda版本报错
所以我换成了
- torch1.7.0
- torchvision 0.8.1
这里torch和torchvision有版本对应要求,需要注意
torch和anaconda安装教程见我另一篇文档。
修改方案
首先会有一些版本不兼容需要修改的提示,如data[0]变成item()等,按照提示修改就可以
这里主要介绍一些难以寻找的大的修改。
导入MNIST和UPSP数据集
从torchvision.datasets.MNIST下载即可,代码在datasets文件夹里mnist.py 和usps.py下载后的MNIST图像都是灰度图像,只有一个通道。所以运行原来的程序会报错:RuntimeError: output with shape [1, 28, 28] doesn't match the broadcast shape [3, 28, 28]。
报错原因:这是因为mnist图像都是灰度图像,只有一个通道,而上面的transforms.Normalize 却对三个通道都归一化了,这肯定会报错,所以只要像下面修改即可:
pre_process = transforms.Compose([transforms.ToTensor(),
transforms.Normalize(
(0.5,),(0.5,))])
需要注意的是mnist.py和usps.py这两个函数里面都需要这样修改
float->long的字符类型出错
天知道我找这个错误找了多久,修改为
pretrain.py
acc += pred_cls.eq(labels.data).long().cpu().sum().item()#增加了.item(),.long()
优化后的精度没有之前的精度高
torchvision版本要为0.2.0版本,0.2.1版本每次加载数据都减去了一次平均值
posted on 2020-11-05 09:38 doubleqing 阅读(311) 评论(0) 编辑 收藏 举报