pyCaffe测试接口与日志可视化分析(输出测试中识别错误的样本)

pyCaffe测试接口与日志可视化分析

https://blog.zhum.in/project/TrafficSignRecognition/pyCaffe%E6%B5%8B%E8%AF%95%E6%8E%A5%E5%8F%A3%E4%B8%8E%E6%97%A5%E5%BF%97%E5%8F%AF%E8%A7%86%E5%8C%96%E5%88%86%E6%9E%90/

日志可视化工具

我们如何监控模型的训练过程呢?Caffe官方在caffe/tools/extra中提供了日志可视化工具plot_training_log.py.example。这一脚本依赖于同目录下的parse_log.sh脚本解析日志文件。该文件默认只支持.log的日志文件,因此我们需要将日志保存为.log格式(其实内容与后缀无关,只是因为parse_log.sh文件中做了后缀名过滤,修改相关代码就可以绕过该限制)。

首先我们用&将训练放入后台,并通过tee管道将Caffe训练过程中的日志保存至log文件:

1
2
3
4
#!/bin/bash
LOG=log/train-`date +%Y-%m-%d-%H-%M-%S`.log
CAFFE=./Caffe/build/tools/caffe
$CAFFE train --solver=../model/solver.prototxt --gpu=0 2>&1 | tee $LOG

 

然后修改plot_training_log.py.example中的load_data方法,此处有一个bug,会导致list out of range错误:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def load_data(data_file, field_idx0, field_idx1):
data = [[], []]
with open(data_file, 'r') as f:
for line in f:
line = line.strip()
if line[0] == '#':
num_fields = len(line.split())
else:
fields = line.split()
if len(fields) != num_fields:
continue
data[0].append(float(fields[field_idx0].strip()))
data[1].append(float(fields[field_idx1].strip()))
return data

 

接下来我们使用命令就可以来将我们的训练日志可视化了,这个工具的用法如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
$ python plot_training_log.py.example chart_type[0-7] /where/to/save.png /path/to/first.log ...
Notes:
1. Supporting multiple logs.
2. Log file name must end with the lower-cased ".log".
Supported chart types:
0: Test accuracy vs. Iters
1: Test accuracy vs. Seconds
2: Test loss vs. Iters
3: Test loss vs. Seconds
4: Train learning rate vs. Iters
5: Train learning rate vs. Seconds
6: Train loss vs. Iters
7: Train loss vs. Seconds

 

我们将loss-iterator图像可视化,观察训练过程中loss的收敛情况:

1
$ python plot_training_log.py.example 2 train_loss_iter.png train.log

 

train_loss_iter

模型测试借口

deply训练好的模型

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
def work(imgdata, all_rect):
data_layer = net.blobs['data']
for resize in [0.5, 1, 2, 4]:
prob_th = 0.95
gbox = 0.1
if resize < 1:
resize = data_layer.shape[2] * 1.0 / imgdata.shape[0]
data = cv2.resize(imgdata, (data_layer.shape[2], data_layer.shape[3]))
else:
data = cv2.resize(imgdata, (imgdata.shape[0] * resize, imgdata.shape[1] * resize))
data = data.transpose(2, 0, 1)
# print data.shape
# data_layer.reshape(*((1,)+data.shape))
netsize = 1024
overlap_size = 256
 
res1 = 4
res2 = 16
pixel_whole = np.zeros((1, data.shape[1] / res1, data.shape[2] / res1))
bbox_whole = np.zeros((4, data.shape[1] / res1, data.shape[2] / res1))
type_whole = np.zeros((1, data.shape[1] / res2, data.shape[2] / res2))
 
tmp = 0
for x in range((data.shape[1] - 1) / netsize + 1):
xl = min(x * netsize, data.shape[1] - netsize - overlap_size)
xr = xl + netsize + overlap_size
xsl = xl if xl == 0 else xl + overlap_size / 2
xsr = xr if xr == data.shape[1] else xr - overlap_size / 2
xtl = xsl - xl
xtr = xsr - xl
for y in range((data.shape[2] - 1) / netsize + 1):
yl = min(y * netsize, data.shape[2] - netsize - overlap_size)
yr = yl + netsize + overlap_size
ysl = yl if yl == 0 else yl + overlap_size / 2
ysr = yr if yr == data.shape[2] else yr - overlap_size / 2
ytl = ysl - yl
ytr = ysr - yl
# print xl,xr,yl,yr,xsl,xsr,ysl,ysr,xtl,xtr,ytl,ytr
fdata = data[:, xl:xr, yl:yr]
 
data_layer.data[...] = fdata
net.forward()
pixel = net.blobs['output_pixel'].data[0]
pixel = np.exp(pixel) / (np.exp(pixel[0]) + np.exp(pixel[1]))
bbox = net.blobs['output_bb'].data[0]
mtypes = net.blobs['output_type'].data[0]
mtypes = np.argmax(mtypes, axis=0)
# print pixel.shape, bbox.shape, mtypes.shape, pixel[1,xtl/res1:xtr/res1, ytl/res1:ytr/res1].shape
 
pixel_whole[:, xsl/res1: xsr/res1, ysl/res1: ysr/res1] = \
pixel[1, xtl/res1: xtr/res1, ytl/res1: ytr/res1]
bbox_whole[:, xsl/res1: xsr/res1, ysl/res1: ysr/res1] = \
bbox[:, xtl/res1: xtr/res1, ytl/res1: ytr/res1]
type_whole[:, xsl/res2: xsr/res2, ysl/res2: ysr/res2] = \
mtypes[xtl/res2: xtr/res2, ytl/res2: ytr/res2]
if resize < 1:
break
if resize < 1:
break
 
# pl.imshow(pixel_whole[0])
# pl.show()
# pl.imshow(type_whole[0])
# pl.show()
 
rects = fix_box(bbox_whole, pixel_whole[0] > prob_th, imgdata.shape[0] * resize,
imgdata.shape[1] * resize, res1)
merge_rects, scores = cv2.groupRectangles(rects.tolist(), 2, gbox)
merge_rects = np.array(merge_rects, np.float32) / resize
# imgdraw = rimgdata.copy()
# draw_rects(imgdraw, merge_rects)
 
# pl.figure(figsize=(20,20))
# pl.imshow(imgdraw)
mrect = merge_rects * resize / res2
if len(mrect) > 0:
mrect[:, [2, 3]] += mrect[:, [0, 1]]
 
for i, rect in enumerate(mrect):
xl = np.floor(rect[0])
yl = np.floor(rect[1])
xr = np.ceil(rect[2]) + 1
yr = np.ceil(rect[3]) + 1
xl = np.clip(xl, 0, type_whole.shape[1])
yl = np.clip(yl, 0, type_whole.shape[2])
xr = np.clip(xr, 0, type_whole.shape[1])
yr = np.clip(yr, 0, type_whole.shape[2])
 
tp = type_whole[0, yl:yr, xl:xr]
uni, num = np.unique(tp, return_counts=True)
maxtp, maxc = 0, 0
for tid, c in zip(uni, num):
if tid != 0 and maxc < c:
maxtp, maxc = tid, c
if maxtp != 0:
all_rect.append((int(maxtp), merge_rects[i].tolist(), float(scores[i]), resize))
# print maxtp, maxc, annos['types'][int(maxtp-1)]

输出测试中识别错误的样本

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
def draw_wrong():
annos = json.loads(open(annos_augmented).read())
result_annos = json.loads(open(aug_aug_annos_results).read())
# print len(result_annos['imgs'])
 
# reload(anno_func)
# result_annos = anno_func.get_refine_rects(annos, rectmap, minscore=0)
# print result_annos['imgs']
sm = anno_func.eval_annos(annos, result_annos, 0.5, types=anno_func.type45, minscore=40, maxboxsize=400)
# # open("../results/%s_result_annos.json" % netname, "w").write(json.dumps(result_annos))
# # open("../results/%s_result.json" % netname, "w").write(json.dumps(sm))
# print sm['report']
 
while True:
imgid = random.sample(result_annos['imgs'].keys(), 1)[0]
if len(sm['wrong']["imgs"][imgid]["objects"]):
break
if len(sm['miss']["imgs"][imgid]["objects"]):
break
print imgid
 
imgdata = anno_func.load_img(annos, datadir, imgid)
 
for anno in [annos, result_annos, sm['wrong'], sm['miss']]:
imgdraw = anno_func.draw_all(anno, datadir, imgid, imgdata)
pl.figure(figsize=(20, 20))
for obj in anno['imgs'][imgid]['objects']:
pl.text(obj['bbox']['xmin'], obj['bbox']['ymin'], obj['category'], fontsize=20, color='red')
pl.imshow(imgdraw)
pl.savefig('wrong/' + imgid)
# pl.show()
pl.close()
 
 

posted on 2017-11-21 19:40  塔上的樹  阅读(1168)  评论(0)    收藏  举报