gpu加速
这会显著降低训练和推理的速度。PyTorch 中启用 GPU 加速非常简单,只需要将模型和数据移动到 GPU 上即可。
下面是修改后的代码,添加了 GPU 支持:
python
运行
#coding:utf8
import torch
import torch.nn as nn
import numpy as np
import random
import json
from transformers import BertModel
"""
基于pytorch的网络编写
实现一个网络完成一个简单nlp任务
判断文本中是否有某些特定字符出现
week2的例子,修改引入bert
"""
# 判断是否有GPU可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
class TorchModel(nn.Module):
def __init__(self, input_dim, sentence_length, vocab):
super(TorchModel, self).__init__()
# 原始代码
# self.embedding = nn.Embedding(len(vocab) + 1, input_dim)
# self.layer = nn.Linear(input_dim, input_dim)
# self.pool = nn.MaxPool1d(sentence_length)
self.bert = BertModel.from_pretrained("./bert-base-chinese", return_dict=False)
self.classify = nn.Linear(input_dim, 3)
self.activation = torch.sigmoid #sigmoid做激活函数
self.dropout = nn.Dropout(0.5)
self.loss = nn.functional.cross_entropy
# 将模型移到GPU
self.to(device)
#当输入真实标签,返回loss值;无真实标签,返回预测值
def forward(self, x, y=None):
# 原始代码
# x = self.embedding(x) #input shape:(batch_size, sen_len) (10,6)
# x = self.layer(x) #input shape:(batch_size, sen_len, input_dim) (10,6,20)
# x = self.dropout(x) #input shape:(batch_size, sen_len, input_dim)
# x = self.activation(x) #input shape:(batch_size, sen_len, input_dim)
# x = self.pool(x.transpose(1,2)).squeeze() #input shape:(batch_size, sen_len, input_dim)
# 将输入移到GPU
x = x.to(device)
if y is not None:
y = y.to(device)
sequence_output, pooler_output = self.bert(x)
x = self.classify(pooler_output)
y_pred = self.activation(x)
if y is not None:
return self.loss(y_pred, y.squeeze())
else:
return y_pred
#字符集随便挑了一些汉字,实际上还可以扩充
#为每个字生成一个标号
#{"a":1, "b":2, "c":3...}
#abc -> [1,2,3]
def build_vocab():
chars = "abcdefghijklmnopqrstuvwxyz" #字符集
vocab = {}
for index, char in enumerate(chars):
vocab[char] = index + 1 #每个字对应一个序号
vocab['unk'] = len(vocab)+1
return vocab
#随机生成一个样本
#从所有字中选取sentence_length个字
#反之为负样本
def build_sample(vocab, sentence_length):
#随机从字表选取sentence_length个字,可能重复
x = [random.choice(list(vocab.keys())) for _ in range(sentence_length)]
#A类样本
if set("abc") & set(x) and not set("xyz") & set(x):
y = 0