基于MindSpore多标签损失函数与模型训练

本文主要着重于两点:

1  定义多标签数据集;

2  定义多标签损失函数。

 

一  定义多标签数据集

首先定义数据集:

1  get_multilabel_data中产生两个标签y1和y2;

2  GeneratorDataset的column_names参数设置为[‘data’, ‘label1’, ‘label2’]。

这样通过create_multilabel_dataset产生的数据集就有一个数据data,两个标签label1和label2。

 

代码:

import numpy as np

from mindspore import dataset as ds

 

def get_multilabel_data(num, w=2.0, b=3.0):

    for _ in range(num):

        x = np.random.uniform(-10.0, 10.0)

        noise1 = np.random.normal(0, 1)

        noise2 = np.random.normal(-1, 1)

        y1 = x * w + b + noise1

        y2 = x * w + b + noise2

        yield np.array([x]).astype(np.float32), np.array([y1]).astype(np.float32), np.array([y2]).astype(np.float32)

 

def create_multilabel_dataset(num_data, batch_size=64):

    dataset = ds.GeneratorDataset(list(get_multilabel_data(num_data)), column_names=['data', 'label1', 'label2'])

    dataset = dataset.batch(batch_size)

return dataset

 

定义多标签损失函数

  针对上一步创建的数据集,定义损失函数L1LossForMultiLabel。此时,损失函数construct的输入有三个,预测值base,真实值target1和target2,在construct中分别计算预测值与真实值target1、target2之间的误差,将这两个误差的均值作为最终的损失函数值,具体代码如下:

import mindspore.ops as ops

from mindspore.nn import LossBase

 

class L1LossForMultiLabel(LossBase):

    def __init__(self, reduction="mean"):

        super(L1LossForMultiLabel, self).__init__(reduction)

        self.abs = ops.Abs()

 

    def construct(self, base, target1, target2):

        x1 = self.abs(base - target1)

        x2 = self.abs(base - target2)

        return self.get_loss(x1)/2 + self.get_loss(x2)/2

 

完整代码如下:

import numpy as np

 

import mindspore.nn as nn

import mindspore.ops as ops

from mindspore import Model

from mindspore import dataset as ds

from mindspore.nn import LossBase

from mindspore.common.initializer import Normal

from mindspore.train.callback import LossMonitor

 

class LinearNet(nn.Cell):

    def __init__(self):

        super(LinearNet, self).__init__()

        self.fc = nn.Dense(1, 1, Normal(0.02), Normal(0.02))

 

    def construct(self, x):

        return self.fc(x)

 

class L1LossForMultiLabel(LossBase):

    def __init__(self, reduction="mean"):

        super(L1LossForMultiLabel, self).__init__(reduction)

        self.abs = ops.Abs()

 

    def construct(self, base, target1, target2):

        x1 = self.abs(base - target1)

        x2 = self.abs(base - target2)

        return self.get_loss(x1)/2 + self.get_loss(x2)/2

 

class CustomWithLossCell(nn.Cell):

    def __init__(self, backbone, loss_fn):

        super(CustomWithLossCell, self).__init__(auto_prefix=False)

        self._backbone = backbone

        self._loss_fn = loss_fn

 

    def construct(self, data, label1, label2):

        output = self._backbone(data)

        return self._loss_fn(output, label1, label2)

 

def get_multilabel_data(num, w=2.0, b=3.0):

    for _ in range(num):

        x = np.random.uniform(-10.0, 10.0)

        noise1 = np.random.normal(0, 1)

        noise2 = np.random.normal(-1, 1)

        y1 = x * w + b + noise1

        y2 = x * w + b + noise2

        yield np.array([x]).astype(np.float32), np.array([y1]).astype(np.float32), np.array([y2]).astype(np.float32)

 

def create_multilabel_dataset(num_data, batch_size=64):

    dataset = ds.GeneratorDataset(list(get_multilabel_data(num_data)), column_names=['data', 'label1', 'label2'])

    dataset = dataset.batch(batch_size)

    return dataset

 

net = LinearNet()

loss = L1LossForMultiLabel()

# build loss network

loss_net = CustomWithLossCell(net, loss)

 

opt = nn.Momentum(net.trainable_params(), learning_rate=0.005, momentum=0.9)

model = Model(network=loss_net, optimizer=opt)

ds_train = create_multilabel_dataset(num_data=160)

model.train(epoch=1, train_dataset=ds_train, callbacks=[LossMonitor()], dataset_sink_mode=False)

 

运行结果:

epoch: 1 step: 1, loss is 12.013865

epoch: 1 step: 2, loss is 8.693487

epoch: 1 step: 3, loss is 8.687659

epoch: 1 step: 4, loss is 8.543764

epoch: 1 step: 5, loss is 6.957058

epoch: 1 step: 6, loss is 7.432212

epoch: 1 step: 7, loss is 7,896543

epoch: 1 step: 8, loss is 6.334256

epoch: 1 step: 9, loss is 6.234437

epoch: 1 step: 10, loss is 5.546875

posted @ 2021-12-25 15:30  MS小白  阅读(111)  评论(0)    收藏  举报