记一次模型调试问题:使用TextLSTM/RNN学习不动,损失和acc均无变化
问题
在清华新闻分类数据集上,使用TextCNN效果不错,使用TextLSTM/RNN学习不动,损失和acc均无变化
定位问题
- CNN效果有提升,说明train代码和数据没问题;
- 更改RNN/LSTM结构,加损失函数还是没效果;
- 修改lr、embed_dim,num_laber均无效果;
- 本地一步步debug,发现一个问题,input里面有很多padding为0了,怀疑是短文本pad太多没有学到特征;
优化
- 修改seq_length,从50改为32,短了补0,长了截取。重新训练,问题解决,模型开始学习。
# 修改seq_length,模型效果缓慢提升
epoch:0 item:50 loss:2.3013758659362793 train_acc:0.109375 dev_acc:0.1474609375
epoch:0 item:100 loss:2.2781167030334473 train_acc:0.1953125 dev_acc:0.17578125
epoch:0 item:150 loss:2.205510377883911 train_acc:0.21875 dev_acc:0.1767578125
epoch:0 item:200 loss:2.183446168899536 train_acc:0.2421875 dev_acc:0.294921875
epoch:0 item:250 loss:2.144759178161621 train_acc:0.2265625 dev_acc:0.2470703125
epoch:0 item:300 loss:2.143526792526245 train_acc:0.265625 dev_acc:0.28515625
epoch:0 item:350 loss:2.078019857406616 train_acc:0.34375 dev_acc:0.279296875
epoch:0 item:400 loss:2.096219301223755 train_acc:0.296875 dev_acc:0.3203125
epoch:0 item:450 loss:2.0016613006591797 train_acc:0.3984375 dev_acc:0.3974609375
epoch:0 item:500 loss:1.9698306322097778 train_acc:0.390625 dev_acc:0.4150390625
epoch:0 item:550 loss:1.9621188640594482 train_acc:0.4453125 dev_acc:0.47265625
- 改为取各个token的mean、max,同样可以学到特征,比取最后一层效果更好
# 取mean,效果优于last hidden state
epoch:0 item:50 loss:1.9982243776321411 train_acc:0.4375 dev_acc:0.474609375
epoch:0 item:100 loss:1.7818747758865356 train_acc:0.6015625 dev_acc:0.6044921875
epoch:0 item:150 loss:1.762993574142456 train_acc:0.5625 dev_acc:0.625
epoch:0 item:200 loss:1.71768057346344 train_acc:0.6484375 dev_acc:0.673828125
epoch:0 item:250 loss:1.6551015377044678 train_acc:0.65625 dev_acc:0.64453125
epoch:0 item:300 loss:1.661691427230835 train_acc:0.6640625 dev_acc:0.640625
epoch:0 item:350 loss:1.6576321125030518 train_acc:0.65625 dev_acc:0.6865234375
# 取max,效果比较好,2个batch后提升很明显
epoch:0 item:50 loss:1.9646885395050049 train_acc:0.578125 dev_acc:0.5859375
epoch:0 item:100 loss:1.7058160305023193 train_acc:0.7890625 dev_acc:0.7041015625
epoch:0 item:150 loss:1.743579626083374 train_acc:0.6328125 dev_acc:0.7392578125
代码及数据集
https://github.com/haibincoder/NlpSummary/tree/master/torchcode/classification
时间会记录下一切。