MindSpore分布式并行训练—分布式训练通信方法(三)ReduceScatter

2.1.3. ReduceScatter

提供group内的集合通信reducescatter功能。ReduceScatter是mindspore为实现通信算子的自动微分,为AllGather提供的反向算子。

class ReduceScatter(PrimitiveWithInfer):
    """
    Reduces 并 scatters来自指定通信群的张量。

    Note:
        尚不支持 op 的反向传播。 请继续关注更多。
        张量在集合的所有过程中必须具有相同的形状和格式。

    Args:
        op (str): 指定用于element-wise的操作,如SUM、MAX、AVG。默认值:ReduceOp.SUM。
        group (str): 要处理的通信组。 默认值:“hccl_world_group”。

    Raises:
        TypeError: 如果操作和组中的任何一个不是字符串。
        ValueError: 如果输入的第一个维度不能除以秩大小。

    Supported Platforms:
        ``Ascend`` ``GPU``

    Examples:
        >>> # This example should be run with two devices. Refer to the tutorial > Distributed Training on mindspore.cn
        >>> from mindspore import Tensor, context
        >>> from mindspore.communication import init
        >>> from mindspore.ops import ReduceOp
        >>> import mindspore.nn as nn
        >>> import mindspore.ops as ops
        >>> import numpy as np
        >>>
        >>> context.set_context(mode=context.GRAPH_MODE)
        >>> init()
        >>> class Net(nn.Cell):
        ...     def __init__(self):
        ...         super(Net, self).__init__()
        ...         self.reducescatter = ops.ReduceScatter(ReduceOp.SUM)
        ...
        ...     def construct(self, x):
        ...         return self.reducescatter(x)
        ...
        >>> input_ = Tensor(np.ones([8, 8]).astype(np.float32))
        >>> net = Net()
        >>> output = net(input_)
        >>> print(output)
        [[2. 2. 2. 2. 2. 2. 2. 2.]
         [2. 2. 2. 2. 2. 2. 2. 2.]
         [2. 2. 2. 2. 2. 2. 2. 2.]
         [2. 2. 2. 2. 2. 2. 2. 2.]]
    """

    @prim_attr_register
    def __init__(self, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP):
        """Initialize ReduceScatter."""
        validator.check_value_type('op', op, (type(ReduceOp.SUM),), self.name)
        validator.check_value_type('group', _get_group(group), (str,), self.name)
        self.op = op
        self.rank_size = get_group_size(_get_group(group))
        self.add_prim_attr('rank_size', self.rank_size)
        self.add_prim_attr('group', _get_group(group))
        self.add_prim_attr('fusion', 0)
        self.add_prim_attr('no_elimilate', True)

    def infer_shape(self, x_shape):
        if self.rank_size == 0:
            raise ValueError(f"For '{self.name}' rank_size can not be zero.")
        if x_shape[0] % self.rank_size != 0:
            raise ValueError(f"For '{self.name}' the first dimension of x should be divided by rank_size.")
        x_shape[0] = int(x_shape[0] / self.rank_size)
        return x_shape

    def infer_dtype(self, x_dtype):
        validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name)
        return x_dtype

    def __call__(self, tensor):
        raise NotImplementedError


class _HostReduceScatter(PrimitiveWithInfer):
    """
    从主机上的指定通信组Reduces并scatters 张量。

    Note:
        张量在集合的所有过程中必须具有相同的形状和格式。HostReduceScatter是一个主机端操作符,它依赖于OpenMPI,必须使用构建选项-M on来启用它。使用mpirun命令运行
        mpirun -output-filename log -merge-stderr-to-stdout -np 3 python test_host_reduce_scatter.py

    Args:
        op (str): 指定用于element-wise的操作,如SUM、MAX、AVG。默认值:ReduceOp.SUM。
        group (Union[tuple[int],list[int]]): 要处理的通信组的 rand_ids。 默认值:无。

    Raises:
        TypeError: 如果 op 不是字符串并且 group 不是列表或元组,
                   或组的元素不是整数。
        ValueError: 如果输入的第一个维度不能除以组大小,
                   或组未设置,或 rank_id 不在 [0, 7] 中。
    """

    @prim_attr_register
    def __init__(self, op=ReduceOp.SUM, group=None):
        """Initialize _HostReduceScatter."""
        if group is None:
            raise ValueError(f"For '{self.name}' group must be set.")
        validator.check_value_type('op', op, (type(ReduceOp.SUM),), self.name)
        validator.check_value_type('group', group, (tuple, list), self.name)
        validator.check_int(len(group), 2, Rel.GE, "group size", self.name)
        for r in group:
            validator.check_int_range(r, 0, 7, Rel.INC_BOTH, "rank_id", self.name)
            validator.check_value_type("rank_id", r, (int,), self.name)
        self.op = op
        self.group_size = len(group)
        self.add_prim_attr('group', group)
        self.add_prim_attr('no_elimilate', True)

    def infer_shape(self, x_shape):
        if x_shape[0] % self.group_size != 0:
            raise ValueError(f"For '{self.name}' the first dimension of x should be divided by group_size.")
        x_shape[0] = int(x_shape[0] / self.group_size)
        return x_shape

    def infer_dtype(self, x_dtype):
        validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name)
        return x_dtype

    def __call__(self, tensor):
        raise NotImplementedError
posted @ 2021-12-20 15:15  MS小白  阅读(111)  评论(0)    收藏  举报