基于tensorflow1.9进行对文本进行分类,循环神经网络进行文本分类

import tensorflow as tf
import numpy as np


class TextClassifier:
    def __init__(self, vocab_size=10000, max_len=200, embedding_dim=128):
        tf.reset_default_graph()
        self.vocab_size = vocab_size
        self.max_len = max_len
        self.embedding_dim = embedding_dim
        self._build_graph()

    def _build_graph(self):
        # 输入占位符
        self.inputs = tf.placeholder(tf.int32, [None, self.max_len], name='inputs')
        self.labels = tf.placeholder(tf.int32, [None], name='labels')
        self.dropout_keep_prob = tf.placeholder(tf.float32, name='dropout_keep_prob')

        # 嵌入层
        with tf.device('/cpu:0'), tf.name_scope('embedding'):
            W = tf.Variable(
                tf.random_uniform([self.vocab_size, self.embedding_dim], -1.0, 1.0),
                name='W')
            self.embedded = tf.nn.embedding_lookup(W, self.inputs)

        # LSTM层
        lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(64)
        lstm_cell = tf.nn.rnn_cell.DropoutWrapper(lstm_cell, output_keep_prob=self.dropout_keep_prob)
        outputs, _ = tf.nn.dynamic_rnn(lstm_cell, self.embedded, dtype=tf.float32)

        # 全连接层
        W = tf.get_variable('W', shape=[64, 2], initializer=tf.contrib.layers.xavier_initializer())
        b = tf.Variable(tf.constant(0.1, shape=[2]), name='b')
        self.logits = tf.nn.xw_plus_b(tf.reduce_mean(outputs, axis=1), W, b, name='logits')

        # 损失函数和优化器
        self.loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits, labels=self.labels))
        self.optimizer = tf.train.AdamOptimizer(1e-3).minimize(self.loss)

        # 预测操作
        self.predictions = tf.argmax(self.logits, 1, name='predictions')
        correct_predictions = tf.equal(self.predictions, tf.cast(self.labels, tf.int64))
        self.accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32), name='accuracy')

    def train(self, X_train, y_train, X_val, y_val, epochs=10, batch_size=32):
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())

            for epoch in range(epochs):
                for i in range(0, len(X_train), batch_size):
                    batch_x = X_train[i:i + batch_size]
                    batch_y = y_train[i:i + batch_size]

                    feed_dict = {
                        self.inputs: batch_x,
                        self.labels: batch_y,
                        self.dropout_keep_prob: 0.5
                    }
                    _, loss, acc = sess.run([self.optimizer, self.loss, self.accuracy], feed_dict=feed_dict)

                # 验证集评估
                val_feed = {
                    self.inputs: X_val,
                    self.labels: y_val,
                    self.dropout_keep_prob: 1.0
                }
                val_acc = sess.run(self.accuracy, feed_dict=val_feed)
                print(f'Epoch {epoch+1}, Loss: {loss:.4f}, Train Acc: {acc:.4f}, Val Acc: {val_acc:.4f}')

            # 保存模型
            saver = tf.train.Saver()
            saver.save(sess, './model/text_classifier.ckpt')

    def predict(self, X_test):
        with tf.Session() as sess:
            saver = tf.train.Saver()
            saver.restore(sess, './model/text_classifier.ckpt')

            feed_dict = {
                self.inputs: X_test,
                self.dropout_keep_prob: 1.0
            }
            return sess.run(self.predictions, feed_dict=feed_dict)


def preprocess_text(texts, max_len=10):
    # 简单分词和构建词汇表
    words = [word for text in texts for word in text.lower().split()]
    vocab = {word: i + 1 for i, word in enumerate(set(words))}

    print(words)
    print(vocab)
    # 文本转序列
    sequences = []
    for text in texts:
        seq = [vocab.get(word, 0) for word in text.lower().split()]
        seq = seq[:max_len] + [0] * (max_len - len(seq))  # padding
        sequences.append(seq)

    return np.array(sequences), len(vocab) + 1


if __name__ == '__main__':
    # 示例数据
    texts = ['this is good', 'that is bad', 'great experience', 'poor quality']
    labels = [1, 0, 1, 0]
    m_len = 3
    # 预处理
    X, vocab_size = preprocess_text(texts,max_len=m_len)
    y = np.array(labels)

    print(X)
    # print(vocab_size)
    # print(y)
    # 训练模型
    model = TextClassifier(vocab_size=vocab_size,max_len=m_len)
    model.train(X, y, X, y, epochs=50)

    print(model.predict(X))
posted @ 2025-08-08 10:49  ARYOUOK  阅读(8)  评论(0)    收藏  举报