mindspore中神经网络的建立
以构建LeNet网络为例,展示MindSpore是如何建立神经网络模型的
import numpy as np
import mindspore
import mindspore.nn as nn
from mindspore import Tensor
class LeNet5(nn.Cell):
"""
MindSpore的Cell类是构建所有网络的基类,也是网络的基本单元。
当用户需要神经网络时,需要继承Cell类,并重写__init__方法和construct方法
"""
def __init__(self, num_class=10, num_channel=1):
super(LeNet5, self).__init__()
# 定义所需要的运算
self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.fc1 = nn.Dense(16 * 5 * 5, 120)
self.fc2 = nn.Dense(120, 84)
self.fc3 = nn.Dense(84, num_class)
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
def construct(self, x):
# 使用定义好的运算构建前向网络
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x
'''
加入nn.Conv2d层,给网络中加入卷积函数,帮助神经网络提取特征
'''
conv2d = nn.Conv2d(1, 6, 5, has_bias=False, weight_init='normal', pad_mode='valid')
input_x = Tensor(np.ones([1, 1, 32, 32]), mindspore.float32)
print(conv2d(input_x).shape)
'''
加入nn.ReLU层,给网络中加入非线性的激活函数,帮助神经网络学习各种复杂的特征。
'''
relu = nn.ReLU()
input_x = Tensor(np.array([-1, 2, -3, 2,