可视化loss和metrics

tensorboard 可视化

 1 from tensorboardX import SummaryWriter
 2 writer = SummaryWriter(args.log_dir)
 3         G_loss, G_loss_list = generator_loss(input_images, masks, outputs, ground_truths, extractor)
 4 
 5         writer.add_scalar('G hole loss', G_loss_list[0].item(), count)
 6         writer.add_scalar('G valid loss', G_loss_list[1].item(), count)
 7         writer.add_scalar('G perceptual loss', G_loss_list[2].item(), count)
 8         writer.add_scalar('G style loss', G_loss_list[3].item(), count)
 9         writer.add_scalar('G tv loss', G_loss_list[4].item(), count)
10         writer.add_scalar('Psnr', G_loss_list[5], count)
11         writer.add_scalar('ssim', G_loss_list[6], count)
12         writer.add_scalar('G total loss', G_loss.item(), count)
13 writer.close()
View Code

pytorch .item()作用

Use torch.Tensor.item() to get a Python number from a tensor containing a single value.

.item()方法返回张量元素的值。

>>> import torch
>>> x = torch.tensor([[1]])
>>> x
tensor([[1]])
>>> x.item()
1
>>> x = torch.tensor(2.5)
>>> x
tensor(2.5000)
>>> x.item()
2.5

 记录loss和metric到.dat文件中

 1 parser.add_argument('--log_dat', type=str, default='runs/celeba', help='log with dat')
 2 log_file = os.path.join(args.log_dat, 'log_dat' + '.dat')
 3 def log(logs):
 4     with open(log_file, 'a') as f:
 5         f.write('%s\n' % ' '.join([str(item[1]) for item in logs]))
 6 logs = [
 7             ("G hole loss", G_loss_list[0].item()),
 8             ("G valid loss", G_loss_list[1].item()),
 9             ('G perceptual loss', G_loss_list[2].item()),
10             ('G style loss', G_loss_list[3].item()),
11             ('Psnr', G_loss_list[5]),
12             ('ssim', G_loss_list[6]),
13             ('G total loss', G_loss.item(), count)
14         ]
15 logs = [
16             ("epoch", epoch),
17             ("count", count),
18           ] + logs
19 log(logs)
View Code

绘制.dat文件的曲线图

 1 import numpy as np
 2 import matplotlib.pyplot as plt
 3 # log_edge.dat
 4 import numpy as np
 5 import numpy as np
 6 from matplotlib import pyplot as plt
 7 from scipy.interpolate import make_interp_spline
 8 import matplotlib.pyplot as pl
 9 with open("log_edge.dat", "r") as f:
10     x = []
11     y1 = []
12     y2 = []
13     for line in f:
14         if not line.strip() or line.startswith('@') or line.startswith('#'):
15             continue
16         row = line.split()
17         x.append(float(row[1]))
18         y1.append(float(row[2]))
19 
20     pl.plot(x, y1)
21     pl.show()
22     pl.savefig("sigma enter code here`ng", dpi=300)
View Code
posted @ 2021-04-16 20:02  临近边缘  阅读(226)  评论(0)    收藏  举报