MindSpore的图像分割模型
图像分割功能可以定位图片中的物体,识别物体的边界轮廓
大致是:
代码标注
训练模型:
if __name__ == '__main__':
rank = 0
device_num = 1
# 调用接口进行数据处理
dataset = create_new_dataset(image_dir=config.coco_root, batch_size=config.batch_size, is_training=True, num_parallel_workers=8)
dataset_size = dataset.get_dataset_size()
print("total images num: ", dataset_size) # 图像总数
print("Create dataset done!") # 创建数据集完成!
# 实例化网络
net = Mask_Rcnn_Resnet50(config=config)
net = net.set_train()
# 加载预训练模型
load_path = args_opt.pre_trained
if load_path != "":
param_dict = load_checkpoint(load_path)
if config.pretrain_epoch_size == 0:
for item in list(param_dict.keys()):
if not (item.startswith('backbone') or item.startswith('rcnn_mask')):
param_dict.pop(item)
load_param_into_net(net, param_dict)
# 设定损失函数、学习率、优化器
loss = LossNet()
opt = Momentum(params=net.trainable_params(), learning_rate=0.0001, momentum=config.momentum,
weight_decay=config.weight_decay, loss_scale=config.loss_scale)
# 包装损失函数
net_with_loss = WithLossCell(net, loss)
# 通过TrainOneStepCell自定义训练过程
net = TrainOneStepCell(net_with_loss, opt, sens=config.loss_scale)
# 监控训练过程
time_cb = TimeMonitor(data_size=dataset_size)
loss_cb = LossCallBack(rank_id=rank)
cb = [time_cb, loss_cb]
# 保存训练后的模型
if config.save_checkpoint:
# 设置模型保存参数
ckptconfig = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * dataset_size,
keep_checkpoint_max=config.keep_checkpoint_max)
save_checkpoint_path = os.path.join(config.save_checkpoint_path, 'ckpt_' + str(rank) + '/')
# 应用模型保存参数
ckpoint_cb = ModelCheckpoint(prefix='mask_rcnn', directory=save_checkpoint_path, config=ckptconfig)
cb += [ckpoint_cb]
# 进行训练
model = Model(net)
model.train(config.epoch_size, dataset, callbacks=cb, dataset_sink_mode = False)
输出:
total loss is: 1.3203125
total loss is: 0.88623046875
total loss is: 0.794921875
total loss is: 0.7216796875
total loss is: 0.67138671875
epoch time: 133629.138 ms, per step time: 26725.828 ms
total loss is: 0.65625
total loss is: 0.646484375
total loss is: 0.5712890625
total loss is: 0.56982421875
total loss is: 0.5732421875
epoch time: 4547.359 ms, per step time: 909.472 ms
total loss is: 0.5595703125
total loss is: 0.5166015625
total loss is: 0.8193359375
total loss is: 0.389892578125
total loss is: 0.44970703125
epoch time: 4639.649 ms, per step time: 927.930 ms
total loss is: 0.360107421875
total loss is: 0.25830078125
total loss is: 0.30224609375
total loss is: 0.2236328125
total loss is: 0.1971435546875
epoch time: 4851.436 ms, per step time: 970.287 ms
total loss is: 0.2021484375
total loss is: 0.36376953125
total loss is: 0.1787109375
total loss is: 0.56884765625
total loss is: 0.1864013671875
epoch time: 4904.663 ms, per step time: 980.933 ms
total loss is: 0.184326171875
total loss is: 0.1395263671875
total loss is: 0.301025390625
total loss is: 0.1458740234375
total loss is: 0.36376953125
epoch time: 4838.346 ms, per step time: 967.669 ms
total loss is: 0.1260986328125
total loss is: 0.08843994140625
total loss is: 0.10125732421875
total loss is: 0.09942626953125
total loss is: 0.162109375
epoch time: 5143.703 ms, per step time: 1028.741 ms