MindSpore分布式并行训练—GPU平台通信方法

2.2 MindSpore的GPU平台通信

在GPU硬件平台上,MindSpore分布式并行训练的通信使用的是NCCL;采用的多进程通信库是OpenMPI。NCCL是Nvidia Collective multi-GPU Communication Library的简称,是英伟达提供的多GPU集合通信方案,在实现上参考了MPI接口,同时进行了诸多针对性优化。它是一个实现多GPU的collective communication通信(all-gather, reduce, broadcast)库,Nvidia做了很多优化,以在PCIe、Nvlink、InfiniBand上实现较高的通信速度。Open MPI 项目是一个开源消息传递接口实现,由学术、研究和行业合作伙伴组成的联盟开发和维护。 因此,Open MPI 能够结合来自高性能计算社区的所有专业知识、技术和资源,以构建可用的最佳 MPI 库。 Open MPI 为系统和软件供应商、应用程序开发人员和计算机科学研究人员提供了优势。MPI的相关配置内容通过文件mindspore.parallel.mpi._mpi_config.py 进行环境配置。

# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
MPI配置,用于配置MPI环境。
"""
import threading
from mindspore._c_expression import MpiConfig
from mindspore._checkparam import args_type_check


class _MpiConfig:
    """
    _MpiConfig 是控制 MPI 的配置工具

    Note:
        不建议通过实例化 MpiConfig 对象来创建配置。
        应该使用 MpiConfig() 来获取配置,因为 MpiConfig 是单例的。
    """
    _instance = None
    _instance_lock = threading.Lock()

    def __init__(self):
        self._mpiconfig_handle = MpiConfig.get_instance()

    def __new__(cls, *args, **kwargs):
        if cls._instance is None:
            cls._instance_lock.acquire()
            cls._instance = object.__new__(cls)
            cls._instance_lock.release()
        return cls._instance

    def __getattribute__(self, attr):
        value = object.__getattribute__(self, attr)
        if attr == "_mpiconfig_handle" and value is None:
            raise ValueError("mpiconfig handle is none in MpiConfig!!!")
        return value

    @property
    def enable_mpi(self):
        return self._mpiconfig_handle.get_enable_mpi()

    @enable_mpi.setter
    def enable_mpi(self, enable_mpi):
        self._mpiconfig_handle.set_enable_mpi(enable_mpi)

_k_mpi_config = None


def _mpi_config():
    """
    获取全局的mpi config,如果没有创建mpi config,则新建一个。

    Returns:
        _MpiConfig,全局 mpi 配置。
    """
    global _k_mpi_config
    if _k_mpi_config is None:
        _k_mpi_config = _MpiConfig()
    return _k_mpi_config


@args_type_check(enable_mpi=bool)
def _set_mpi_config(**kwargs):
    """
    为运行环境设置 mpi 配置。

    应该在运行程序之前配置mpi config。如果没有配置,默认情况下,mpi 模块将被禁用。

    Note:
        设置属性时需要属性名称。

    Args:
        enable_mpi (bool): 是否开启mpi。 默认值:False。

    Raises:
        ValueError: 如果输入键不是 mpi 配置中的属性。

    Examples:
        >>> mpiconfig.set_mpi_config(enable_mpi=True)
    """
    for key, value in kwargs.items():
        if not hasattr(_mpi_config(), key):
            raise ValueError("Set mpi config keyword %s is not recognized!" % key)
        setattr(_mpi_config(), key, value)


def _get_mpi_config(attr_key):
    """
    根据输入键获取mpi config属性值。

    Args:
        attr_key (str): 属性的键。

    Returns:
        Object, 给定属性键的值。

    Raises:
        ValueError: 如果输入键不是config中的属性。
    """
    if not hasattr(_mpi_config(), attr_key):
        raise ValueError("Get context keyword %s is not recognized!" % attr_key)
    return getattr(_mpi_config(), attr_key)
posted @ 2021-12-20 15:13  MS小白  阅读(295)  评论(0)    收藏  举报