鸢尾草分类-手写简单神经网络-向量化实现,鸢尾草数据集是一个经典的数据集,可以到网上下载。向前向后传递算法思路来自吴恩达deeplearning.ai系列视频。
import random
from csv import reader
import matplotlib.pyplot as plt
import numpy as np
from sklearn.preprocessing import MinMaxScaler
def load_dataset(dataset_path, n_train_data):
dataset = []
label_dict = {'Iris-setosa': 0, 'Iris-versicolor': 1, 'Iris-virginica': 2}
with open(dataset_path, 'r') as file:
csv_reader = reader(file, delimiter=',')
for row in csv_reader:
row[0:4] = list(map(float, row[0:4]))
row[4] = label_dict[row[4]]
dataset.append(row)
dataset = np.array(dataset)
mms = MinMaxScaler()
for i in range(dataset.shape[1] - 1):
dataset[:, i] = mms.fit_transform(dataset[:, i].reshape(-1, 1)).flatten()
dataset = dataset.tolist()
for row in dataset:
row[4] = int(row[4])
random.shuffle(dataset)
train_data = np.array(dataset[0:n_train_data]).T
val_data = np.array(dataset[n_train_data:]).T
return train_data, val_data
if __name__ == "__main__":
file_path = './iris.csv'
l_rate = 0.01
epochs = 300
# load data
train_data, val_data = load_dataset(file_path, 130)
y = train_data[4].reshape([1, 130])
a0 = np.delete(train_data, 4, axis=0) # 4*130
w1 = np.random.random((5, 4)) # 5*4
w2 = np.random.random((1, 5)) # 1*5
cost = []
for i in range(0, epochs):
# forward
z1 = np.dot(w1, a0) # 5*130
a1 = 1.0 / (1.0 + np.exp(-z1)) # 5*130
z2 = np.dot(w2, a1) # 1*130
# back
dz2 = z2 - y # 1*130
cost.append(np.sum(np.power(np.abs(dz2), 2)) / 130)
dw2 = np.dot(dz2, a1.T) # 1*5
da1 = np.dot(w2.T, dz2) # 5*130
dz1 = np.multiply(da1, np.multiply(a1, 1 - a1)) # 5*130
dw1 = np.dot(dz1, a0.T) # 5*4
w1 = w1 - l_rate * dw1
w2 = w2 - l_rate * dw2
# print("w1 =\n" + str(w1) + "\nw2 =\n" + str(w2))
# test
y_test = val_data[4].reshape([1, 20])
a0_test = np.delete(val_data, 4, axis=0)
z1_test = np.dot(w1, a0_test)
a1_test = 1.0 / (1.0 + np.exp(-z1_test))
z2_test = np.dot(w2, a1_test)
z2_test = np.rint(z2_test)
dz2_test = z2_test - y_test
print(z2_test)
print(y_test)
print(np.sum(np.power(np.abs(dz2_test), 2)) / 20)
plt.xlabel('epochs')
plt.ylabel('cost')
plt.plot(cost)
plt.show()