关于model.train()训练时传入数据参数的问题

mindspore有一键训练功能,可以很方便的使用model.train()进行训练,示例代码如下

net = Net()
loss = nn.SoftmaxCrossEntropyWithLogits()
optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
model = Model(net, loss_fn=loss, optimizer=optim)
model.train(epoch_size, dataset)

其中关于Model()的参数的定义,官方文档给出的参考如下所示:

classmindspore.Model(network, loss_fn=None, optimizer=None, metrics=None, eval_network=None, eval_indexes=None, amp_level="O0", **kwargs, **kwargs)

其中,只有network为必传参数,其他参数都有默认值。所以,要训练网络就有如下三种表示方法:

model = Model(net, loss_fn=loss, optimizer=optim)  
model = Model(net_with_loss, optimizer=optim) 
model = Model(net_with_loss_and_optimizer)

以上三种网络表示都可以用model.train()进行训练,但是实际上这里存在一个坑。

我的模型的dataset有5列,分别是content、aspect、sen_len、label。我用上述第二种Model构建好网络后,使用model.train(epoch,dataset)进行训练时,出现数据集参数个数不对的错误。

为这个问题纠结了挺久,最后查阅源码发现,前两种Model的定义方式,dataset默认只有两列(data,label),只有第三种Model的定义方式,dataset才支持任意数据集参数个数的传入。

那如果我不使用第三种Model定义方式可以传入任意数量的参数呢,答案也是可以的,只不过需要再用nn.TrainOneStepCell()将网络包起来:

train_net = nn.TrainOneStepCell(net_with_loss, optimizer)
model = Model(train_net)
model.train(epoch_size, dataset)
posted @ 2021-12-30 19:03  MS小白  阅读(208)  评论(0)    收藏  举报