神经网络初步实践

已知 x = 4, y = 2,模型为线性回归模型,根据计算可以得知,具体模型为 y = 0.5x,0.5即需要训练的参数。随机初始化参数,通过前向传播算法,求得参数值为0.5。

import random


class appleNetwork:
    def __init__(self) -> None:
        self.param = random.random()
        
    def train(self, input_data, output_data) -> None:
        result = self.forward(input_data)
        error = output_data - result
        diff = error / input_data
        self.param += diff

    def query(self, input_data):
        return self.forward(input_data)

    def forward(self, input_data):
        return self.param*input_data


if __name__ == "__main__":
    n = appleNetwork()
    n.train(4, 2)
    t = n.query(8)
    print(t)

输出结果为: 4.0

若输入数据增多,增加损失函数计算,设置学习率为0.05,epoch为10。

import random
import matplotlib.pyplot as plt

class appleNetwork:
    def __init__(self) -> None:
        self.param = random.random()
        self.lr = 0.05

    def train(self, input_data, output_data) -> None:
        result = self.forward(input_data)
        error = output_data - result
        diff = error / input_data
        self.param += diff * self.lr

    def query(self, input_data):
        return self.forward(input_data)


    def forward(self, input_data):
        return self.param*input_data

    def loss_func(self, test_data):
        total_loss = 0
        for item in test_data:
            result = self.forward(item[0])
            error =abs( item[1] - result)
            total_loss += error
        return total_loss


if __name__ == "__main__":
    n = appleNetwork()
    train_data = [[2,4],[1,2],[1.5,2.9],[3,6.2],[4,7],[3,6.1]]
    test_data = [[2,4],[5,10.1],[7,13.9]]
    x = []
    y = []
    step = 0
    for epoch in range(10):
        
        for item in train_data:
            n.train(item[0],item[1])
            error = n.loss_func(test_data)
            step += 1
            x.append(step)
            y.append(error)
    t = n.query(8)
    print(t)
    plt.figure(num=1,figsize=(5,5))
    plt.scatter(x,y)
    plt.show()

通过pylot的可视化操作,得出下图,其中x轴是更新次数,y轴为错误率。

将学习率改为0.01,epoch改为100后,发现经过400次迭代后,错误率几乎不再发生变化。输出结果为15.673927533912916,接近16。

posted @ 2021-05-24 15:10  Summer0127  阅读(88)  评论(0)    收藏  举报