##手写数字识别
# 对分类任务而言,标签不可以用数值表示,而应该用向量表示(向量有相似度)。因此,我们需要将标签转换为独热编码形式(onehot )。
# 导入必要的库
import numpy as np # numpy库
import os # 操作系统库
import struct # 数据结构库
import matplotlib.pyplot as plt # 绘图库
def load_images(path): # 加载图片
with open(path, "rb") as f: # 读取二进制文件,rb为读取二进制文件
data = f.read() # 读取全部数据
magic_number, num_items, rows, cols = struct.unpack(
">iiii", data[:16]
) # 解析文件头
return np.asanyarray(bytearray(data[16:]), dtype=np.uint8).reshape(
num_items, 28, 28
) # 解析图片数据并返回numpy数组
def load_labels(file): # 加载标签
with open(file, "rb") as f: # 读取二进制文件,rb为读取二进制文件
data = f.read() # 读取全部数据
return np.asanyarray(
bytearray(data[8:]), dtype=np.int32
) # 解析标签数据并返回numpy数组
def softmax(x): # softmax函数
ex = np.exp(x) # 计算e的幂
sum_ex = np.sum(ex) # 求和
result = ex / sum_ex # 计算softmax
return result # 返回softmax结果
def make_onehot(labels, class_num): # 独热编码
result = np.zeros((labels.shape[0], class_num)) # 初始化结果数组
for idx, cls in enumerate(labels): # 遍历标签
result[idx, cls] = 1 # 标记为1
return result # 返回结果数组
if (
__name__ == "__main__" # 判断是否为主程序
): # D:\my code\Python\NLP basic\data\minist\t10k-images.idx3-ubyte
# Python\NLP basic\data\minist\t10k-images.idx3-ubyte
train_images = load_images(
os.path.join("Python", "NLP basic", "data", "minist", "train-images.idx3-ubyte")
)/255 # 加载训练集图片
train_labels = make_onehot(
load_labels(
os.path.join(
"Python", "NLP basic", "data", "minist", "train-labels.idx1-ubyte"
)
),
10,
) # 加载训练集标签
train_images = train_images.reshape(60000, 784) # 将图片数据转换为一维数组
w = np.random.normal(0, 1, size=(784, 10)) # 随机初始化权重
b = np.random.normal(0, 1, size=(1, 10)) # 随机初始化偏置
epochs = 10 # 迭代次数
lr = 0.01 # 学习率
for e in range(epochs): # 迭代epochs次
for idx in range(train_images.shape[0]): # 遍历训练集
image = train_images[idx : idx + 1] # 取出一张图片
label = train_labels[idx : idx + 1] # 取出一张标签
pre = image @ w + b # 计算预测值
p = softmax(pre) # 计算softmax值
loss = -np.sum(label * np.log(p)) # 计算损失
G=p-label # 计算梯度
delta_w = image.T @ G # 计算权重梯度
delta_b = np.sum(G) # 计算偏置梯度
w=w-lr*delta_w # 更新权重
b=b-lr*delta_b # 更新偏置
print("损失值: ", loss) # 打印损失值