MindSpore分布式并行训练—自动并行架构介绍(一)半自动并行
3.1. 半自动并行
在MindSpore中系统自动分析张量的空间排布与所需的通信策略。仅需要指定算子的切分策略,系统就可以自行实现所需的数据并行与模型并行策略。如下图所示:在input的指定维度设置设备数即可在指定维度进行数据拆分;在模型参数weight的指定维度设置设备数即可在指定维度进行模型拆分。
MindSpore 较灵活,它支持用户指定的高级策略配置,称之为半自动并行(semi-auto-parallel)。下面展示一个从数据到模型的并行转换的例子。该子模型的结构为 BatchNorm 算子后跟一个 MatMul 算子,广泛应用于 ResNet、ReID 等分类任务。
class Submodel(nn.Cell):
def _init_(self, shape):
self.bn = BatchNorm(set_strategy={[4, 1]})
self.matmul = MatMul(set_strategy={[1, 1], [1, 4]})
self.W = Parameter(Tensor(shape), require_grad=True)
def construct(self, X):
Y = self.bn(X)
Z = self.matmul(y, self.W)
return Z
在 BatchNorm 算子中,X 按行拆分为四部分,数据可以并行,效率非常高。在 MatMul 算子中,可学习参数的权重 W 被分成四部分,模型可以并行,由于参数数量较多,这部分的模型并行更有效。由于 BatchNormi 的输出布局与 MatMul 的输入布局不同,所以框架插入了一个张量重排布(该例中为 AllGather 和 ConCat),这一过程对用户是透明的。用户也不必关注哪个设备运行了模型的哪个部分,框架会自动安排。
上述示例中,用户无需指定图切分方案,仅需指定各算子的切分策略。这种方式对手动调优有很大帮助,但是还是具有一定的复杂度。多次配置 set_strategy,耗时耗力。
Mindspore通过装饰器设计为Cell提供了切分策略所需的集合通信的功能mindspore.parallel._cell_wrapper.py
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""自动并联单元"""
from mindspore.nn.cell import Cell
from mindspore.ops.operations.comm_ops import AllGather
from mindspore.communication import GlobalComm
_allgather_cell = None
class AllGatherCell(Cell):
"""
Allgather 单元,用于模型并行场景。
从每个设备中收集选定的参数切片。
"""
def __init__(self, group):
super(AllGatherCell, self).__init__(auto_prefix=False)
self.allgather = AllGather(group)
def construct(self, x):
x = self.allgather(x)
return x
class SaveOptShardCkptCell(Cell):
"""
Allgather 单元,用于优化器并行场景。
首先将张量收集到指定设备组中的原始布局。
然后从所有设备收集整个参数切片。
Note:
这可以在以后以更少的通信消耗进行优化。
"""
def __init__(self, group):
super(SaveOptShardCkptCell, self).__init__(auto_prefix=False)
self.allgather1 = AllGather(group)
self.allgather2 = AllGather()
def construct(self, x):
x = self.allgather1(x)
x = self.allgather2(x)
return x
def get_allgather_cell(group, need_merge_twice=False):
"""得到AllGatherCell对象。"""
global _allgather_cell
if need_merge_twice:
_allgather_cell = SaveOptShardCkptCell(group)
else:
if group:
_allgather_cell = AllGatherCell(group)
else:
_allgather_cell = AllGatherCell(GlobalComm.WORLD_COMM_GROUP)
return _allgather_cell
def destroy_allgather_cell():
"""销毁 AllGatherCell 对象。"""
global _allgather_cell
if _allgather_cell:
_allgather_cell = None