薄书的pytorch项目实战lesson49-情感分类+蹭免费GPU

项目来源

B站视频pytorch项目实战-情感分类问题

github lesson49-情感分类实战


1 实验环境

在这里和大家推荐一个学习ML和DL的一个实验运行平台,就是google的Colaboratory,或者说一个白嫖GPU的实验平台。

大家直接在google搜colab就好,登入账号就可以用了。

什么是 Colaboratory?

借助 Colaboratory(简称 Colab),您可在浏览器中编写和执行 Python 代码,并且:

  • 无需任何配置
  • 免费使用 GPU
  • 轻松共享

无论您是一名学生数据科学家还是 AI 研究员,Colab 都能够帮助您更轻松地完成工作。您可以观看 Colab 简介了解详情,或查看入门指南!

对于 Colab 笔记本,您可以将可执行代码富文本以及图像HTMLLaTeX 等内容合入 1 个文档中。当您创建自己的 Colab 笔记本时,系统会将这些笔记本存储在您的 Google 云端硬盘帐号名下。您可以轻松地将 Colab 笔记本共享给同事或好友,允许他们评论甚至修改笔记本。要了解详情,请参阅 Colab 概览。要创建新的 Colab 笔记本,您可以使用上方的“文件”菜单,也可以使用以下链接:创建新的 Colab 笔记本

Colab 笔记本是由 Colab 托管的 Jupyter 笔记本。如需详细了解 Jupyter 项目,请访问 jupyter.org

使用过程中记得在 菜单栏>代码执行程序>更改运行时类型 中打开使用GPU加速

2 实验

2.1 环境配置和导入

!pip install torch
!pip install torchtext
!python -m spacy download en


# K80 gpu for 12 hours
import torch
from torch import nn, optim
from torchtext import data, datasets
print('GPU:', torch.cuda.is_available())

torch.manual_seed(123)
Requirement already satisfied: torch in /usr/local/lib/python3.6/dist-packages (1.7.0+cu101)
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.6/dist-packages (from torch) (3.7.4.3)
Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from torch) (0.16.0)
Requirement already satisfied: dataclasses in /usr/local/lib/python3.6/dist-packages (from torch) (0.8)
Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from torch) (1.19.4)
Requirement already satisfied: torchtext in /usr/local/lib/python3.6/dist-packages (0.3.1)
Requirement already satisfied: torch in /usr/local/lib/python3.6/dist-packages (from torchtext) (1.7.0+cu101)
Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from torchtext) (4.41.1)
Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from torchtext) (1.19.4)
Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from torchtext) (2.23.0)
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.6/dist-packages (from torch->torchtext) (3.7.4.3)
Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from torch->torchtext) (0.16.0)
Requirement already satisfied: dataclasses in /usr/local/lib/python3.6/dist-packages (from torch->torchtext) (0.8)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->torchtext) (2.10)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->torchtext) (3.0.4)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->torchtext) (2020.12.5)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->torchtext) (1.24.3)
Requirement already satisfied: en_core_web_sm==2.2.5 from https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-2.2.5/en_core_web_sm-2.2.5.tar.gz#egg=en_core_web_sm==2.2.5 in /usr/local/lib/python3.6/dist-packages (2.2.5)
Requirement already satisfied: spacy>=2.2.2 in /usr/local/lib/python3.6/dist-packages (from en_core_web_sm==2.2.5) (2.2.4)
Requirement already satisfied: srsly<1.1.0,>=1.0.2 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->en_core_web_sm==2.2.5) (1.0.5)
Requirement already satisfied: tqdm<5.0.0,>=4.38.0 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->en_core_web_sm==2.2.5) (4.41.1)
Requirement already satisfied: thinc==7.4.0 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->en_core_web_sm==2.2.5) (7.4.0)
Requirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->en_core_web_sm==2.2.5) (51.0.0)
Requirement already satisfied: numpy>=1.15.0 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->en_core_web_sm==2.2.5) (1.19.4)
Requirement already satisfied: murmurhash<1.1.0,>=0.28.0 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->en_core_web_sm==2.2.5) (1.0.5)
Requirement already satisfied: requests<3.0.0,>=2.13.0 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->en_core_web_sm==2.2.5) (2.23.0)
Requirement already satisfied: blis<0.5.0,>=0.4.0 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->en_core_web_sm==2.2.5) (0.4.1)
Requirement already satisfied: plac<1.2.0,>=0.9.6 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->en_core_web_sm==2.2.5) (1.1.3)
Requirement already satisfied: wasabi<1.1.0,>=0.4.0 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->en_core_web_sm==2.2.5) (0.8.0)
Requirement already satisfied: preshed<3.1.0,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->en_core_web_sm==2.2.5) (3.0.5)
Requirement already satisfied: catalogue<1.1.0,>=0.0.7 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->en_core_web_sm==2.2.5) (1.0.0)
Requirement already satisfied: cymem<2.1.0,>=2.0.2 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->en_core_web_sm==2.2.5) (2.0.5)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests<3.0.0,>=2.13.0->spacy>=2.2.2->en_core_web_sm==2.2.5) (1.24.3)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests<3.0.0,>=2.13.0->spacy>=2.2.2->en_core_web_sm==2.2.5) (2.10)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests<3.0.0,>=2.13.0->spacy>=2.2.2->en_core_web_sm==2.2.5) (2020.12.5)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests<3.0.0,>=2.13.0->spacy>=2.2.2->en_core_web_sm==2.2.5) (3.0.4)
Requirement already satisfied: importlib-metadata>=0.20; python_version < "3.8" in /usr/local/lib/python3.6/dist-packages (from catalogue<1.1.0,>=0.0.7->spacy>=2.2.2->en_core_web_sm==2.2.5) (3.3.0)
Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.6/dist-packages (from importlib-metadata>=0.20; python_version < "3.8"->catalogue<1.1.0,>=0.0.7->spacy>=2.2.2->en_core_web_sm==2.2.5) (3.4.0)
Requirement already satisfied: typing-extensions>=3.6.4; python_version < "3.8" in /usr/local/lib/python3.6/dist-packages (from importlib-metadata>=0.20; python_version < "3.8"->catalogue<1.1.0,>=0.0.7->spacy>=2.2.2->en_core_web_sm==2.2.5) (3.7.4.3)
✔ Download and installation successful
You can now load the model via spacy.load('en_core_web_sm')
✔ Linking successful
/usr/local/lib/python3.6/dist-packages/en_core_web_sm -->
/usr/local/lib/python3.6/dist-packages/spacy/data/en
You can now load the model via spacy.load('en')
GPU: True
<torch._C.Generator at 0x7f7acf579b10>

2.2 设置数据集

TEXT = data.Field(tokenize='spacy')
LABEL = data.LabelField(dtype=torch.float)
train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)
# IMDB是torchtext提供的数据集
print('len of train data:', len(train_data))
print('len of test data:', len(test_data))
len of train data: 25000
len of test data: 25000
['First', 'I', 'was', 'caught', 'totally', 'off', 'guard', 'by', 'the', 'film', "'s", 'initial', 'lyricism', 'and', 'then', 'I', 'became', 'totally', 'enchanted', 'with', 'the', 'unfolding', 'story', 'and', 'engrossed', 'with', 'the', 'brilliant', 'directing', '.', 'The', 'characters', 'were', 'all', 'fully', 'developed', ',', 'not', 'bigger', '-', 'than', '-', 'life', 'but', 'just', 'like', 'the', 'people', 'we', 'live', 'among', 'anywhere', 'we', 'are', 'in', 'the', 'world', ',', 'in', 'Sweden', ',', 'in', 'Turkey', 'or', 'in', 'America', ',', 'all', 'completely', 'believable', 'human', 'beings', 'with', 'foibles', 'and', 'nobility', '.', 'Hollywood', 'could', 'learn', 'so', 'much', 'from', 'this', 'beautiful', 'film', '.', 'It', 'shows', 'that', 'there', 'is', 'no', 'need', 'to', 'go', 'into', 'every', 'little', 'detail', 'behind', 'every', 'action', 'to', 'bring', 'out', 'the', 'whole', 'theme', 'clear', 'and', 'bright', ',', 'and', 'that', 'shows', 'the', 'brilliance', 'of', 'the', 'director', '!', 'Hearfelt', 'thanks', 'to', 'Kay', 'Pollak', 'and', 'the', 'wonderful', 'cast', 'for', 'this', 'superb', 'treat', '!', '!']
pos
# word2vec, glove
TEXT.build_vocab(train_data, max_size=10000, vectors='glove.6B.100d')
LABEL.build_vocab(train_data)


batchsz = 30
device = torch.device('cuda')
train_iterator, test_iterator = data.BucketIterator.splits(
    (train_data, test_data),
    batch_size = batchsz,
    device=device
)
.vector_cache/glove.6B.zip: 862MB [06:28, 2.22MB/s]                          
100%|█████████▉| 398704/400000 [00:16<00:00, 24622.99it/s]

2.3 搭建lstm网络

class RNN(nn.Module):
    
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        """
        """
        super(RNN, self).__init__()
        
        # [0-10001] => [100] vocab_size=10002 embedding_dim=100,就是说10002个单词其中是10000个真的单词还有一个是不认识的单侧另一个是特殊符号,每个单词用长度100的向量表示
        self.embedding = nn.Embedding(vocab_size, embedding_dim) 
        #[Embedding介绍1](https://zhuanlan.zhihu.com/p/53194407) 
        #[Embedding介绍2](https://www.cnblogs.com/USTC-ZCC/p/11068791.html)
        # [100] => [256]
        self.rnn = nn.LSTM(embedding_dim, hidden_dim, num_layers=2, 
                           bidirectional=True, dropout=0.5)  
        # [双向循环神经网络bidirectional介绍](https://shenxiaohai.me/2018/10/19/pytorch-tutorial-intermediate-04/)
        # [256*2] => [1]
        self.fc = nn.Linear(hidden_dim*2, 1)
        self.dropout = nn.Dropout(0.5)
        
        
    def forward(self, x):
        """
        x: [seq_len, b] vs [b, 3, 28, 28]
        """
        # [seq, b, 1] => [seq, b, 100]
        embedding = self.dropout(self.embedding(x))
        
        # output: [seq, b, hid_dim*2]
        # hidden/h: [num_layers*2, b, hid_dim]
        # cell/c: [num_layers*2, b, hid_di]
        output, (hidden, cell) = self.rnn(embedding)
        
        # [num_layers*2, b, hid_dim] => 2 of [b, hid_dim] => [b, hid_dim*2]
        hidden = torch.cat([hidden[-2], hidden[-1]], dim=1)
        
        # [b, hid_dim*2] => [b, 1]
        hidden = self.dropout(hidden)
        out = self.fc(hidden)
        
        return out

2.4 embedding和网络优化

rnn = RNN(len(TEXT.vocab), 100, 256)
# 转换成embedding的形式
pretrained_embedding = TEXT.vocab.vectors
print('pretrained_embedding:', pretrained_embedding.shape)
rnn.embedding.weight.data.copy_(pretrained_embedding)
print('embedding layer inited.')

optimizer = optim.Adam(rnn.parameters(), lr=1e-3)
criteon = nn.BCEWithLogitsLoss().to(device)
rnn.to(device)

pretrained_embedding: torch.Size([10002, 100])
embedding layer inited.
RNN(
  (embedding): Embedding(10002, 100)
  (rnn): LSTM(100, 256, num_layers=2, dropout=0.5, bidirectional=True)
  (fc): Linear(in_features=512, out_features=1, bias=True)
  (dropout): Dropout(p=0.5, inplace=False)
)

2.5 训练与测试

import numpy as np

def binary_acc(preds, y):
    """
    get accuracy
    """
    preds = torch.round(torch.sigmoid(preds))
    correct = torch.eq(preds, y).float()
    acc = correct.sum() / len(correct)
    return acc

def train(rnn, iterator, optimizer, criteon):
    
    avg_acc = []
    rnn.train()
    
    for i, batch in enumerate(iterator): # 遍历所有训练数据
        
        # [seq, b] => [b, 1] => [b]
        pred = rnn(batch.text).squeeze(1)
        # 
        loss = criteon(pred, batch.label)
        acc = binary_acc(pred, batch.label).item()
        avg_acc.append(acc)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if i%100 == 0:
            print(i, acc)
        
    avg_acc = np.array(avg_acc).mean()
    print('avg acc:', avg_acc)
    
    
def eval(rnn, iterator, criteon):
    
    avg_acc = []
    
    rnn.eval()
    
    with torch.no_grad():
        for batch in iterator:

            # [b, 1] => [b]
            pred = rnn(batch.text).squeeze(1)

            #
            loss = criteon(pred, batch.label)

            acc = binary_acc(pred, batch.label).item()
            avg_acc.append(acc)
        
    avg_acc = np.array(avg_acc).mean()
    
    print('>>test:', avg_acc)
        
for epoch in range(10):
    
    eval(rnn, test_iterator, criteon)
    train(rnn, train_iterator, optimizer, criteon)
>>test: 0.8730615999915903
0 0.8666667342185974
100 0.8666667342185974
200 0.9666666984558105
300 0.9333333969116211
400 0.9333333969116211
500 0.9666666984558105
600 0.9000000357627869
700 0.9666666984558105
800 1.0
avg acc: 0.9348521599952552
>>test: 0.8765388191175117
0 0.8666667342185974
100 0.9000000357627869
200 1.0
300 0.9333333969116211
400 0.9333333969116211
500 0.9333333969116211
600 0.9666666984558105
700 0.8666667342185974
800 0.9000000357627869
avg acc: 0.9394085123527536
>>test: 0.8712630401984107
0 1.0
100 0.8666667342185974
200 0.9666666984558105
300 0.9666666984558105
400 1.0
500 0.9666666984558105
600 0.9666666984558105
700 0.9666666984558105
800 0.8333333730697632
avg acc: 0.94452441853585
>>test: 0.8790967720303889
0 0.9666666984558105
100 1.0
200 0.9666666984558105
300 0.9000000357627869
400 0.9333333969116211
500 0.9000000357627869
600 0.9000000357627869
700 0.9666666984558105
800 1.0
avg acc: 0.9481215391942351
>>test: 0.8758193941996824
0 0.9333333969116211
100 0.9666666984558105
200 1.0
300 0.9333333969116211
400 0.9666666984558105
500 0.9666666984558105
600 0.9666666984558105
700 0.9666666984558105
800 0.9333333969116211
avg acc: 0.9529176998338539
>>test: 0.8762590416329656
0 0.8666667342185974
100 0.9333333969116211
200 1.0
300 0.9666666984558105
400 0.8666667342185974
500 0.9333333969116211
600 0.9666666984558105
700 1.0
800 0.8666667342185974
avg acc: 0.9550360044558271
>>test: 0.8747402563941279
0 0.9666666984558105
100 0.9333333969116211
200 0.9666666984558105
300 0.9666666984558105
400 0.9333333969116211
500 0.9333333969116211
600 0.9333333969116211
700 0.9666666984558105
800 0.9333333969116211
avg acc: 0.958473253521702
>>test: 0.8732214720843793
0 0.9333333969116211
100 0.9666666984558105
200 0.9666666984558105
300 0.9666666984558105
400 0.9666666984558105
500 0.9666666984558105
600 1.0
700 0.9666666984558105
800 0.9333333969116211
avg acc: 0.9630296058792005
>>test: 0.8703038053546878
0 0.9666666984558105
100 0.9333333969116211
200 1.0
300 1.0
400 1.0
500 0.9333333969116211
600 0.9666666984558105
700 1.0
800 0.9666666984558105
avg acc: 0.965107941155811
>>test: 0.8725819842849704
0 1.0
100 1.0
200 0.9666666984558105
300 1.0
400 1.0
500 1.0
600 0.9333333969116211
700 1.0
800 0.9000000357627869
avg acc: 0.9668265646881908
posted @ 2021-01-07 20:19  薄书  阅读(128)  评论(0编辑  收藏  举报