TensorFlow文本分类

 参考文章:https://zhuanlan.zhihu.com/p/59506402

import tensorflow as tf
import matplotlib.pyplot as plt
import tensorflow.keras.layers as layers
# 加载数据
imdb = tf.keras.datasets.imdb
(train_x,train_y),(test_x,test_y) = tf.keras.datasets.imdb.load_data(num_words=10000)
#了解IMDB数据
print(train_x[0])
print('len:',len(train_x[0]),len(train_x[1]))
#创建id和词的匹配字典
word_index = imdb.get_word_index()
word2id = {k:(v+3) for k, v in word_index.items()}
word2id['<PAD>'] = 0
word2id['<START>'] = 1
word2id['<UNK>'] = 2
word2id['<UNUSED>'] = 3
id2word = {v:k for k, v in word2id.items()}
def get_words(sent_ids):
    return ' '.join([id2word.get(i,'?') for i in sent_ids])
sent = get_words(train_x[0])
print(sent)
#准备数据
train_x = tf.keras.preprocessing.sequence.pad_sequences(train_x,value=word2id['<PAD>'],padding='post',maxlen=256)
test_x = tf.keras.preprocessing.sequence.pad_sequences(test_x,value=word2id['<PAD>'],padding='post',maxlen=256)
print(train_x[0])
print('len:',len(train_x[0]),len(train_x[1]))
#构建模型
vocab_size = 10000
model = tf.keras.Sequential()
model.add(layers.Embedding(vocab_size,16))
model.add(layers.GlobalAveragePooling1D())
model.add(layers.Dense(16,activation='relu'))
model.add(layers.Dense(1,activation='sigmoid'))
model.summary()
model.compile(optimizer='adam',loss='binary_crossentropy',metrics=['accuracy'])
#模型训练与验证
x_val = train_x[:10000]
x_train = train_x[10000:]
y_val = train_y[:10000]
y_train = train_y[10000:]
history = model.fit(x_train,y_train,epochs=40,batch_size=512,validation_data=(x_val,y_val),verbose=1)
result = model.evaluate(test_x,test_y)
print(result)
#查看准确率时序图
history_dict = history.history
history_dict.keys()
acc = history_dict['accuracy']
val_acc = history_dict['val_accuracy']
loss = history_dict['loss']
val_loss = history_dict['val_loss']
epochs = range(1,len(acc)+1)
plt.plot(epochs,loss,'bo',label='train loss')
plt.plot(epochs,val_loss,'b',label='val loss')
plt.title('Train and val loss')
plt.xlabel('Epochs')
plt.xlabel('loss')
plt.legend()
plt.show()

plt.plot(epochs, acc, 'bo', label='Training acc')
plt.plot(epochs, val_acc, 'b', label='Validation acc')
plt.title('Training and validation accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

 

posted @ 2021-01-20 16:17  .HAHA  阅读(117)  评论(0编辑  收藏  举报