基于keras采用LSTM实现多标签文本分类

我先抓取博客园知识库的文章标题和分类

代码:

#coding=utf-8

import os
import sys
import requests
from lxml import etree,html
import lxml
import time
import re

filepath = 'data/bokeyuan_fenlei.csv'


def zhuaqudata():
    page = 1
    print("开始抓取%s页..." % page)
    (haslast,titles,fenleis) = getwenzhangandnext(page)
    for i,title in enumerate(titles):
        fenlei = fenleis[i]
        print('[%s] %s' % (fenlei, title))
        writefile(filepath, "[%s] %s\n" % (fenlei, title))
    print()
    while haslast:
        page = page + 1
        print("开始抓取%s页..." % page)
        (haslast,titles,fenleis) = getwenzhangandnext(page)
        for i,title in enumerate(titles):
            fenlei = fenleis[i]
            print('[%s] %s' % (fenlei, title))
            writefile(filepath, "[%s] %s\n" % (fenlei, title))
        print()
        
def getwenzhangandnext(page):
    baseurl = 'https://kb.cnblogs.com/'
    if page == 1:
        url = baseurl
    else:
        url = baseurl + str(page)+'/'
    print(url)
    content = geturl(url)
    htmlcontent = etree.HTML(content)
    
    titles = []
    fenleis = []
    ps = htmlcontent.xpath('//div[@class="list_block"]//div[@class="msg_title"]//p')
    for p in ps:
        phtml = html.tostring(p).decode('utf-8')
        pcontent = etree.HTML(phtml)
        if not 'span' in phtml:
            continue
        else:
            title = pcontent.xpath('//a//@title')[0]
            fenlei = pcontent.xpath('//span//text()')[0]
            titles.append(title)
            fenleis.append(fenlei)

    haslasttext = str(htmlcontent.xpath('//div[@id="pager_block"]//div[@id="pager"]//a[last()]//text()')[0])
    
    for i,title in enumerate(titles):
        titles[i] = formatstr(title)
        
    for i,fenlei in enumerate(fenleis):
        fenleis[i] = formatstr(fenlei)

    haslast = 0
    if 'next' in haslasttext.lower():
        haslast = 1
        #print("存在下一页")
    else:
        #print("不存在下一页")
        pass
    
    time.sleep(3)
    return haslast,titles,fenleis

def formatstr(str):
    res = re.findall('[0-9a-zA-Z\u4e00-\u9fa5:、?!,]', str)
    return ''.join(res)

def readfile(filepath):
    fp = open(filepath, 'r', encoding='utf-8')
    res = fp.read()
    fp.close()
    return res

def writefile(filepath, s):
    fp = open(filepath, 'a', encoding='utf-8')
    fp.write(s)
    fp.close()

def geturl(url):
    header = {
        'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:95.0) Gecko/20100101 Firefox/95.0'
    }
    res = requests.get(url,headers=header)
    res.encoding = res.apparent_encoding
    return res.text

if __name__ == '__main__':
    zhuaqudata()

 

结果:

 

 

然后通过程序读出文件,建立数据和标签的对应关系,进行编码,建模,训练,测试

代码:

#coding=utf-8

import os
import sys
import re
import jieba
from sklearn.preprocessing import MultiLabelBinarizer
from keras.preprocessing.text import Tokenizer
from keras_preprocessing.sequence import pad_sequences
from keras.models import Sequential,Model,load_model
import numpy as np
from keras.layers import Dense, Input, Flatten, Dropout, LSTM
from keras.layers import Conv1D, MaxPooling1D, Embedding, GlobalMaxPooling1D, SpatialDropout1D
import random

filepath = 'data/bokeyuan_fenlei.csv'
stopwordfilepath = 'data/cn_stopwords.txt'

def readfile(filepath):
    fp = open(filepath, 'r', encoding='utf-8')
    res = fp.read()
    fp.close()
    return res

def writefile(filepath, s):
    fp = open(filepath, 'a', encoding='utf-8')
    fp.write(s)
    fp.close()
    
def duqushuju():
    text = readfile(filepath)
    stop_text = readfile(stopwordfilepath)
    stopwords = [i for i in stop_text.split('\n') if i.strip()]
    res = re.findall('\[(.*?)\](.*?)\n', text)
    titles = []
    fenleis = []
    #random.shuffle(res)
    for i,j in res:
        fenleis.append([i])
        titles.append(contentsplit(j, stopwords))
    
    trainlen = 0#int(len(fenleis) * 0.8)
    
    if trainlen > 0:
        train_data = titles[:trainlen]
        train_label = fenleis[:trainlen]
        test_data = titles[trainlen:]
        test_label = fenleis[trainlen:]
    else:
        train_data = titles[:]
        train_label = fenleis[:]
        test_data = titles[:]
        test_label = fenleis[:]
        
    all_data = titles
    all_fenlei = fenleis
    
    return all_data,all_fenlei,train_data,train_label,test_data,test_label

def contentsplit(segment, stopwords):
    segment = formatstr(segment)
    segments = jieba.cut(segment)
    segments = [i for i in segments if i.strip() and i.strip() not in stopwords and len(i) > 1]
    seg = " ".join(segments)
    return seg

def formatstr(str):
    res = re.findall('[0-9a-zA-Z\u4e00-\u9fa5]', str)
    return ''.join(res)
    
if __name__ == '__main__':
    all_data,all_fenlei,train_data,train_label,test_data,test_label = duqushuju()
    print('总分类大小:%s' % len(all_fenlei))
    print('总标题大小:%s' % len(all_data))
    print('训练分类大小:%s' % len(train_label))
    print('训练标题大小:%s' % len(train_data))
    print('测试分类大小:%s' % len(test_label))
    print('测试标题大小:%s' % len(test_data))
    
    train_dict = {}
    for i,j in enumerate(train_label):
        train_dict[i] = j
    
    # 标签向量化
    mutil_lab = MultiLabelBinarizer()
    train_label_code = mutil_lab.fit_transform(train_label)
    
    mutil_lab = MultiLabelBinarizer()
    test_label_code = mutil_lab.fit_transform(test_label)
    
    tokenizer = Tokenizer(num_words=40000, filters='!"#$%&()*+,-./:;<=>?@[\]^_`{|}~', lower=True)
    tokenizer.fit_on_texts(train_data)
    #print(tokenizer.word_index)
    
    # 利用Tokenizer 向量化文本
    x_data = tokenizer.texts_to_sequences(train_data)
    x_data = pad_sequences(x_data, 100)
    y_data = np.array(train_label_code)
    
    # 利用Tokenizer 向量化文本
    x_test_data = tokenizer.texts_to_sequences(test_data)
    x_test_data = pad_sequences(x_test_data, 100)
    y_test_data = np.array(test_label_code)
    
    print("训练集的大小为: ", x_data.shape, "训练集标签的大小为: ", y_data.shape)
    print("测试集的大小为: ", x_test_data.shape, "测试集标签的大小为: ", y_test_data.shape)
    
    model_path = 'models/wenben_fenlei_lstm.h5'
    if os.path.exists(model_path):
        model = load_model(model_path)
    else:
        # 构建模型
        inputs = Input(shape=(100,))
        embed = Embedding(40000, 100, input_length=x_data.shape[1])(inputs)
        dropout = SpatialDropout1D(0.2)(embed)
        
        # 注意LSTM层的参数是为了能够用上cuDNN的加速
        lstm = LSTM(100, dropout=0.2, recurrent_dropout=0, activation='tanh', recurrent_activation='sigmoid')(dropout)
        output = Dense(y_data.shape[1], activation='sigmoid')(lstm)
        model = Model(inputs, output)
        model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
        model.summary()# 评估模型
        
        model.fit(x_data, y_data, batch_size=16, epochs=20, validation_data=(x_test_data, y_test_data))
        
        model.save(model_path)
    
    n = 3
    pre = model.predict(x_data[:n], n)
    for i in range(n):
        print('[%s] %s' % (','.join(train_label[i]), train_data[i]))
        print('预测值为:%s'  % ','.join(train_dict[pre[i].argmax()]))
        print()
    
    
    ceshi_data = ['FWT/快速沃尔什变换 入门指南', '如何在 Apinto 实现 HTTP 与gRPC 的协议转换 (下)', '万字血书Vue—Vue语法', '云图说丨初识华为云安全云脑——新一代云安全运营中心']
    # 利用Tokenizer 向量化文本
    x_ceshi_data = tokenizer.texts_to_sequences(ceshi_data)
    x_ceshi_data = pad_sequences(x_ceshi_data, 100)
    
    n = 4
    pre = model.predict(x_ceshi_data[:n], n)
    for i in range(n):
        print('%s' % ceshi_data[i])
        print('预测值为:%s'  % ','.join(train_dict[pre[i].argmax()]))
        print()

 

停词的data/cn_stopwords.txt 你可以随便创建一个,空的也没有问题,只是会影响到切词准确与否的问题

我先对训练库的前三个标题做了预测,基本正确,后对4个博客文章的标题做了预测,至少是出结果了。

效果:

 

参考:https://blog.csdn.net/qq_56154355/article/details/125685955

posted @ 2023-03-17 16:27  河北大学-徐小波  阅读(119)  评论(0编辑  收藏  举报