MindSpore分布式并行训练—分布式训练通信方法(三)ReduceScatter
2.1.3. ReduceScatter
提供group内的集合通信reducescatter功能。ReduceScatter是mindspore为实现通信算子的自动微分,为AllGather提供的反向算子。
class ReduceScatter(PrimitiveWithInfer):
"""
Reduces 并 scatters来自指定通信群的张量。
Note:
尚不支持 op 的反向传播。 请继续关注更多。
张量在集合的所有过程中必须具有相同的形状和格式。
Args:
op (str): 指定用于element-wise的操作,如SUM、MAX、AVG。默认值:ReduceOp.SUM。
group (str): 要处理的通信组。 默认值:“hccl_world_group”。
Raises:
TypeError: 如果操作和组中的任何一个不是字符串。
ValueError: 如果输入的第一个维度不能除以秩大小。
Supported Platforms:
``Ascend`` ``GPU``
Examples:
>>> # This example should be run with two devices. Refer to the tutorial > Distributed Training on mindspore.cn
>>> from mindspore import Tensor, context
>>> from mindspore.communication import init
>>> from mindspore.ops import ReduceOp
>>> import mindspore.nn as nn
>>> import mindspore.ops as ops
>>> import numpy as np
>>>
>>> context.set_context(mode=context.GRAPH_MODE)
>>> init()
>>> class Net(nn.Cell):
... def __init__(self):
... super(Net, self).__init__()
... self.reducescatter = ops.ReduceScatter(ReduceOp.SUM)
...
... def construct(self, x):
... return self.reducescatter(x)
...
>>> input_ = Tensor(np.ones([8, 8]).astype(np.float32))
>>> net = Net()
>>> output = net(input_)
>>> print(output)
[[2. 2. 2. 2. 2. 2. 2. 2.]
[2. 2. 2. 2. 2. 2. 2. 2.]
[2. 2. 2. 2. 2. 2. 2. 2.]
[2. 2. 2. 2. 2. 2. 2. 2.]]
"""
@prim_attr_register
def __init__(self, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP):
"""Initialize ReduceScatter."""
validator.check_value_type('op', op, (type(ReduceOp.SUM),), self.name)
validator.check_value_type('group', _get_group(group), (str,), self.name)
self.op = op
self.rank_size = get_group_size(_get_group(group))
self.add_prim_attr('rank_size', self.rank_size)
self.add_prim_attr('group', _get_group(group))
self.add_prim_attr('fusion', 0)
self.add_prim_attr('no_elimilate', True)
def infer_shape(self, x_shape):
if self.rank_size == 0:
raise ValueError(f"For '{self.name}' rank_size can not be zero.")
if x_shape[0] % self.rank_size != 0:
raise ValueError(f"For '{self.name}' the first dimension of x should be divided by rank_size.")
x_shape[0] = int(x_shape[0] / self.rank_size)
return x_shape
def infer_dtype(self, x_dtype):
validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name)
return x_dtype
def __call__(self, tensor):
raise NotImplementedError
class _HostReduceScatter(PrimitiveWithInfer):
"""
从主机上的指定通信组Reduces并scatters 张量。
Note:
张量在集合的所有过程中必须具有相同的形状和格式。HostReduceScatter是一个主机端操作符,它依赖于OpenMPI,必须使用构建选项-M on来启用它。使用mpirun命令运行
mpirun -output-filename log -merge-stderr-to-stdout -np 3 python test_host_reduce_scatter.py
Args:
op (str): 指定用于element-wise的操作,如SUM、MAX、AVG。默认值:ReduceOp.SUM。
group (Union[tuple[int],list[int]]): 要处理的通信组的 rand_ids。 默认值:无。
Raises:
TypeError: 如果 op 不是字符串并且 group 不是列表或元组,
或组的元素不是整数。
ValueError: 如果输入的第一个维度不能除以组大小,
或组未设置,或 rank_id 不在 [0, 7] 中。
"""
@prim_attr_register
def __init__(self, op=ReduceOp.SUM, group=None):
"""Initialize _HostReduceScatter."""
if group is None:
raise ValueError(f"For '{self.name}' group must be set.")
validator.check_value_type('op', op, (type(ReduceOp.SUM),), self.name)
validator.check_value_type('group', group, (tuple, list), self.name)
validator.check_int(len(group), 2, Rel.GE, "group size", self.name)
for r in group:
validator.check_int_range(r, 0, 7, Rel.INC_BOTH, "rank_id", self.name)
validator.check_value_type("rank_id", r, (int,), self.name)
self.op = op
self.group_size = len(group)
self.add_prim_attr('group', group)
self.add_prim_attr('no_elimilate',