MindSpore分布式并行训练—分布式训练通信方法(二)AllGather

2.1.2. AllGather

  • Gather

Gather 从好多进程里面收集数据到一个进程上面。这个机制对很多平行算法很有用,比如并行的排序和搜索。下图是这个算法的一个示例。

Gather 从其他进程收集元素到根进程上面。元素是根据接收到的进程的rank排序的。

  • AllGather

很多时候发送多个元素到多个进程也很有用(也就是多对多通信模式)。Allgather 就是这个作用。

对于分发在所有进程上的一组数据来说,Allgather 会收集所有数据到所有进程上。从最基础的角度来看,Allgather 相当于一个 Gather 操作之后跟着一个 Broadcast 操作。下面的示意图显示了 Allgather 调用之后数据是如何分布的。

每个进程上的元素是根据他们的秩为顺序被收集起来的,只不过这次是收集到了所有进程上面。当然,Allgather 不需要root这个参数来指定根节点。

class AllGather(PrimitiveWithInfer):
    """
    从指定的通信组收集张量。

    Note:
        张量在集合的所有过程中必须具有相同的形状和格式。

    Args:
        group (str): 要处理的通信组。 默认值:“hccl_world_group”。

    Inputs:
        - **input_x** (Tensor) - 张量的形状是 :math:`(x_1, x_2, ..., x_R)`.

    Outputs:
        Tensor. 如果组内设备个数为N, 那么输出的形状是 :math:`(N, x_1, x_2, ..., x_R)`.

    Raises:
        TypeError: 如果`group` 不是字符串。
        ValueError: 如果组中调用进程的本地rank_id大于组的等级大小。

    Supported Platforms:
        ``Ascend`` ``GPU``

    Examples:
        >>> # This example should be run with two devices. Refer to the tutorial > Distributed Training on mindspore.cn
        >>> import numpy as np
        >>> import mindspore.ops as ops
        >>> import mindspore.nn as nn
        >>> from mindspore.communication import init
        >>> from mindspore import Tensor, context
        >>>
        >>> context.set_context(mode=context.GRAPH_MODE)
        >>> init()
        ... class Net(nn.Cell):
        ...     def __init__(self):
        ...         super(Net, self).__init__()
        ...         self.allgather = ops.AllGather()
        ...
        ...     def construct(self, x):
        ...         return self.allgather(x)
        ...
        >>> input_x = Tensor(np.ones([2, 8]).astype(np.float32))
        >>> net = Net()
        >>> output = net(input_x)
        >>> print(output)
        [[1. 1. 1. 1. 1. 1. 1. 1.]
         [1. 1. 1. 1. 1. 1. 1. 1.]
         [1. 1. 1. 1. 1. 1. 1. 1.]
         [1. 1. 1. 1. 1. 1. 1. 1.]]
    """

    @prim_attr_register
    def __init__(self, group=GlobalComm.WORLD_COMM_GROUP):
        """Initialize AllGather."""
        validator.check_value_type('group', _get_group(group), (str,), self.name)
        self.rank = get_rank(_get_group(group))
        self.rank_size = get_group_size(_get_group(group))
        validator.check('rank', self.rank, 'rank_size', self.rank_size, Rel.LT, self.name)
        self.add_prim_attr('rank_size', self.rank_size)
        self.add_prim_attr('group', _get_group(group))
        self.add_prim_attr('fusion', 0)
        self.add_prim_attr('mean_flag', False)
        self.add_prim_attr('no_elimilate', True)

    def infer_shape(self, x_shape):
        validator.check_positive_int(len(x_shape), "x shape", self.name)
        if x_shape[0] > 0:
            x_shape[0] = x_shape[0] * self.rank_size
        return x_shape

    def infer_dtype(self, x_dtype):
        validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name)
        return x_dtype

    def __call__(self, tensor):
        raise NotImplementedError

同时MindSpore还提供了一些仅供内部调用的特殊AllGather

class _MiniStepAllGather(PrimitiveWithInfer):
    """
    自动并行虚拟运算符。 前向传播中什么都不做,在后向传播中在每一小步执行reducescatter。 它只是为了并行模块内部使用,用户不能调用。

    Args:
        group (str): 要处理的通信组。 默认值:None。
        grad_accumulation_step (int): 梯度积累步骤。默认值: None.
    """

    @prim_attr_register
    def __init__(self, group=GlobalComm.WORLD_COMM_GROUP, grad_accumulation_step=None, mean_flag=None):
        """Initialize _MiniStepAllGather."""
        validator.check_value_type('group', _get_group(group), (str,), self.name)
        self.rank = get_rank(_get_group(group))
        self.rank_size = get_group_size(_get_group(group))
        validator.check('rank', self.rank, 'rank_size', self.rank_size, Rel.LT, self.name)
        self.add_prim_attr('rank_size', self.rank_size)
        self.add_prim_attr('group', _get_group(group))
        self.add_prim_attr('fusion', 1)
        self.grad_accumulation_step = grad_accumulation_step
        self.mean_flag = mean_flag

    def infer_shape(self, x_shape, z_shape):
        validator.check_positive_int(len(x_shape), "x shape", self.name)
        if x_shape[0] > 0:
            x_shape[0] = x_shape[0] * self.rank_size
        return x_shape

    def infer_dtype(self, x_dtype, z_shape):
        validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name)
        return x_dtype


class _MicroStepAllGather(PrimitiveWithInfer):
    """
    自动并行虚拟运算符。 前向传播中什么都不做,在后向传播中在每一小步执行reducescatter。 它只是为了并行模块内部使用,用户不能调用。

    Args:
        group (str): 要处理的通信组。 默认值:None。
    """

    @prim_attr_register
    def __init__(self, group=GlobalComm.WORLD_COMM_GROUP, mean_flag=None):
        validator.check_value_type('group', _get_group(group), (str,), self.name)
        self.rank = get_rank(_get_group(group))
        self.rank_size = get_group_size(_get_group(group))
        validator.check('rank', self.rank, 'rank_size', self.rank_size, Rel.LT, self.name)
        self.add_prim_attr('rank_size', self.rank_size)
        self.add_prim_attr('group', _get_group(group))
        self.add_prim_attr('fusion', 1)
        self.mean_flag = mean_flag

    def infer_shape(self, x_shape, z_shape):
        validator.check_positive_int(len(x_shape), "x shape", self.name)
        if x_shape[0] > 0:
            x_shape[0] = x_shape[0] * self.rank_size
        return x_shape

    def infer_dtype(self, x_dtype, z_dtype):
        validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name)
        return x_dtype


class _HostAllGather(PrimitiveWithInfer):
    """
    从主机上的指定通信组收集张量。

    Note:
        张量在集合的所有过程中必须具有相同的形状和格式。
        _HostAllGather 是一个主机端操作符,它依赖于 OpenMPI 并且必须使用构建选项 -M on 启用它。 
        使用 mpirun 命令运行它:
        mpirun -output-filename log -merge-stderr-to-stdout -np 3 python test_host_all_gather.py

    Args:
        group (Union[tuple[int],list[int]]): 要处理的通信组的rand_id。默认值: None.

    Raises:
        TypeError: 如果 group 不是列表或元组,或者 group 的元素不是 int。
        ValueError: 如果未设置组,或组中的 rank_id 不在 [0, 7] 中。

    Inputs:
        - **input_x** (Tensor) - 张量的形状是:math:`(x_1, x_2, ..., x_R)`.

    Outputs:
        Tensor. 如果组内设备个数为N, 那么输出的形状是 :math:`(N, x_1, x_2, ..., x_R)`.
    """

    @prim_attr_register
    def __init__(self, group=None):
        """Initialize _HostAllGather."""
        if group is None:
            raise ValueError(f"For '{self.name}' group must be set.")
        validator.check_value_type('group', group, (tuple, list), self.name)
        validator.check_int(len(group), 2, Rel.GE, "group size", self.name)
        for r in group:
            validator.check_int_range(r, 0, 7, Rel.INC_BOTH, "rank_id", self.name)
            validator.check_value_type("rank_id", r, (int,), self.name)
        self.group_size = len(group)
        self.add_prim_attr('group', group)
        self.add_prim_attr('no_elimilate', True)

    def infer_shape(self, x_shape):
        validator.check_positive_int(len(x_shape), "x shape", self.name)
        if x_shape[0] > 0:
            x_shape[0] = x_shape[0] * self.group_size
        return x_shape

    def infer_dtype(self, x_dtype):
        validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name)
        return x_dtype

    def __call__(self, tensor):
        raise NotImplementedError
posted @ 2021-12-20 15:15  MS小白  阅读(435)  评论(0)    收藏  举报