在RNN中,loss第二个回合变为nan
最近多次遇到,在RNN的训练过程中,loss在几次迭代后变为了nan。
一般人的直观感受都是,哇,梯度爆炸了。
而我遇到的这几次都不是梯度的问题,而是网络输入的问题。
我的问题代码可以简化为:
import tensorflow as tf import numpy as np '''测试tf.cond basket_embedding = tf.cond(tf.constant(nonzero_nums==0,dtype=tf.bool),lambda:embedding_params[max_items_all],lambda: tf.reduce_max(xe[i,j,:nonzero_nums],axis=0)) ''' nonzero_nums = tf.placeholder(dtype=tf.int32,shape=[]) embedding_params = tf.tile(tf.reshape(tf.constant(list(range(10))),shape=(-1,1)),[1,5]) #[10,5] x = tf.placeholder(dtype=tf.int32,shape=[3]) xe = tf.nn.embedding_lookup(embedding_params,x) #[3,5] basket_embedding = tf.cond(tf.constant(nonzero_nums==0,dtype=tf.bool),true_fn=lambda:embedding_params[9],false_fn=lambda:tf.reduce_max(xe[:nonzero_nums],axis=0)) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) feed_dict={nonzero_nums:0,x:np.array([1,2,3])} _basket_embedding = sess.run(basket_embedding,feed_dict=feed_dict) print(embedding_params.eval()) print(embedding_params[9].eval()) print(_basket_embedding) # print(tf.equal(nonzero_nums,tf.constant(0)).eval())
最后一个print _basket_embedding 应该为[9,9,9,9,9], 而实际情况确实[-inf,-inf,-inf,..],而cond的条件false的部分却可以正常执行。
然后打印
tf.constant(nonzero_nums==0,dtype=tf.bool)
发现无论nonzero_nums赋几,都是false。
所以问题在,一直执行的是
false_fn=lambda:tf.reduce_max(xe[:nonzero_nums],axis=0)
那当nonzero_nums==0时,得到的就是一个null,于是就会出先负无穷-inf了。
改为以下代码,使用tf.equal问题解决。
import tensorflow as tf import numpy as np '''测试tf.cond basket_embedding = tf.cond(tf.constant(nonzero_nums==0,dtype=tf.bool),lambda:embedding_params[max_items_all],lambda: tf.reduce_max(xe[i,j,:nonzero_nums],axis=0)) ''' nonzero_nums = tf.placeholder(dtype=tf.int32,shape=[]) embedding_params = tf.tile(tf.reshape(tf.constant(list(range(10))),shape=(-1,1)),[1,5]) #[10,5] x = tf.placeholder(dtype=tf.int32,shape=[3]) xe = tf.nn.embedding_lookup(embedding_params,x) #[3,5] basket_embedding = tf.cond(tf.equal(nonzero_nums,0),true_fn=lambda:embedding_params[9],false_fn=lambda:tf.reduce_max(xe[:nonzero_nums],axis=0)) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) feed_dict={nonzero_nums:0,x:np.array([1,2,3])} _basket_embedding = sess.run(basket_embedding,feed_dict=feed_dict) print(embedding_params.eval()) print(embedding_params[9].eval()) print(_basket_embedding) # print(tf.equal(nonzero_nums,tf.constant(0)).eval())
估计是因为tf里没有==这个符号?><貌似可以正常使用(貌似)。

浙公网安备 33010602011771号