可视化loss和metrics
tensorboard 可视化
1 from tensorboardX import SummaryWriter 2 writer = SummaryWriter(args.log_dir) 3 G_loss, G_loss_list = generator_loss(input_images, masks, outputs, ground_truths, extractor) 4 5 writer.add_scalar('G hole loss', G_loss_list[0].item(), count) 6 writer.add_scalar('G valid loss', G_loss_list[1].item(), count) 7 writer.add_scalar('G perceptual loss', G_loss_list[2].item(), count) 8 writer.add_scalar('G style loss', G_loss_list[3].item(), count) 9 writer.add_scalar('G tv loss', G_loss_list[4].item(), count) 10 writer.add_scalar('Psnr', G_loss_list[5], count) 11 writer.add_scalar('ssim', G_loss_list[6], count) 12 writer.add_scalar('G total loss', G_loss.item(), count) 13 writer.close()
pytorch .item()作用
Use torch.Tensor.item() to get a Python number from a tensor containing a single value.
.item()方法返回张量元素的值。
>>> import torch
>>> x = torch.tensor([[1]])
>>> x
tensor([[1]])
>>> x.item()
1
>>> x = torch.tensor(2.5)
>>> x
tensor(2.5000)
>>> x.item()
2.5
记录loss和metric到.dat文件中
1 parser.add_argument('--log_dat', type=str, default='runs/celeba', help='log with dat') 2 log_file = os.path.join(args.log_dat, 'log_dat' + '.dat') 3 def log(logs): 4 with open(log_file, 'a') as f: 5 f.write('%s\n' % ' '.join([str(item[1]) for item in logs])) 6 logs = [ 7 ("G hole loss", G_loss_list[0].item()), 8 ("G valid loss", G_loss_list[1].item()), 9 ('G perceptual loss', G_loss_list[2].item()), 10 ('G style loss', G_loss_list[3].item()), 11 ('Psnr', G_loss_list[5]), 12 ('ssim', G_loss_list[6]), 13 ('G total loss', G_loss.item(), count) 14 ] 15 logs = [ 16 ("epoch", epoch), 17 ("count", count), 18 ] + logs 19 log(logs)
绘制.dat文件的曲线图
1 import numpy as np 2 import matplotlib.pyplot as plt 3 # log_edge.dat 4 import numpy as np 5 import numpy as np 6 from matplotlib import pyplot as plt 7 from scipy.interpolate import make_interp_spline 8 import matplotlib.pyplot as pl 9 with open("log_edge.dat", "r") as f: 10 x = [] 11 y1 = [] 12 y2 = [] 13 for line in f: 14 if not line.strip() or line.startswith('@') or line.startswith('#'): 15 continue 16 row = line.split() 17 x.append(float(row[1])) 18 y1.append(float(row[2])) 19 20 pl.plot(x, y1) 21 pl.show() 22 pl.savefig("sigma enter code here`ng", dpi=300)

浙公网安备 33010602011771号