## （原）人脸姿态时别HYBRID COARSE-FINE CLASSIFICATION FOR HEAD POSE ESTIMATION

https://www.cnblogs.com/darkknightzh/p/12150128.html

HYBRID COARSE-FINE CLASSIFICATION FOR HEAD POSE ESTIMATION

https://arxiv.org/abs/1901.06778

# 1. 网络结构

2)     其他角度区间（除198个类别的角度区间之外）只用于计算交叉熵损失（如下图所示）。

3)     不同角度区间的交叉熵损失权重不同。

4)     本文MSE损失的权重较大（为2

5)     训练时使用softmax计算概率。测试时使用带temperaturesoftmax计算概率（由于代码中T=1，实际上等效于softmax）。

6)     https://arxiv.org/abs/1503.02531可知，给定输入logit ${{z}_{i}}$，其softmax temperature的输出${{q}_{i}}$计算如下：

${{q}_{i}}=\frac{\exp ({{z}_{i}}/T)}{\sum\nolimits_{j}{\exp ({{z}_{j}}/T)}}$

$Loss=\alpha \centerdot MSE(y,{{y}^{*}})+\sum\limits_{i=1}^{num}{{{\beta }_{i}}\centerdot H({{y}_{i}},y_{i}^{*})}$

# 2. 代码

## 2.1 网络结构

1 class Multinet(nn.Module):
2     # Hopenet with 3 output layers for yaw, pitch and roll
3     # Predicts Euler angles by binning and regression with the expected value
4     def __init__(self, block, layers, num_bins):
5         self.inplanes = 64
6         super(Multinet, self).__init__()
7         self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
8         self.bn1 = nn.BatchNorm2d(64)
9         self.relu = nn.ReLU(inplace=True)
10         self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
11         self.layer1 = self._make_layer(block, 64, layers[0])
12         self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
13         self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
14         self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
15         self.avgpool = nn.AvgPool2d(7)    # 至此为Resnet的骨干网络
16         self.fc_yaw = nn.Linear(512 * block.expansion, num_bins)     #  和hopenet类似，只是num_bins=198
17         self.fc_pitch = nn.Linear(512 * block.expansion, num_bins)   #  和hopenet类似，只是num_bins=198
18         self.fc_roll = nn.Linear(512 * block.expansion, num_bins)    #  和hopenet类似，只是num_bins=198
19
20         self.fc_yaw_1 = nn.Linear(512 * block.expansion, 66)   # 66和deep head pose一致
21         self.fc_yaw_2 = nn.Linear(512 * block.expansion, 18)   # 其他为新的fc层
22         self.fc_yaw_3 = nn.Linear(512 * block.expansion, 6)
23         self.fc_yaw_4 = nn.Linear(512 * block.expansion, 2)
24
25         self.fc_pitch_1 = nn.Linear(512 * block.expansion, 66)
26         self.fc_pitch_2 = nn.Linear(512 * block.expansion, 18)
27         self.fc_pitch_3 = nn.Linear(512 * block.expansion, 6)
28         self.fc_pitch_4 = nn.Linear(512 * block.expansion, 2)
29
30         self.fc_roll_1 = nn.Linear(512 * block.expansion, 66)
31         self.fc_roll_2 = nn.Linear(512 * block.expansion, 18)
32         self.fc_roll_3 = nn.Linear(512 * block.expansion, 6)
33         self.fc_roll_4 = nn.Linear(512 * block.expansion, 2)
34
35         # Vestigial layer from previous experiments
36         self.fc_finetune = nn.Linear(512 * block.expansion + 3, 3)  # 未使用
37
38         for m in self.modules():
39             if isinstance(m, nn.Conv2d):
40                 n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
41                 m.weight.data.normal_(0, math.sqrt(2. / n))
42             elif isinstance(m, nn.BatchNorm2d):
43                 m.weight.data.fill_(1)
44                 m.bias.data.zero_()
45
46     def _make_layer(self, block, planes, blocks, stride=1):
47         downsample = None
48         if stride != 1 or self.inplanes != planes * block.expansion:
49             downsample = nn.Sequential(
50                 nn.Conv2d(self.inplanes, planes * block.expansion,
51                           kernel_size=1, stride=stride, bias=False),
52                 nn.BatchNorm2d(planes * block.expansion),
53             )
54
55         layers = []
56         layers.append(block(self.inplanes, planes, stride, downsample))
57         self.inplanes = planes * block.expansion
58         for i in range(1, blocks):
59             layers.append(block(self.inplanes, planes))
60
61         return nn.Sequential(*layers)
62
63     def forward(self, x):
64         x = self.conv1(x)
65         x = self.bn1(x)
66         x = self.relu(x)
67         x = self.maxpool(x)
68
69         x = self.layer1(x)
70         x = self.layer2(x)
71         x = self.layer3(x)
72         x = self.layer4(x)
73
74         x = self.avgpool(x)
75         x = x.view(x.size(0), -1)  # 得到骨干网络的特征
76         pre_yaw = self.fc_yaw(x)     # 以下得到yaw、pitch、roll等的其他特征
77         pre_pitch = self.fc_pitch(x)
78         pre_roll = self.fc_roll(x)
79
80         pre_yaw_1 = self.fc_yaw_1(x)
81         pre_pitch_1 = self.fc_pitch_1(x)
82         pre_roll_1 = self.fc_roll_1(x)
83
84         pre_yaw_2 = self.fc_yaw_2(x)
85         pre_pitch_2 = self.fc_pitch_2(x)
86         pre_roll_2 = self.fc_roll_2(x)
87
88         pre_yaw_3 = self.fc_yaw_3(x)
89         pre_pitch_3 = self.fc_pitch_3(x)
90         pre_roll_3 = self.fc_roll_3(x)
91
92         pre_yaw_4 = self.fc_yaw_4(x)
93         pre_pitch_4 = self.fc_pitch_4(x)
94         pre_roll_4 = self.fc_roll_4(x)
95
96         return pre_yaw,pre_yaw_1,pre_yaw_2,pre_yaw_3,pre_yaw_4, pre_pitch,pre_pitch_1,pre_pitch_2,pre_pitch_3,pre_pitch_4, pre_roll,pre_roll_1,pre_roll_2,pre_roll_3,pre_roll_4
View Code

## 2.2 训练代码

1 def parse_args():
2     """Parse input arguments."""
3     parser = argparse.ArgumentParser(description='Head pose estimation using the Hopenet network.')
4     parser.add_argument('--gpu', dest='gpu_id', help='GPU device id to use [0]', default=0, type=int)
5     parser.add_argument('--num_epochs', dest='num_epochs', help='Maximum number of training epochs.', default=25, type=int)
6     parser.add_argument('--batch_size', dest='batch_size', help='Batch size.', default=32, type=int)
7     parser.add_argument('--lr', dest='lr', help='Base learning rate.', default=0.000001, type=float)
8     parser.add_argument('--dataset', dest='dataset', help='Dataset type.', default='AFLW_multi', type=str)
9     parser.add_argument('--data_dir', dest='data_dir', help='Directory path for data.', default='', type=str)
10     parser.add_argument('--filename_list', dest='filename_list', help='Path to text file containing relative paths for every example.', default='/tools/AFLW_train.txt', type=str)
11     parser.add_argument('--output_string', dest='output_string', help='String appended to output snapshots.', default = '', type=str)
12     parser.add_argument('--alpha', dest='alpha', help='Regression loss coefficient.', default=2, type=float)
13     parser.add_argument('--snapshot', dest='snapshot', help='Path of model snapshot.', default='', type=str)
14
15     args = parser.parse_args()
16     return args
17
18 def get_ignored_params(model):
19     # Generator function that yields ignored params.
20     b = [model.conv1, model.bn1, model.fc_finetune]
21     for i in range(len(b)):
22         for module_name, module in b[i].named_modules():
23             if 'bn' in module_name:
24                 module.eval()
25             for name, param in module.named_parameters():
26                 yield param
27
28 def get_non_ignored_params(model):
29     # Generator function that yields params that will be optimized.
30     b = [model.layer1, model.layer2, model.layer3, model.layer4]
31     for i in range(len(b)):
32         for module_name, module in b[i].named_modules():
33             if 'bn' in module_name:
34                 module.eval()
35             for name, param in module.named_parameters():
36                 yield param
37
38 def get_fc_params(model):
39     # Generator function that yields fc layer params.
40     b = [model.fc_yaw, model.fc_pitch, model.fc_roll,
41          model.fc_yaw_1, model.fc_pitch_1, model.fc_roll_1,
42          model.fc_yaw_2, model.fc_pitch_2, model.fc_roll_2,
43          model.fc_yaw_3, model.fc_pitch_3, model.fc_roll_3]
44     for i in range(len(b)):
45         for module_name, module in b[i].named_modules():
46             for name, param in module.named_parameters():
47                 yield param
48
50     # By user apaszke from discuss.pytorch.org
51     model_dict = model.state_dict()
52     snapshot = {k: v for k, v in snapshot.items() if k in model_dict}
53     model_dict.update(snapshot)
55
56 if __name__ == '__main__':
57     args = parse_args()
58
59     cudnn.enabled = True
60     num_epochs = args.num_epochs
61     batch_size = args.batch_size
62     gpu = args.gpu_id
63
64     if not os.path.exists('output/snapshots'):
65         os.makedirs('output/snapshots')
66
67     # ResNet50 structure
68     model = hopenet.Multinet(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], 198)   # 载入模型
69
70     if args.snapshot == '':
72     else:
75
77
78     transformations = transforms.Compose([transforms.Resize(240),
79     transforms.RandomCrop(224), transforms.ToTensor(),
80     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
81
82     if args.dataset == 'Pose_300W_LP':
83         pose_dataset = datasets.Pose_300W_LP(args.data_dir, args.filename_list, transformations)
84     elif args.dataset == 'Pose_300W_LP_multi':
85         pose_dataset = datasets.Pose_300W_LP_multi(args.data_dir, args.filename_list, transformations)
86     elif args.dataset == 'Pose_300W_LP_random_ds':
87         pose_dataset = datasets.Pose_300W_LP_random_ds(args.data_dir, args.filename_list, transformations)
89         pose_dataset = datasets.Synhead(args.data_dir, args.filename_list, transformations)
90     elif args.dataset == 'AFLW2000':
91         pose_dataset = datasets.AFLW2000(args.data_dir, args.filename_list, transformations)
92     elif args.dataset == 'BIWI':
93         pose_dataset = datasets.BIWI(args.data_dir, args.filename_list, transformations)
94     elif args.dataset == 'BIWI_multi':
95         pose_dataset = datasets.BIWI_multi(args.data_dir, args.filename_list, transformations)
96     elif args.dataset == 'AFLW':
97         pose_dataset = datasets.AFLW(args.data_dir, args.filename_list, transformations)
98     elif args.dataset == 'AFLW_multi':        # 载入数据的dataset
99         pose_dataset = datasets.AFLW_multi(args.data_dir, args.filename_list, transformations)
100     elif args.dataset == 'AFLW_aug':
101         pose_dataset = datasets.AFLW_aug(args.data_dir, args.filename_list, transformations)
102     elif args.dataset == 'AFW':
103         pose_dataset = datasets.AFW(args.data_dir, args.filename_list, transformations)
104     else:
105         print('Error: not a valid dataset name')
106         sys.exit()
107
109
110     model.cuda(gpu)
111     criterion = nn.CrossEntropyLoss().cuda(gpu)
112     reg_criterion = nn.MSELoss().cuda(gpu)
113     # Regression loss coefficient
114     alpha = args.alpha
115
116     softmax = nn.Softmax(dim=1).cuda(gpu)
117     idx_tensor = [idx for idx in range(198)]
118     idx_tensor = Variable(torch.FloatTensor(idx_tensor)).cuda(gpu)
119
120     optimizer = torch.optim.Adam([{'params': get_ignored_params(model), 'lr': 0},
121                                   {'params': get_non_ignored_params(model), 'lr': args.lr},
122                                   {'params': get_fc_params(model), 'lr': args.lr * 5}],
123                                    lr = args.lr)
124
126     for epoch in range(num_epochs):
127         for i, (images, labels, labels_0, labels_1, labels_2, labels_3, cont_labels, name) in enumerate(train_loader):
128             images = Variable(images).cuda(gpu)
129
130             # Binned labels
131             label_yaw = Variable(labels[:,0]).cuda(gpu)
132             label_pitch = Variable(labels[:,1]).cuda(gpu)
133             label_roll = Variable(labels[:,2]).cuda(gpu)
134
135             label_yaw_1 = Variable(labels_0[:,0]).cuda(gpu)
136             label_pitch_1 = Variable(labels_0[:,1]).cuda(gpu)
137             label_roll_1 = Variable(labels_0[:,2]).cuda(gpu)
138
139             label_yaw_2 = Variable(labels_1[:,0]).cuda(gpu)
140             label_pitch_2 = Variable(labels_1[:,1]).cuda(gpu)
141             label_roll_2 = Variable(labels_1[:,2]).cuda(gpu)
142
143             label_yaw_3 = Variable(labels_2[:,0]).cuda(gpu)
144             label_pitch_3 = Variable(labels_2[:,1]).cuda(gpu)
145             label_roll_3 = Variable(labels_2[:,2]).cuda(gpu)
146
147             label_yaw_4 = Variable(labels_3[:,0]).cuda(gpu)
148             label_pitch_4 = Variable(labels_3[:,1]).cuda(gpu)
149             label_roll_4 = Variable(labels_3[:,2]).cuda(gpu)
150
151             # Continuous labels
152             label_yaw_cont = Variable(cont_labels[:,0]).cuda(gpu)
153             label_pitch_cont = Variable(cont_labels[:,1]).cuda(gpu)
154             label_roll_cont = Variable(cont_labels[:,2]).cuda(gpu)
155
156             # Forward pass
157             yaw,yaw_1,yaw_2,yaw_3,yaw_4, pitch,pitch_1,pitch_2,pitch_3,pitch_4, roll,roll_1,roll_2,roll_3,roll_4 = model(images)     # 得到各个特征
158
159             # Cross entropy loss  # 各个交叉熵损失
160             loss_yaw,loss_yaw_1,loss_yaw_2,loss_yaw_3,loss_yaw_4 = criterion(yaw, label_yaw),criterion(yaw_1, label_yaw_1),criterion(yaw_2, label_yaw_2),criterion(yaw_3, label_yaw_3),criterion(yaw_4, label_yaw_4)
161             loss_pitch,loss_pitch_1,loss_pitch_2,loss_pitch_3,loss_pitch_4 = criterion(pitch, label_pitch),criterion(pitch_1, label_pitch_1),criterion(pitch_2, label_pitch_2),criterion(pitch_3, label_pitch_3),criterion(pitch_4, label_pitch_4)
162             loss_roll,loss_roll_1,loss_roll_2,loss_roll_3,loss_roll_4 = criterion(roll, label_roll),criterion(roll_1, label_roll_1),criterion(roll_2, label_roll_2),criterion(roll_3, label_roll_3),criterion(roll_4, label_roll_4)
163
164             # MSE loss  # 归一化特征
165             yaw_predicted = softmax(yaw)
166             pitch_predicted = softmax(pitch)
167             roll_predicted = softmax(roll)
168
169             yaw_predicted = torch.sum(yaw_predicted * idx_tensor, 1) - 99  # 此部分和deep head pose计算一致
170             pitch_predicted = torch.sum(pitch_predicted * idx_tensor, 1) - 99
171             roll_predicted = torch.sum(roll_predicted * idx_tensor, 1) - 99
172
173             loss_reg_yaw = reg_criterion(yaw_predicted, label_yaw_cont)  # 此部分和deep head pose计算一致
174             loss_reg_pitch = reg_criterion(pitch_predicted, label_pitch_cont)
175             loss_reg_roll = reg_criterion(roll_predicted, label_roll_cont)
176
177             # Total loss
178             total_loss_yaw = alpha * loss_reg_yaw + 7*loss_yaw + 5*loss_yaw_1 + 3*loss_yaw_2 + 1*loss_yaw_3 + 1*loss_yaw_4  # 各个角度区间的加权总损失
179             total_loss_pitch = alpha * loss_reg_pitch + 7*loss_pitch + 5*loss_pitch_1 + 3*loss_pitch_2 + 1*loss_pitch_3 + 1*loss_pitch_4
180             total_loss_roll = alpha * loss_reg_roll + 7*loss_roll + 5*loss_roll_1 + 3*loss_roll_2 + 1*loss_roll_3 + 1*loss_pitch_4
181
182             loss_seq = [total_loss_yaw, total_loss_pitch, total_loss_roll]
183             grad_seq = [torch.tensor(1.0).cuda(gpu) for _ in range(len(loss_seq))]
186             optimizer.step()
187
188             if (i+1) % 100 == 0:
189                 print ('Epoch [%d/%d], Iter [%d/%d] Losses: Yaw %.4f, Pitch %.4f, Roll %.4f'
190                        %(epoch+1, num_epochs, i+1, len(pose_dataset)//batch_size, total_loss_yaw.item(), total_loss_pitch.item(), total_loss_roll.item()))
191         # Save models at numbered epochs.
192         if epoch % 1 == 0 and epoch < num_epochs:
193             print('Taking snapshot...')
194             torch.save(model.state_dict(),
195             'output/snapshots/' + args.output_string + '_epoch_'+ str(epoch+1) + '.pkl')
View Code

## 2.3 测试代码

1 def parse_args():
2     """Parse input arguments."""
3     parser = argparse.ArgumentParser(description='Head pose estimation using the Hopenet network.')
4     parser.add_argument('--gpu', dest='gpu_id', help='GPU device id to use [0]', default=0, type=int)
5     parser.add_argument('--data_dir', dest='data_dir', help='Directory path for data.', default='/AFLW2000/', type=str)
6     parser.add_argument('--filename_list', dest='filename_list', help='Path to text file containing relative paths for every example.',
7           default='/tools/AFLW2000_filename_filtered.txt', type=str)
8     parser.add_argument('--snapshot', dest='snapshot', help='Name of model snapshot.',
9           default='/output/snapshots/AFLW2000/_epoch_9.pkl', type=str)
10     parser.add_argument('--batch_size', dest='batch_size', help='Batch size.', default=1, type=int)
11     parser.add_argument('--save_viz', dest='save_viz', help='Save images with pose cube.', default=False, type=bool)
12     parser.add_argument('--dataset', dest='dataset', help='Dataset type.', default='AFLW2000', type=str)
13
14     args = parser.parse_args()
15
16     return args
17
18 if __name__ == '__main__':
19     args = parse_args()
20
21     cudnn.enabled = True
22     gpu = args.gpu_id
23     snapshot_path = args.snapshot
24
25     # ResNet50 structure
26     model = hopenet.Multinet(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], 198)
27
32
34
35     transformations = transforms.Compose([transforms.Resize(224),
36     transforms.CenterCrop(224), transforms.ToTensor(),
37     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
38
39     if args.dataset == 'Pose_300W_LP':
40         pose_dataset = datasets.Pose_300W_LP(args.data_dir, args.filename_list, transformations)
41     elif args.dataset == 'Pose_300W_LP_multi':
42         pose_dataset = datasets.Pose_300W_LP_multi(args.data_dir, args.filename_list, transformations)
43     elif args.dataset == 'Pose_300W_LP_random_ds':
44         pose_dataset = datasets.Pose_300W_LP_random_ds(args.data_dir, args.filename_list, transformations)
46         pose_dataset = datasets.Synhead(args.data_dir, args.filename_list, transformations)
47     elif args.dataset == 'AFLW2000':
48         pose_dataset = datasets.AFLW2000(args.data_dir, args.filename_list, transformations)
49     elif args.dataset == 'BIWI':
50         pose_dataset = datasets.BIWI(args.data_dir, args.filename_list, transformations)
51     elif args.dataset == 'BIWI_multi':
52         pose_dataset = datasets.BIWI_multi(args.data_dir, args.filename_list, transformations)
53     elif args.dataset == 'AFLW':
54         pose_dataset = datasets.AFLW(args.data_dir, args.filename_list, transformations)
55     elif args.dataset == 'AFLW_multi':
56         pose_dataset = datasets.AFLW_multi(args.data_dir, args.filename_list, transformations)
57     elif args.dataset == 'AFLW_aug':
58         pose_dataset = datasets.AFLW_aug(args.data_dir, args.filename_list, transformations)
59     elif args.dataset == 'AFW':
60         pose_dataset = datasets.AFW(args.data_dir, args.filename_list, transformations)
61     else:
62         print('Error: not a valid dataset name')
63         sys.exit()
65                                                batch_size=args.batch_size,
66                                                num_workers=2)
67
68     model.cuda(gpu)
69
71
72     # Test the Model
73     model.eval()  # Change model to 'eval' mode (BN uses moving mean/var).
74     total = 0
75
76     idx_tensor = [idx for idx in range(198)]
77     idx_tensor = torch.FloatTensor(idx_tensor).cuda(gpu)
78
79     yaw_error = .0
80     pitch_error = .0
81     roll_error = .0
82
83     l1loss = torch.nn.L1Loss(size_average=False)
84     for i, (images, labels, cont_labels, name) in enumerate(test_loader):
85     #for i, (images, labels, labels_0, labels_1, labels_2, labels_3, cont_labels, name) in enumerate(test_loader):
86         images = Variable(images).cuda(gpu)
87         total += cont_labels.size(0)
88
89         label_yaw = cont_labels[:,0].float()
90         label_pitch = cont_labels[:,1].float()
91         label_roll = cont_labels[:,2].float()
92
93         yaw,yaw_1,yaw_2,yaw_3,yaw_4, pitch,pitch_1,pitch_2,pitch_3,pitch_4, roll,roll_1,roll_2,roll_3,roll_4 = model(images)  # 得到特征
94
95         # Binned predictions
96         _, yaw_bpred = torch.max(yaw.data, 1)
97         _, pitch_bpred = torch.max(pitch.data, 1)
98         _, roll_bpred = torch.max(roll.data, 1)
99
100         # Continuous predictions
101         yaw_predicted = utils.softmax_temperature(yaw.data, 1)  # 带temperature的softmax
102         pitch_predicted = utils.softmax_temperature(pitch.data, 1)
103         roll_predicted = utils.softmax_temperature(roll.data, 1)
104
105         yaw_predicted = torch.sum(yaw_predicted * idx_tensor, 1).cpu() - 99     # 计算期望
106         pitch_predicted = torch.sum(pitch_predicted * idx_tensor, 1).cpu() - 99
107         roll_predicted = torch.sum(roll_predicted * idx_tensor, 1).cpu() - 99
108
109         # Mean absolute error
110         yaw_error += torch.sum(torch.abs(yaw_predicted - label_yaw))
111         pitch_error += torch.sum(torch.abs(pitch_predicted - label_pitch))
112         roll_error += torch.sum(torch.abs(roll_predicted - label_roll))
113
114         # Save first image in batch with pose cube or axis.
115         if args.save_viz:
116             name = name[0]
117             if args.dataset == 'BIWI':
118                 cv2_img = cv2.imread(os.path.join(args.data_dir, name + '_rgb.png'))
119             else:
120                 cv2_img = cv2.imread(os.path.join(args.data_dir, name + '.jpg'))
121             if args.batch_size == 1:
122                 error_string = 'y %.2f, p %.2f, r %.2f' % (torch.sum(torch.abs(yaw_predicted - label_yaw)), torch.sum(torch.abs(pitch_predicted - label_pitch)), torch.sum(torch.abs(roll_predicted - label_roll)))
123                 cv2.putText(cv2_img, error_string, (30, cv2_img.shape[0]- 30), fontFace=1, fontScale=1, color=(0,0,255), thickness=2)
124             # utils.plot_pose_cube(cv2_img, yaw_predicted[0], pitch_predicted[0], roll_predicted[0], size=100)
125             utils.draw_axis(cv2_img, yaw_predicted[0], pitch_predicted[0], roll_predicted[0], tdx = 200, tdy= 200, size=100)
126             cv2.imwrite(os.path.join('output/images', name + '.jpg'), cv2_img)
127
128     print('Test error in degrees of the model on the ' + str(total) +
129     ' test images. Yaw: %.4f, Pitch: %.4f, Roll: %.4f, MAE: %.4f' % (yaw_error / total,
130     pitch_error / total, roll_error / total, (yaw_error+pitch_error+roll_error)/(3.0*total)))
View Code

## 2.4 softmax_temperature代码

1 def softmax_temperature(tensor, temperature):
2     result = torch.exp(tensor / temperature)
3     result = torch.div(result, torch.sum(result, 1).unsqueeze(1).expand_as(result))  # 带temperature的softmax
4     return result
View Code

posted on 2020-01-04 19:36  darkknightzh  阅读(710)  评论(0编辑  收藏  举报