pytorch实践(五) 定义一个简单的神经网络
# 定义神经网络模型 class NeuralNetwork(nn.Module): def __init__(self): super().__init__() self.flatten = nn.Flatten() # 将 1x28x28 展平为 784 self.linear_relu_stack = nn.Sequential( nn.Linear(28*28, 128), nn.ReLU(), nn.Linear(128, 64), nn.ReLU(), nn.Linear(64, 10) # 最终10类输出 ) def forward(self, x): x = self.flatten(x) logits = self.linear_relu_stack(x) return logits
 
                    
                     
                    
                 
                    
                
 
                
            
         
         浙公网安备 33010602011771号
浙公网安备 33010602011771号