JadenFK
哪有什么岁月静好,只是有人替我们负重前行! http://JadenFK.github.io

视频链接:https://www.bilibili.com/video/BV12741177Cu?from=search&seid=17209581732555565064

视频上是用的jupyter notebook实现的,这次我是用的pycharm实现的代码。

fizz\buzz\fizzbuzz小游戏的意思是:如果被3除尽打印fizz,被5除尽打印buzz,被15除尽打印fizzbuzz。这可以用一个函数实现,但是我们是学习神经网络,所以用一个二层神经网络实现,自己去学习,自己去玩,当然界面不实现

主要有三个.py文件:utils.py存放工具函数,model.py训练模型,paragraph2.py:使用模型进行预测

utils.py

import numpy as np

def binary_encode(i, num_digits):   # 转二进制计算
    return np.array([i >> d & 1 for d in range(num_digits)])[::-1]   # [::-1]是把arry倒过来,因为一开始转的是二进制反的

def fizz_buzz_encode(i):
    if i % 15 == 0: return 3
    elif i % 5 == 0: return 2
    elif i % 3 == 0: return 1
    else: return 0

def fizz_buzz_decode(i, prediction):
    return [str(i), 'fizz', 'buzz', 'fizzbuzz'][prediction]   #这是个很好玩的用法,我也是第一次见,各位可以打印一下试试

model.py实现:

import torch
from p2.utils import binary_encode, fizz_buzz_encode
NUM_DIGITS = 10

trX = torch.Tensor([binary_encode(i, NUM_DIGITS) for i in range(101, 2 ** NUM_DIGITS)])   # 训练数据, 101致以上,好像是923个
trY = torch.LongTensor([fizz_buzz_encode(i) for i in range(101, 2 ** NUM_DIGITS)])   # x可以是float类型,但是y是表示类别的,不行

NUM_HIDDEN = 100
model = torch.nn.Sequential(    # 模型定义,激活函数为ReLU
    torch.nn.Linear(NUM_DIGITS, NUM_HIDDEN),
    torch.nn.ReLU(),
    torch.nn.Linear(NUM_HIDDEN, 4)
)

if torch.cuda.is_available():   # 模型转到gpu上运行
    model = model.cuda()

loss_fn = torch.nn.CrossEntropyLoss() # 损失函数使用交叉熵损失函数
optimizer = torch.optim.SGD(model.parameters(), lr=0.05)   # 优化算法选择SGD,可百度下SGD,是随机梯度下降法,torch封装了好几个优化算法,可以自行试试

BATCH_SIZE = 128

def __main__():
    for epoch in range(1000):    # 训练epoch是1000, 视频上老师训练是10000,我嫌太大了,慢,所以改为了1000,但是效果确实不如10000的,可以自己试试 
        for start in range(0, len(trX), BATCH_SIZE):   # 批量大小为BATCH_SIZE
            end = start + BATCH_SIZE
            batchX = trX[start:end]
            batchY = trY[start:end]

            if torch.cuda.is_available():   # 训练数据搬到gpu
                batchX = batchX.cuda()
                batchY = batchY.cuda()

            y_pred = model(batchX)

            loss = loss_fn(y_pred, batchY)
            print("Epoch", epoch, loss.item())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    torch.save(model, 'fbmodel.pkl')

paragraph2.py实现

import torch
from p2.utils import binary_encode, fizz_buzz_decode

model = torch.load('p2/fbmodel.pkl')

NUM_DIGITS = 10

testX = torch.Tensor([binary_encode(i, NUM_DIGITS) for i in range(1, 101)])
if torch.cuda.is_available():
    testX = testX.cuda()

with torch.no_grad():
    testY = model(testX)

predictions = zip(range(1, 101), testY.max(1)[1].cpu().data.tolist())      # 非常有意思和技巧的一个东西,testY.max(1)[1].cpu().data.tolist()可以自己试试,打印
print([fizz_buzz_decode(i, x) for i, x in predictions])

训练epoch为1000的结果:

['1', '2', 'fizz', '4', 'buzz', 'fizz', '7', '8', 'fizz', '10', '11', 'fizz', '13', '14', 'fizz', '16', '17', 'fizz', '19', '20', 'fizz', '22', '23', 'fizz', '25', '26', 'fizz', '28', '29', 'fizz', '31', '32', 'fizz', '34', '35', 'fizz', '37', '38', 'fizz', '40', '41', '42', '43', '44', 'fizzbuzz', '46', '47', 'fizz', '49', '50', 'fizz', '52', '53', 'fizz', '55', '56', 'fizz', '58', '59', 'fizzbuzz', '61', '62', 'fizz', '64', 'buzz', 'fizz', '67', 'fizz', 'fizz', '70', '71', 'fizz', '73', '74', 'fizzbuzz', '76', 'buzz', 'fizz', '79', 'buzz', 'fizz', '82', '83', 'fizz', 'fizz', '86', 'fizz', '88', '89', 'fizzbuzz', '91', '92', 'fizz', '94', 'buzz', 'fizz', '97', '98', 'fizz', '100']
训练epoch为10000的结果:

['1', '2', 'fizz', '4', 'buzz', 'fizz', '7', '8', 'fizz', 'buzz', '11', 'fizz', '13', '14', 'fizzbuzz', '16', '17', 'fizz', 'fizz', 'buzz', 'fizz', '22', '23', 'fizz', 'buzz', '26', 'fizz', '28', '29', 'fizzbuzz', '31', '32', 'fizz', '34', 'buzz', 'fizz', '37', '38', 'fizz', 'buzz', '41', 'fizz', '43', '44', 'fizzbuzz', '46', '47', 'fizz', '49', 'buzz', 'fizz', '52', '53', 'fizz', 'buzz', '56', 'fizz', '58', '59', 'fizzbuzz', '61', '62', 'fizz', '64', 'buzz', '66', '67', '68', 'fizz', '70', '71', 'fizz', '73', '74', 'fizzbuzz', '76', '77', '78', '79', 'buzz', 'fizz', '82', '83', 'fizz', 'buzz', 'fizz', 'fizz', '88', '89', 'fizzbuzz', '91', '92', 'fizz', '94', 'buzz', 'fizz', '97', '98', 'fizz', 'buzz']
训练数据多少还是有区别的。

我是小白,虽然我黑。一起学习,一起探讨,加油。

 

posted on 2020-06-13 12:03  郭心全  阅读(922)  评论(0编辑  收藏  举报

……