PyTorch 深度学习实践 第7讲:处理多维特征的输入

处理多维特征的输入

视频教程

图片来自课程ppt

1.准备数据集:

行:样本
列:每个样本的特征 代码数据集x是前8列,y最后一列

image

2.多维的逻辑斯蒂回归模型

image
说明:

  1. 计算每个样本的8个特征分别去乘以权重

  2. 计算出1的值之后加上偏置量,即z(i)(第i个样本的线性计算值)

  3. 将z(i)代入logistic函数,求出y的预测值

3.n个样本的mini-batch

image
说明:

  1. 上标样本,下标特征
  2. 将方程组可以转换成向量矩阵的运算 输入8维,输出1维,因此,x,y都得是矩阵
4.线形层

image
说明:
通过线性层+激活函数 -> 降到6维->........一直降维,相反,也可以升高维!

5.代码图解

image

image

6.代码:

import  torch
import  numpy as np
import matplotlib.pyplot as plt

#1.准备数据:数据与源代码放在同一个文件内,老师给出的的数据集在视频置顶评论
xy = np.loadtxt('diabetes.csv', delimiter=',', dtype=np.float32)
x_data = torch.from_numpy(xy[:,:-1])#第一个“;”读取所有行,行全要,第二个:指截取除去最后一列的其他数据
y_data = torch.from_numpy(xy[:,[-1]])#y就是最后一列向量组成的矩阵

#2.设计模型
class Model(torch.nn.Module):#Model继承自nn.module
    def __init__(self):
        super(Model,self).__init__()
        self.linear1 = torch.nn.Linear(8,6)#输入数据x的特征8维,输出y的特征是6维
        self.linear1 = torch.nn.Linear(6,4)
        self.linear1 = torch.nn.Linear(4,1)#逐层输出y降维1维
        self.sigmoid = torch.nn.Sigmoid()
        #nn.Sigmoid:构建计算图,运算模块:非线性变换层
        #注意nn.Function.sigmoid是函数,这里通过线性层+激活函数 -> 降到6维-...->降到1维
        
    def forward(self,x):
        x = self.sigmoid(self.linear1(x))#计算O1
        x = self.sigmoid(self.linear2(x))#计算O2,O1的结果作为输入
        x = self.sigmoid(self.linear3(x))#计算y^,O2的结果作为输入
        return x#返回y^,为了防止误差,全部统一为x
    

model = Model()

#3.构建loss与optimizer
#criterion = torch.nn.BCELoss(size_average='True')
criterion = torch.nn.BCELoss(reduction='mean')
optimizer = torch.optim.SGD(model.parameters(),lr = 0.1)

epoch_list = []
loss_list = []

#4.循环训练forward backward update
for epoch in range(100):
    y_pred = model(x_data)
    loss = criterion(y_pred,y_data)
    print(epoch,loss.item())
    epoch_list.append(epoch)
    loss_list.append(loss.item())
    
    optimizer.zero_grad()
    loss.backward()
    
    optimizer.step()
    

plt.plot(epoch_list,loss_list)
plt.ylabel('loss')
plt.xlabel('epoch')
plt.show()

0 0.645651638507843
1 0.6456509828567505
2 0.645650327205658
3 0.6456496119499207
4 0.6456489562988281
5 0.6456483602523804
.......
93 0.645593523979187
94 0.6455929279327393
95 0.6455922722816467
96 0.645591676235199
97 0.645591139793396
98 0.6455904841423035
99 0.6455898880958557
image

posted @ 2022-08-09 12:27  Ling22  阅读(178)  评论(0)    收藏  举报