关于model.train()训练时传入数据参数的问题
mindspore有一键训练功能,可以很方便的使用model.train()进行训练,示例代码如下
net = Net()
loss = nn.SoftmaxCrossEntropyWithLogits()
optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
model = Model(net, loss_fn=loss, optimizer=optim)
model.train(epoch_size, dataset)
其中关于Model()的参数的定义,官方文档给出的参考如下所示:
classmindspore.Model(network, loss_fn=None, optimizer=None, metrics=None, eval_network=None, eval_indexes=None, amp_level="O0", **kwargs, **kwargs)
其中,只有network为必传参数,其他参数都有默认值。所以,要训练网络就有如下三种表示方法:
model = Model(net, loss_fn=loss, optimizer=optim)
model = Model(net_with_loss, optimizer=optim)
model = Model(net_with_loss_and_optimizer)
以上三种网络表示都可以用model.train()进行训练,但是实际上这里存在一个坑。
我的模型的dataset有5列,分别是content、aspect、sen_len、label。我用上述第二种Model构建好网络后,使用model.train(epoch,dataset)进行训练时,出现数据集参数个数不对的错误。
为这个问题纠结了挺久,最后查阅源码发现,前两种Model的定义方式,dataset默认只有两列(data,label),只有第三种Model的定义方式,dataset才支持任意数据集参数个数的传入。
那如果我不使用第三种Model定义方式可以传入任意数量的参数呢,答案也是可以的,只不过需要再用nn.TrainOneStepCell()将网络包起来:
train_net = nn.TrainOneStepCell(net_with_loss, optimizer)
model = Model(train_net)
model.train(epoch_size, dataset)