MindSpore分布式并行训练—自动并行架构介绍(一)半自动并行

3.1. 半自动并行

在MindSpore中系统自动分析张量的空间排布与所需的通信策略。仅需要指定算子的切分策略,系统就可以自行实现所需的数据并行与模型并行策略。如下图所示:在input的指定维度设置设备数即可在指定维度进行数据拆分;在模型参数weight的指定维度设置设备数即可在指定维度进行模型拆分。

MindSpore 较灵活,它支持用户指定的高级策略配置,称之为半自动并行(semi-auto-parallel)。下面展示一个从数据到模型的并行转换的例子。该子模型的结构为 BatchNorm 算子后跟一个 MatMul 算子,广泛应用于 ResNet、ReID 等分类任务。

class Submodel(nn.Cell):
    def _init_(self, shape):
        self.bn = BatchNorm(set_strategy={[4, 1]})
        self.matmul = MatMul(set_strategy={[1, 1], [1, 4]})
        self.W = Parameter(Tensor(shape), require_grad=True)

    def construct(self, X):
        Y = self.bn(X)
        Z = self.matmul(y, self.W)
        return Z

在 BatchNorm 算子中,X 按行拆分为四部分,数据可以并行,效率非常高。在 MatMul 算子中,可学习参数的权重 W 被分成四部分,模型可以并行,由于参数数量较多,这部分的模型并行更有效。由于 BatchNormi 的输出布局与 MatMul 的输入布局不同,所以框架插入了一个张量重排布(该例中为 AllGather 和 ConCat),这一过程对用户是透明的。用户也不必关注哪个设备运行了模型的哪个部分,框架会自动安排。

上述示例中,用户无需指定图切分方案,仅需指定各算子的切分策略。这种方式对手动调优有很大帮助,但是还是具有一定的复杂度。多次配置 set_strategy,耗时耗力。

Mindspore通过装饰器设计为Cell提供了切分策略所需的集合通信的功能mindspore.parallel._cell_wrapper.py

# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""自动并联单元"""

from mindspore.nn.cell import Cell
from mindspore.ops.operations.comm_ops import AllGather
from mindspore.communication import GlobalComm

_allgather_cell = None


class AllGatherCell(Cell):
    """
    Allgather 单元,用于模型并行场景。
    从每个设备中收集选定的参数切片。
    """
    def __init__(self, group):
        super(AllGatherCell, self).__init__(auto_prefix=False)

        self.allgather = AllGather(group)

    def construct(self, x):
        x = self.allgather(x)

        return x


class SaveOptShardCkptCell(Cell):
    """
    Allgather 单元,用于优化器并行场景。
    首先将张量收集到指定设备组中的原始布局。
    然后从所有设备收集整个参数切片。

    Note:
        这可以在以后以更少的通信消耗进行优化。
    """
    def __init__(self, group):
        super(SaveOptShardCkptCell, self).__init__(auto_prefix=False)
        self.allgather1 = AllGather(group)
        self.allgather2 = AllGather()

    def construct(self, x):
        x = self.allgather1(x)
        x = self.allgather2(x)

        return x


def get_allgather_cell(group, need_merge_twice=False):
    """得到AllGatherCell对象。"""
    global _allgather_cell
    if need_merge_twice:
        _allgather_cell = SaveOptShardCkptCell(group)
    else:
        if group:
            _allgather_cell = AllGatherCell(group)
        else:
            _allgather_cell = AllGatherCell(GlobalComm.WORLD_COMM_GROUP)
    return _allgather_cell


def destroy_allgather_cell():
    """销毁 AllGatherCell 对象。"""
    global _allgather_cell
    if _allgather_cell:
        _allgather_cell = None
posted @ 2021-12-20 15:11  MS小白  阅读(325)  评论(0)    收藏  举报