add-pytorch代码实现

运行的是GitHub关于GAN算法的最高赞代码https://github.com/corenel/pytorch-adda
这个代码直接运行是不能直接出结果的,需要进行以下修改:

大佬的安装环境是

  • Python 3.6
  • PyTorch 0.2.0

但是我在anaconda上搭建的pytorch0.2.0显示的是cuda版本报错

所以我换成了

  • torch1.7.0
  • torchvision 0.8.1

这里torch和torchvision有版本对应要求,需要注意

torch和anaconda安装教程见我另一篇文档。

修改方案

首先会有一些版本不兼容需要修改的提示,如data[0]变成item()等,按照提示修改就可以

这里主要介绍一些难以寻找的大的修改。

导入MNIST和UPSP数据集

从torchvision.datasets.MNIST下载即可,代码在datasets文件夹里mnist.py 和usps.py下载后的MNIST图像都是灰度图像,只有一个通道。所以运行原来的程序会报错:RuntimeError: output with shape [1, 28, 28] doesn't match the broadcast shape [3, 28, 28]。

报错原因:这是因为mnist图像都是灰度图像,只有一个通道,而上面的transforms.Normalize 却对三个通道都归一化了,这肯定会报错,所以只要像下面修改即可:

    pre_process = transforms.Compose([transforms.ToTensor(),
                                      transforms.Normalize(
                                          (0.5,),(0.5,))])

需要注意的是mnist.py和usps.py这两个函数里面都需要这样修改

float->long的字符类型出错

天知道我找这个错误找了多久,修改为

pretrain.py

      acc += pred_cls.eq(labels.data).long().cpu().sum().item()#增加了.item(),.long()

优化后的精度没有之前的精度高
torchvision版本要为0.2.0版本,0.2.1版本每次加载数据都减去了一次平均值

posted on 2020-11-05 09:38  doubleqing  阅读(311)  评论(0编辑  收藏  举报

导航