第一章: PyTorch计算机视觉实战:目标检测、图像处理与深度学习

 

第一章

# https://github.com/PacktPublishing/Modern-Computer-Vision-with-PyTorch

# https://github.com/PacktPublishing/Modern-Computer-Vision-with-PyTorch


###################  Chapter One #######################################
import numpy as np
from copy import deepcopy
import matplotlib.pyplot as plt
x = np.array([[1,1]])
y = np.array([[0]])

from copy import deepcopy
import numpy as np
def feed_forward(inputs, outputs, weights):
    pre_hidden = np.dot(inputs,weights[0])+ weights[1]
    hidden = 1/(1+np.exp(-pre_hidden))
    out = np.dot(hidden, weights[2]) + weights[3]
    mean_squared_error = np.mean(np.square(out - outputs))
    return mean_squared_error

def update_weights(inputs, outputs, weights, lr):
    original_weights = deepcopy(weights)
    temp_weights = deepcopy(weights)
    updated_weights = deepcopy(weights)
    original_loss = feed_forward(inputs, outputs, original_weights)
    for i, layer in enumerate(original_weights):
        print(i,layer)
        for index, weight in np.ndenumerate(layer):
            print("**",index, weight)
            temp_weights = deepcopy(weights)
            temp_weights[i][index] += 0.0001  # 每次+0.0001方式,改变一个参数,计算新的输出值,与原输出值比较,得到梯度
            print("**@",temp_weights[i][index])
            _loss_plus = feed_forward(inputs, outputs, temp_weights)
            grad = (_loss_plus - original_loss)/(0.0001)
            updated_weights[i][index] -= grad*lr
    return updated_weights, original_loss

W = [
    np.array([[-0.0053, 0.3793],
              [-0.5820, -0.5204],
              [-0.2723, 0.1896]], dtype=np.float32).T,
    np.array([-0.0140, 0.5607, -0.0628], dtype=np.float32),
    np.array([[ 0.1528, -0.1745, -0.1135]], dtype=np.float32).T,
    np.array([-0.5516], dtype=np.float32)
]



losses = []
for epoch in range(100):
    W, loss = update_weights(x,y,W,0.01)
    losses.append(loss)
    print(epoch,loss)
plt.plot(losses)
plt.title('Loss over increasing number of epochs')
plt.show()
########################################################################

 



posted @ 2024-12-12 14:06  辛河  阅读(402)  评论(0)    收藏  举报