MindSpore分布式并行训练—分布式训练通信方法(四)AllReduce

2.1.4. AllReduce

  • Reduce

与 Gather 类似,Reduce 在每个进程上获取一个输入元素数组,并将输出元素数组返回给根进程。

查看当进程拥有多个元素时会发生什么也很有用。 下图显示了每个进程归约多个数字的情况。

上图中的每个进程都有两个元素。 结果求和基于每个元素进行。 换句话说,不是将所有数组中的所有元素累加到一个元素中,而是将每个数组中的第 i 个元素累加到进程 0 结果数组中的第 i 个元素中。

  • AllReduce

许多并行程序中,需要在所有进程而不是仅仅在根进程中访问归约的结果。 以与 Gather 相似的补充方式,Allreduce 将归约值并将结果分配给所有进程。您可能已经注意到,Allreduce 与 Reduce 相同,不同之处在于它不需要根进程 ID(因为结果分配给所有进程)。 下图介绍了 Allreduce 的通信模式:

Allreduce 等效于先执行 Reduce,然后执行 Bcast

class AllReduce(PrimitiveWithInfer):
    """
    Reduce所有设备上的张量数据,以便所有设备都获得相同的最终结果。

    Note:
        AllReduce 的操作目前不支持“prod”。
        张量在集合的所有过程中必须具有相同的形状和格式。

    Args:
        op (str): 指定用于逐元素归约的操作,如总和、最大值和最小值。 默认值:ReduceOp.SUM。
        group (str): 要处理的通信组。 默认值:“hccl_world_group”。

    Inputs:
        - **input_x** (Tensor) - 张量的形状是:math:`(x_1, x_2, ..., x_R)`.

    Outputs:
        Tensor, 具有与输入相同的形状, i.e., :math:`(x_1, x_2, ..., x_R)`.
        context取决于指定的操作。

    Raises:
        TypeError: 如果 `op` 和 `group` 中的任何一个不是 str,或者 fusion 不是整数,或者输入的 dtype 是 bool。
        ValueError: 如果 `op` 是 "prod".

    Supported Platforms:
        ``Ascend`` ``GPU``

    Examples:
        >>> # This example should be run with two devices. Refer to the tutorial > Distributed Training on mindspore.cn
        >>> import numpy as np
        >>> from mindspore.communication import init
        >>> from mindspore import Tensor
        >>> from mindspore.ops import ReduceOp
        >>> import mindspore.nn as nn
        >>> import mindspore.ops as ops
        >>>
        >>> init()
        >>> class Net(nn.Cell):
        ...     def __init__(self):
        ...         super(Net, self).__init__()
        ...         self.allreduce_sum = ops.AllReduce(ReduceOp.SUM)
        ...
        ...     def construct(self, x):
        ...         return self.allreduce_sum(x)
        ...
        >>> input_ = Tensor(np.ones([2, 8]).astype(np.float32))
        >>> net = Net()
        >>> output = net(input_)
        >>> print(output)
        [[2. 2. 2. 2. 2. 2. 2. 2.]
         [2. 2. 2. 2. 2. 2. 2. 2.]]
    """

    @prim_attr_register
    def __init__(self, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP):
        """Initialize AllReduce."""
        if not isinstance(op, type(ReduceOp.SUM)):
            raise TypeError("The operation of AllReduce should be str.")
        if not isinstance(_get_group(group), str):
            raise TypeError("The group of AllReduce should be str.")
        check_hcom_group_valid(group)
        self.op = op
        self.add_prim_attr('group', _get_group(group))
        self.add_prim_attr('fusion', 0)
        self.add_prim_attr('index', 0)
        self.add_prim_attr('no_elimilate', True)

    def infer_shape(self, x_shape):
        return x_shape

    def infer_dtype(self, x_dtype):
        validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name)
        return x_dtype

    class ReduceOp:
    """
    减少张量的操作选项。 这是一个枚举类型,而不是一个运算符。
    主要用于数据并行模式。

    主要调用方法如下:

    - SUM: ReduceOp.SUM.
    - MAX: ReduceOp.MAX.
    - MIN: ReduceOp.MIN.
    - PROD: ReduceOp.PROD.

    有四种运算选项,“SUM”、“MAX”、“MIN”和“PROD”。

    - SUM: 得到 sum.
    - MAX: 得到 maximum.
    - MIN: 得到 minimum.
    - PROD: 得到 product.

    Note:
        更多内容请参考示例。 这需要在具有多个图形卡的环境中运行。

    Supported Platforms:
        ``Ascend`` ``GPU``

    Examples:
        >>> from mindspore.communication import init
        >>> from mindspore import Tensor
        >>> from mindspore.ops.operations.comm_ops import ReduceOp
        >>> import mindspore.nn as nn
        >>> import mindspore.ops.operations as ops
        >>>
        >>> init()
        >>> class Net(nn.Cell):
        ...     def __init__(self):
        ...         super(Net, self).__init__()
        ...         self.allreduce_sum = ops.AllReduce(ReduceOp.SUM, group="nccl_world_group")
        ...
        ...     def construct(self, x):
        ...         return self.allreduce_sum(x)
        ...
        >>> input_ = Tensor(np.ones([2, 8]).astype(np.float32))
        >>> net = Net()
        >>> output = net(input_)
        >>> print(output)
        [[4. 5. 6. 0. 0. 0. 0. 0.]
         [0. 0. 0. 0. 0. 0. 0. 0.]]
    """
    SUM = "sum"
    MAX = "max"
    MIN = "min"
    PROD = "prod"
posted @ 2021-12-20 15:14  MS小白  阅读(91)  评论(0)    收藏  举报