pytorch模型测试中的代码理解
模型训练之后,测试是重中之重,目前,我所碰到的有以下两种方式:
🅰利用
torchvision.utils.save_image来对模型结果直接进行保存🅱先将结果转换为
numpy array之后进行提取与合成
利用torchvision.util.save_image进行测试
了解源码请看官方文档📒
- 若模型的输出是三通道的RGB图像,那么直接利用
torchvision.utils.save_image(model_out, './out.png', padding=0)
- 若模型是单通道的输出时,则需要做一些想要的变化工作:
HR_2_r = model_r(r) 
HR_2_g = model_g(g)
black_2 = torch.zeros(1, 1024, 1024).unsqueeze(0).cuda()
HR_2 = torch.cat((HR_2_r.squeeze(0), HR_2_g.squeeze(0), black_2.squeeze(0))).unsqueeze(0)
torchvision.utils.save_image(HR_2, './out.png', padding=0)
print("saved !")
- 
假设经过模型计算后的 HR_2_r.shape:[1, 1, 1024, 1024]
- 
假设经过模型计算后的 HR_2_g.shape:[1, 1, 1024, 1024]
- 
创建 black_2是因为想获得全黑色无信息的B通道而后结合模型输出的R and G通道来生成RGB的图像,利用unsqueeze(0)来使得torch.zeros(1, 1024, 1024)的形状变成[1, 1, 1024, 1024],使其与其他两个通道相同以便于合并
- 
HR_2_r.squeeze(0)使得其形状由[1, 1, 1024, 1024]变为[1, 1024, 1024],其他同理。之后利用torch.cat来进行合并通道,最后使用torchvision.utils.save_image来进行保存。
利用numpy array来保存图像
model = torch.load(opt.model).cuda()
transform = Compose([ToTensor(), Lambda(lambda x: x.repeat(3,1,1)),])
img = Image.open(opt.input).convert('RGB')
LR_r_spilt, _, _ = img.split()
LR_r = transform(LR_r_spilt)
LR_r = LR_r.unsqueeze(0)
LR_r = LR_r.cuda()
HR_2_r = model(LR_r)
HR_2_r = HR_2_r.cpu()
out_HR_2_r_detach = HR_2_r.data[0].numpy() 
out_HR_2_r = out_HR_2_r_detach * 255.0
out_HR_2_r = out_HR_2_r.clip(0, 255)
out_HR_2_r_result = Image.fromarray(np.uint8(out_HR_2_r[0]), mode='L')
out_HR_2_r_result.save(opt.outputHR2)
print('output image saved to ', opt.outputHR2)
- 假设img.size为[3, 512, 512],经过split后,LR_r_spilt.size为[512, 512],再经过transform后,则模型的输入LR_r.shape:torch.Size([3, 512, 512]),但是网络的输入应该是[batchSize, C, H, W]的4维向量,所以要进行unsqueeze(0)操作,使得LR_r变为torch.Size([1, 3, 512, 512])
- 经过模型计算后的HR_2_r.shape:torch.Size([1, 3, 1024, 1024]),因为模型是利用torch计算的,在进行numpy转换前,应该先用.cpu()做一个数据的转换
- 利用HR_2_r.data[0]是对数据进行分离,使得四维的数据变成三维torch.Size([3, 1024, 1024]),当然也有使用detach()来进行这样操作👉MORE。
- 由于网络的结果是[0, 1]之间的值,而之后要利用np.unit8()做转换,所以要将结果转为[0, 255]之间,所以要乘以255.0。而乘以255后可能使得数据大于255,故用clip(0, 255)来保证数据域是正确的
- 利用out_HR_2[0]来提取第一通道,即R通道,最后利用Image.fromarray来实现从array到Image的转变,此时,out_HR_r_result即为单通道的灰度图
以下是测试输出,可以不看:
Namespace(cuda=True, input='/content/drive/My Drive/TestImg/val/95-512pix-speed7-ave1.tif', model='/content/drive/My Drive/TestImg/H_Adagrad_epoch_r_400.pth', outputHR2='/content/drive/My Drive/TestImg/val2/95_X2.tif', outputHR4='/content/drive/My Drive/TestImg/val4/95_X4.tif')
After Split:(512, 512)
After transform:torch.Size([3, 512, 512])
After unsqueeze:torch.Size([1, 3, 512, 512])
After Model:torch.Size([1, 3, 1024, 1024])
After data[0]:torch.Size([3, 1024, 1024])
After data[0].numpy():(3, 1024, 1024)
# 网络输出
[[[0.0726895  0.05225104 0.05731031 ... 0.08038348 0.09176005 0.11681005]
  [0.05010414 0.03608295 0.05038059 ... 0.07422554 0.08273476 0.09623137]
  [0.05970517 0.04536656 0.04714845 ... 0.07344079 0.07673076 0.08257505]
  ...
  [0.10552698 0.10863981 0.1160759  ... 0.07840011 0.05360675 0.05854338]
  [0.12117186 0.11651129 0.12565961 ... 0.06953818 0.03986105 0.05407226]
  [0.13890806 0.13018039 0.12679848 ... 0.07166685 0.05829817 0.08199665]]
 [[0.06548429 0.04780073 0.05077077 ... 0.07777905 0.08438891 0.08978723]
  [0.04724503 0.03454134 0.04663415 ... 0.07442069 0.07981256 0.08830814]
  [0.05518261 0.04601964 0.04341775 ... 0.07343534 0.07746741 0.07461968]
  ...
  [0.10186043 0.10778397 0.11563006 ... 0.07755548 0.05533218 0.04998586]
  [0.11432448 0.11585081 0.12556481 ... 0.06927705 0.04027918 0.04754832]
  [0.12643194 0.12204635 0.12268594 ... 0.06363812 0.05656445 0.07072814]]
 [[0.06303221 0.04812427 0.05175997 ... 0.0773088  0.08727033 0.09965152]
  [0.04497686 0.03826225 0.04548033 ... 0.07697698 0.08157301 0.09077145]
  [0.05506983 0.04517868 0.04388288 ... 0.07452801 0.07780927 0.07592566]
  ...
  [0.10016826 0.10798413 0.11497539 ... 0.078601   0.05558252 0.05430539]
  [0.11343104 0.11644471 0.12352681 ... 0.06749362 0.03726557 0.0477092 ]
  [0.125687   0.12345901 0.12515822 ... 0.06406513 0.05620003 0.0672472 ]]]
  
# 乘以 255
[[[18.535824  13.324016  14.61413   ... 20.497787  23.398813  29.786564 ]
  [12.776556   9.201153  12.84705   ... 18.927513  21.097364  24.539    ]
  [15.224818  11.568472  12.022855  ... 18.727402  19.566343  21.056639 ]
  ...
  [26.90938   27.70315   29.599356  ... 19.992027  13.669721  14.928563 ]
  [30.898825  29.710379  32.0432    ... 17.732235  10.164569  13.788426 ]
  [35.421555  33.196     32.333614  ... 18.275047  14.866034  20.909145 ]]
 [[16.698492  12.189187  12.9465475 ... 19.833658  21.519173  22.895744 ]
  [12.047482   8.808042  11.891709  ... 18.977276  20.352201  22.518576 ]
  [14.071565  11.735009  11.071527  ... 18.726011  19.75419   19.028019 ]
  ...
  [25.974411  27.484913  29.485666  ... 19.776648  14.109707  12.746393 ]
  [29.152742  29.541956  32.019028  ... 17.665648  10.271191  12.124823 ]
  [32.240147  31.12182   31.284914  ... 16.22772   14.423935  18.035675 ]]
 [[16.073214  12.271688  13.198793  ... 19.713745  22.253935  25.411137 ]
  [11.469099   9.756873  11.597483  ... 19.629131  20.801117  23.14672  ]
  [14.042808  11.520564  11.190133  ... 19.004642  19.841366  19.361044 ]
  ...
  [25.542906  27.535952  29.318726  ... 20.043255  14.173544  13.847875 ]
  [28.924913  29.6934    31.499336  ... 17.210873   9.50272   12.165845 ]
  [32.050186  31.482048  31.915346  ... 16.336607  14.331007  17.148035 ]]]
# clip之后
[[[18.535824  13.324016  14.61413   ... 20.497787  23.398813  29.786564 ]
  [12.776556   9.201153  12.84705   ... 18.927513  21.097364  24.539    ]
  [15.224818  11.568472  12.022855  ... 18.727402  19.566343  21.056639 ]
  ...
  [26.90938   27.70315   29.599356  ... 19.992027  13.669721  14.928563 ]
  [30.898825  29.710379  32.0432    ... 17.732235  10.164569  13.788426 ]
  [35.421555  33.196     32.333614  ... 18.275047  14.866034  20.909145 ]]
 [[16.698492  12.189187  12.9465475 ... 19.833658  21.519173  22.895744 ]
  [12.047482   8.808042  11.891709  ... 18.977276  20.352201  22.518576 ]
  [14.071565  11.735009  11.071527  ... 18.726011  19.75419   19.028019 ]
  ...
  [25.974411  27.484913  29.485666  ... 19.776648  14.109707  12.746393 ]
  [29.152742  29.541956  32.019028  ... 17.665648  10.271191  12.124823 ]
  [32.240147  31.12182   31.284914  ... 16.22772   14.423935  18.035675 ]]
 [[16.073214  12.271688  13.198793  ... 19.713745  22.253935  25.411137 ]
  [11.469099   9.756873  11.597483  ... 19.629131  20.801117  23.14672  ]
  [14.042808  11.520564  11.190133  ... 19.004642  19.841366  19.361044 ]
  ...
  [25.542906  27.535952  29.318726  ... 20.043255  14.173544  13.847875 ]
  [28.924913  29.6934    31.499336  ... 17.210873   9.50272   12.165845 ]
  [32.050186  31.482048  31.915346  ... 16.336607  14.331007  17.148035 ]]]
# 提取第一通道
[[18.535824 13.324016 14.61413  ... 20.497787 23.398813 29.786564]
 [12.776556  9.201153 12.84705  ... 18.927513 21.097364 24.539   ]
 [15.224818 11.568472 12.022855 ... 18.727402 19.566343 21.056639]
 ...
 [26.90938  27.70315  29.599356 ... 19.992027 13.669721 14.928563]
 [30.898825 29.710379 32.0432   ... 17.732235 10.164569 13.788426]
 [35.421555 33.196    32.333614 ... 18.275047 14.866034 20.909145]]
After PostDeal:(1024, 1024)
output image saved to  /content/drive/My Drive/TestImg/val2/95_X2.tif
output image saved to  /content/drive/My Drive/TestImg/val4/95_X4.tif
Write by Gqq

 
                
            
         
         浙公网安备 33010602011771号
浙公网安备 33010602011771号