MindSpore分布式并行训练—分布式训练通信方法(一)Broadcast

2. 分布式训练通信方法

2.1集合通信原语

MindSpore采用了集合通信模式来交互梯度或activation。所谓集合通信模式是指,模型切分后,通过集合通信原语来实现不同模型切片之间的数据交互。集合通信原语主要有Broadcast、AllGather、AllReduce、ReduceScatter等。

v2-e71943e6d89a7129ad94c65c0814ffa2_720w.jpg

 

集合通信原语相关操作被封装在 mindspore.ops.operations.comm_ops.py 。下面我们将对该文件中涉及集合通信的内容进行具体查看解析。

2.1.1. Broadcast

广播 (broadcast) 是标准的集体通信技术之一。一个广播发生的时候,一个进程会把同样一份数据传递给一个 communicator 里的所有其他进程。广播的主要用途之一是把用户输入传递给一个分布式程序,或者把一些配置参数传递给所有的进程。

广播的通信模式看起来像这样:

broadcast_pattern.png

 

在这个例子里,进程0是我们的进程,它持有一开始的数据。其他所有的进程都会从它这里接受到一份数据的副本。

 class Broadcast(PrimitiveWithInfer):
     """
    向整个组广播张量。
 ​
    Note:
        张量在集合的所有过程中必须具有相同的形状和格式。
 ​
    Args:
        root_rank (int): 数据来源编号。 除了一个进程外,所有进程都需要
                    这个变化的设备发送数据
        group (str): 要处理的通信组。 默认值:“hccl_world_group”。
 ​
    Inputs:
        - **input_x** (Tensor) - 张量的形状是 :math:`(x_1, x_2, ..., x_R)`.
 ​
    Outputs:
        Tensor, 具有与输入相同的形状, i.e., :math:`(x_1, x_2, ..., x_R)`.
        内容取决于“root_rank”设备的数据。
 ​
    Raises:
        如果 root_rank 不是整数或group不是字符串。
 ​
    Supported Platforms:
        ``Ascend`` ``GPU``
 ​
    Examples:
        >>> # This example should be run with multiple processes.
        >>> # Please refer to the tutorial > Distributed Training on mindspore.cn.
        >>> from mindspore import Tensor
        >>> from mindspore import context
        >>> from mindspore.communication import init
        >>> import mindspore.nn as nn
        >>> import mindspore.ops as ops
        >>> import numpy as np
        >>>
        >>> context.set_context(mode=context.GRAPH_MODE)
        >>> init()
        >>> class Net(nn.Cell):
        ...     def __init__(self):
        ...         super(Net, self).__init__()
        ...         self.broadcast = ops.Broadcast(1)
        ...
        ...     def construct(self, x):
        ...         return self.broadcast((x,))
        ...
        >>> input_x = Tensor(np.ones([2, 4]).astype(np.int32))
        >>> net = Net()
        >>> output = net(input_x)
        >>> print(output)
        (Tensor(shape[2,4], dtype=Int32, value=
        [[1, 1, 1, 1],
          [1, 1, 1, 1]]),)
    """
 ​
     @prim_attr_register
     def __init__(self, root_rank, group=GlobalComm.WORLD_COMM_GROUP):
         """Initialize Broadcast."""
         validator.check_value_type('root_rank', root_rank, (int,), self.name)
         validator.check_value_type('group', _get_group(group), (str,), self.name)
         check_hcom_group_valid(group)
         self.add_prim_attr('group', _get_group(group))
         self.add_prim_attr('no_elimilate', True)
 ​
     def infer_shape(self, x_shape):
         return x_shape
 ​
     def infer_dtype(self, x_dtype):
         if not isinstance(x_dtype, tuple):
             raise TypeError(f"{self.name}'s input should be a tuple!")
         for _ele in x_dtype:
             validator.check_tensor_dtype_valid('x', _ele, target_dtypes, self.name)
         return x_dtype
posted @ 2021-12-20 15:16  MS小白  阅读(194)  评论(0)    收藏  举报