MindSpore的图像分割模型

图像分割功能可以定位图片中的物体,识别物体的边界轮廓
大致是:
image.png
image.png
image.png
image.png

代码标注

训练模型:

if __name__ == '__main__':
    rank = 0
    device_num = 1
    # 调用接口进行数据处理
    dataset = create_new_dataset(image_dir=config.coco_root, batch_size=config.batch_size, is_training=True, num_parallel_workers=8)
    dataset_size = dataset.get_dataset_size()
    print("total images num: ", dataset_size) # 图像总数
    print("Create dataset done!") # 创建数据集完成!
    # 实例化网络
    net = Mask_Rcnn_Resnet50(config=config)
    net = net.set_train()
    # 加载预训练模型
    load_path = args_opt.pre_trained
    if load_path != "":
        param_dict = load_checkpoint(load_path)
        if config.pretrain_epoch_size == 0:
            for item in list(param_dict.keys()):
                if not (item.startswith('backbone') or item.startswith('rcnn_mask')):
                    param_dict.pop(item)
        load_param_into_net(net, param_dict)
    # 设定损失函数、学习率、优化器
    loss = LossNet()
    opt = Momentum(params=net.trainable_params(), learning_rate=0.0001, momentum=config.momentum,
                   weight_decay=config.weight_decay, loss_scale=config.loss_scale)
    # 包装损失函数
    net_with_loss = WithLossCell(net, loss)
    # 通过TrainOneStepCell自定义训练过程
    net = TrainOneStepCell(net_with_loss, opt, sens=config.loss_scale)
    # 监控训练过程
    time_cb = TimeMonitor(data_size=dataset_size)
    loss_cb = LossCallBack(rank_id=rank)
    cb = [time_cb, loss_cb]
    # 保存训练后的模型
    if config.save_checkpoint:
        # 设置模型保存参数
        ckptconfig = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * dataset_size,
                                      keep_checkpoint_max=config.keep_checkpoint_max)
        save_checkpoint_path = os.path.join(config.save_checkpoint_path, 'ckpt_' + str(rank) + '/')
        # 应用模型保存参数
        ckpoint_cb = ModelCheckpoint(prefix='mask_rcnn', directory=save_checkpoint_path, config=ckptconfig)
        cb += [ckpoint_cb]
    # 进行训练
    model = Model(net)
    model.train(config.epoch_size, dataset, callbacks=cb, dataset_sink_mode = False)

输出:

total loss is: 1.3203125
total loss is: 0.88623046875
total loss is: 0.794921875
total loss is: 0.7216796875
total loss is: 0.67138671875
epoch time: 133629.138 ms, per step time: 26725.828 ms
total loss is: 0.65625
total loss is: 0.646484375
total loss is: 0.5712890625
total loss is: 0.56982421875
total loss is: 0.5732421875
epoch time: 4547.359 ms, per step time: 909.472 ms
total loss is: 0.5595703125
total loss is: 0.5166015625
total loss is: 0.8193359375
total loss is: 0.389892578125
total loss is: 0.44970703125
epoch time: 4639.649 ms, per step time: 927.930 ms
total loss is: 0.360107421875
total loss is: 0.25830078125
total loss is: 0.30224609375
total loss is: 0.2236328125
total loss is: 0.1971435546875
epoch time: 4851.436 ms, per step time: 970.287 ms
total loss is: 0.2021484375
total loss is: 0.36376953125
total loss is: 0.1787109375
total loss is: 0.56884765625
total loss is: 0.1864013671875
epoch time: 4904.663 ms, per step time: 980.933 ms
total loss is: 0.184326171875
total loss is: 0.1395263671875
total loss is: 0.301025390625
total loss is: 0.1458740234375
total loss is: 0.36376953125
epoch time: 4838.346 ms, per step time: 967.669 ms
total loss is: 0.1260986328125
total loss is: 0.08843994140625
total loss is: 0.10125732421875
total loss is: 0.09942626953125
total loss is: 0.162109375
epoch time: 5143.703 ms, per step time: 1028.741 ms
posted @ 2021-12-25 15:30  MS小白  阅读(55)  评论(0)    收藏  举报