MindSpore分布式并行训练—分布式训练通信方法(二)AllGather
2.1.2. AllGather
- Gather
Gather
从好多进程里面收集数据到一个进程上面。这个机制对很多平行算法很有用,比如并行的排序和搜索。下图是这个算法的一个示例。
Gather
从其他进程收集元素到根进程上面。元素是根据接收到的进程的rank排序的。
- AllGather
很多时候发送多个元素到多个进程也很有用(也就是多对多通信模式)。Allgather
就是这个作用。
对于分发在所有进程上的一组数据来说,Allgather
会收集所有数据到所有进程上。从最基础的角度来看,Allgather
相当于一个 Gather
操作之后跟着一个 Broadcast
操作。下面的示意图显示了 Allgather
调用之后数据是如何分布的。
每个进程上的元素是根据他们的秩为顺序被收集起来的,只不过这次是收集到了所有进程上面。当然,Allgather
不需要root这个参数来指定根节点。
class AllGather(PrimitiveWithInfer):
"""
从指定的通信组收集张量。
Note:
张量在集合的所有过程中必须具有相同的形状和格式。
Args:
group (str): 要处理的通信组。 默认值:“hccl_world_group”。
Inputs:
- **input_x** (Tensor) - 张量的形状是 :math:`(x_1, x_2, ..., x_R)`.
Outputs:
Tensor. 如果组内设备个数为N, 那么输出的形状是 :math:`(N, x_1, x_2, ..., x_R)`.
Raises:
TypeError: 如果`group` 不是字符串。
ValueError: 如果组中调用进程的本地rank_id大于组的等级大小。
Supported Platforms:
``Ascend`` ``GPU``
Examples:
>>> # This example should be run with two devices. Refer to the tutorial > Distributed Training on mindspore.cn
>>> import numpy as np
>>> import mindspore.ops as ops
>>> import mindspore.nn as nn
>>> from mindspore.communication import init
>>> from mindspore import Tensor, context
>>>
>>> context.set_context(mode=context.GRAPH_MODE)
>>> init()
... class Net(nn.Cell):
... def __init__(self):
... super(Net, self).__init__()
... self.allgather = ops.AllGather()
...
... def construct(self, x):
... return self.allgather(x)
...
>>> input_x = Tensor(np.ones([2, 8]).astype(np.float32))
>>> net = Net()
>>> output = net(input_x)
>>> print(output)
[[1. 1. 1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1. 1. 1.]]
"""
@prim_attr_register
def __init__(self, group=GlobalComm.WORLD_COMM_GROUP):
"""Initialize AllGather."""
validator.check_value_type('group', _get_group(group), (str,), self.name)
self.rank = get_rank(_get_group(group))
self.rank_size = get_group_size(_get_group(group))
validator.check('rank', self.rank, 'rank_size', self.rank_size, Rel.LT, self.name)
self.add_prim_attr('rank_size', self.rank_size)
self.add_prim_attr('group', _get_group(group))
self.add_prim_attr('fusion', 0)
self.add_prim_attr('mean_flag', False)
self.add_prim_attr('no_elimilate', True)
def infer_shape(self, x_shape):
validator.check_positive_int(len(x_shape), "x shape", self.name)
if x_shape[0] > 0:
x_shape[0] = x_shape[0] * self.rank_size
return x_shape
def infer_dtype(self, x_dtype):
validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name)
return x_dtype
def __call__(self, tensor):
raise NotImplementedError
同时MindSpore还提供了一些仅供内部调用的特殊AllGather
class _MiniStepAllGather(PrimitiveWithInfer):
"""
自动并行虚拟运算符。 前向传播中什么都不做,在后向传播中在每一小步执行reducescatter。 它只是为了并行模块内部使用,用户不能调用。
Args:
group (str): 要处理的通信组。 默认值:None。
grad_accumulation_step (int): 梯度积累步骤。默认值: None.
"""
@prim_attr_register
def __init__(self, group=GlobalComm.WORLD_COMM_GROUP, grad_accumulation_step=None, mean_flag=None):
"""Initialize _MiniStepAllGather."""
validator.check_value_type('group', _get_group(group), (str,), self.name)
self.rank = get_rank(_get_group(group))
self.rank_size = get_group_size(_get_group(group))
validator.check('rank', self.rank, 'rank_size', self.rank_size, Rel.LT, self.name)
self.add_prim_attr('rank_size', self.rank_size)
self.add_prim_attr('group', _get_group(group))
self.add_prim_attr('fusion', 1)
self.grad_accumulation_step = grad_accumulation_step
self.mean_flag = mean_flag
def infer_shape(self, x_shape, z_shape):
validator.check_positive_int(len(x_shape), "x shape", self.name)
if x_shape[0] > 0:
x_shape[0] = x_shape[0] * self.rank_size
return x_shape
def infer_dtype(self, x_dtype, z_shape):
validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name)
return x_dtype
class _MicroStepAllGather(PrimitiveWithInfer):
"""
自动并行虚拟运算符。 前向传播中什么都不做,在后向传播中在每一小步执行reducescatter。 它只是为了并行模块内部使用,用户不能调用。
Args:
group (str): 要处理的通信组。 默认值:None。
"""
@prim_attr_register
def __init__(self, group=GlobalComm.WORLD_COMM_GROUP, mean_flag=None):
validator.check_value_type('group', _get_group(group), (str,), self.name)
self.rank = get_rank(_get_group(group))
self.rank_size = get_group_size(_get_group(group))
validator.check('rank', self.rank, 'rank_size', self.rank_size, Rel.LT, self.name)
self.add_prim_attr('rank_size', self.rank_size)
self.add_prim_attr('group', _get_group(group))
self.add_prim_attr('fusion', 1)
self.mean_flag = mean_flag
def infer_shape(self, x_shape, z_shape):
validator.check_positive_int(len(x_shape), "x shape", self.name)
if x_shape[0] > 0:
x_shape[0] = x_shape[0] * self.rank_size
return x_shape
def infer_dtype(self, x_dtype, z_dtype