TensorFlow里面损失函数

损失算法的选取

损失函数的选取取决于输入标签数据的类型:

  • 如果输入的是实数、无界的值,损失函数使用平方差;
  • 如果输入标签是位矢量(分类标志),使用交叉熵会更适合。

1.均值平方差

 

 

在TensorFlow没有单独的MSE函数,不过由于公式比较简单,往往开发者都会自己组合,而且也可以写出n种写法,例如:

MSE=tf.reduce_mean(tf.pow(tf.sub(logits, outputs), 2.0))
MSE=tf.reduce_mean(tf.square(tf.sub(logits, outputs)))
MSE=tf.reduce_mean(tf.square(logits- outputs))

代码中logits代表标签值,outputs代表预测值

2.交叉熵

交叉熵(crossentropy)也是loss算法的一种,一般用在分类问题上,表达的意识为预测输入样本属于某一类的概率 。其表达式如下,其中y代表真实值分类(0或1),a代表预测值。

 

 

在TensorFlow中常见的交叉熵函数有:

  • Sigmoid交叉熵;
  • softmax交叉熵;
  • Sparse交叉熵;
  • 加权Sigmoid交叉熵。

图:在TensorFlow里常用的损失函数如表所示。

 

当然,也可以像MSE那样使用自己组合的公式计算交叉熵,举例,对于softmax后的结果logits我们可以对其使用公式-tf.reduce_sum(labels*tf.log(logits),1),就等同于softmax_cross_entropy_with_logits得到的结果。

 


import tensorflow as tf

labels = [[0, 0, 1], [0, 1, 0]]
logits = [[2, 0.5, 6], [0.1, 0, 3]]
logits_scaled = tf.nn.softmax(logits)
logits_scaled2 = tf.nn.softmax(logits_scaled)

result1 = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits)
result2 = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits_scaled)
result3 = -tf.reduce_sum(labels * tf.log(logits_scaled), 1)

with tf.Session() as sess:
print("scaled=", sess.run(logits_scaled))
print("scaled2=", sess.run(logits_scaled2))
# 经过第二次的softmax后,分布概率会有变化

print("rel1=", sess.run(result1), "\n") # 正确的方式
print("rel2=", sess.run(result2), "\n")
# 如果将softmax变换完的值放进去会,就相当于算第二次softmax的loss,所以会出错
print("rel3=", sess.run(result3))

 

运行上面代码,输出结果如下:

scaled= [[ 0.01791432 0.00399722 0.97808844]
[ 0.04980332 0.04506391 0.90513283]]
scaled2= [[ 0.21747023 0.21446465 0.56806517]
[ 0.2300214 0.22893383 0.54104471]]
rel1= [ 0.02215516 3.09967351]
rel2= [ 0.56551915 1.47432232]
rel3= [ 0.02215518 3.09967351]

 

下面开始验证下前面所说的实验:

  • 比较scaled和scaled2可以看到:经过第二次的softmax后,分布概率会有变化,而scaled才是我们真实转化的softmax值。

  • 比较rel1和rel2可以看到:传入softmax_cross_entropy_with_logits的logits是不需要进行softmax的。如果将softmax后的值scaled传入softmax_cross_entropy_with_logits就相当于进行了两次的softmax转换。



 




posted @ 2020-04-15 12:45  CrescentTing  阅读(761)  评论(0编辑  收藏  举报