MindSpore分布式并行训练—Ascend平台通信方法

2.3 MindSpore的Ascend平台通信

对于Ascend AI处理器,MindSpore分布式并行训练的通信使用了华为集合通信库Huawei Collective Communication Library(HCCL)。mindspore.communication.management中封装了HCCL提供的集合通信接口,方便用户配置分布式信息。

在该文件中,MindSpore提供了HCCL分布式通信的后端初始化方法,同时还提供了多种检查和查询方法,也提供了相关资源的释放方法,是一个出色的封装接口。

mindspore.communication.management的代码如下:

# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Communication management API"""
from mindspore import context
from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched
from ._comm_helper import Backend, _get_rank_helper, _get_size_helper, \
    _get_world_rank_from_group_rank_helper, _get_group_rank_from_world_rank_helper, \
    _create_group_helper, _destroy_group_helper, HCCL_WORLD_COMM_GROUP, NCCL_WORLD_COMM_GROUP, \
    _get_local_rank_helper, _get_local_size_helper, GlobalComm
from .._c_expression import init_hccl, finalize_hccl, init_gpu_collective


__all__ = ["init", "release", "get_rank", "get_local_rank", "get_group_size",
           "get_local_rank_size", "get_world_rank_from_group_rank",
           "get_group_rank_from_world_rank", "create_group", "destroy_group",
           "HCCL_WORLD_COMM_GROUP", "NCCL_WORLD_COMM_GROUP"]

DEFAULT_WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP


def _get_group(group):
    """返回全局通信组,如果`group`是`DEFAULT_WORLD_COMM_GROUP`."""
    if group == DEFAULT_WORLD_COMM_GROUP:
        return GlobalComm.WORLD_COMM_GROUP
    return group

def _check_task_sink_envs():
    """
    检查任务接收器环境变量是否导出。
    如果已导出任务接收器环境变量,则返回True,否则返回False。
    """
    import os
    task_sink = os.getenv("GRAPH_OP_RUN")
    if task_sink:
        try:
            if int(task_sink) == 1:
                return False
        except ValueError:
            return True
    return True


def _check_parallel_envs():
    """
    检查并行环境变量是否已导出。
    raises:
        RuntimeError:如果并行环境变量未导出或导出到错误值。
    """
    if not GlobalComm.CHECK_ENVS:
        return
    import os
    rank_id_str = os.getenv("RANK_ID")
    if not rank_id_str:
        raise RuntimeError("Environment variables RANK_ID has not been exported")
    try:
        int(rank_id_str)
    except ValueError:
        print("RANK_ID should be number")
    rank_table_file_str = os.getenv("MINDSPORE_HCCL_CONFIG_PATH")
    rank_table_file_str_old = os.getenv("RANK_TABLE_FILE")
    if not rank_table_file_str and not rank_table_file_str_old:
        raise RuntimeError("Get hccl rank_table_file failed, "
                           "please export MINDSPORE_HCCL_CONFIG_PATH or RANK_TABLE_FILE.")

def init(backend_name=None):
    """
    初始化分布式后端,例如 HCCL/NCCL,在使用通讯服务前需要完成。

    注意:
        HCCL的全称为Huawei Collective Communication Library。
        NCCL的全称是NVIDIA Collective Communication Library。
        这个方法应该在设置context之后使用。

    参数:
        backend_name (str): 后端,使用 HCCL/NCCL。 如果未设置,则通过 device_target 进行推断。 默认值:无。

    Raises:
        TypeError:如果 `backend_name` 不是字符串。
        RuntimeError: 如果设备目标无效,
        或后端无效,
        或分布式初始化失败,
        或环境变量 RANK_ID/MINDSPORE_HCCL_CONFIG_PATH在后端是HCCL时未导出。
        ValueError:如果环境变量 RANK_ID 尚未导出为数字。

    Examples:
        >>> from mindspore.context import set_context
        >>> set_context(device_target="Ascend")
        >>> init()
    """
    if _is_role_pserver() or _is_role_sched():
        return
    task_sink = _check_task_sink_envs()
    device_target = context.get_context("device_target")
    mode = context.get_context("mode")
    mpi_init = False
    if not task_sink and mode == context.GRAPH_MODE:
        mpi_init = True

    if backend_name is None:
        if device_target == "Ascend":
            backend_name = "hccl"
        elif device_target == "GPU":
            backend_name = "nccl"
        else:
            raise RuntimeError("Device target {} is not supported in parallel initialization, "
                               "please use Ascend or GPU.".format(device_target))
    if not isinstance(backend_name, str):
        raise TypeError("Backend name must be a string, but got {}".format(type(backend_name)))

    if backend_name == "hccl":
        if device_target != "Ascend":
            raise RuntimeError("Device target should be 'Ascend' to init hccl, but got {}".format(device_target))
        if not mpi_init:
            _check_parallel_envs()
            GlobalComm.BACKEND = Backend("hccl")
        else:
            GlobalComm.BACKEND = Backend("hccl_mpi")
        init_hccl()
        GlobalComm.WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
        GlobalComm.INITED = True
    elif backend_name == "nccl":
        init_gpu_collective()
        GlobalComm.BACKEND = Backend("nccl")
        GlobalComm.WORLD_COMM_GROUP = NCCL_WORLD_COMM_GROUP
        GlobalComm.INITED = True
    else:
        raise RuntimeError("Backend name {} is not supported.".format(backend_name))


def release():
    """
    释放分布式资源。

    Note:
        这个方法应该在init()之后使用。

    Raises:
        RuntimeError:如果释放分布式资源失败。
    """
    finalize_hccl()


def get_rank(group=GlobalComm.WORLD_COMM_GROUP):
    """
    获取当前设备在指定集体通信组中的rank ID。

    Note:
        这个方法应该在init()之后使用。

    Args:
        group (str):要处理的通信组。一般情况下,需要使用create group创建组,否则使用default group。默认值:WORLD_COMM_GROUP。

    Returns:
        int, 群组内主叫进程的等级号。

    Raises:
        TypeError: 如果group不是字符串。
        ValueError: 如果backend不合法。
        RuntimeError: 如果HCCL/NCCL不可用。
    """
    return _get_rank_helper(group=_get_group(group), backend=GlobalComm.BACKEND)


def get_local_rank(group=GlobalComm.WORLD_COMM_GROUP):
    """
    获取指定集体通信组中当前设备的本地等级ID。

    Note:
        MindSpore的GPU版本不支持这个方法。
        这个方法应该在init()之后使用。

    Args:
        group (str):要处理的通信组。一般情况下,需要使用create group创建组,否则使用default group。默认值:WORLD_COMM_GROUP。

    Returns:
        int, 群组内主叫进程的等级号。

    Raises:
        TypeError: 如果group不是字符串。
        ValueError: 如果backend不合法。
        RuntimeError: 如果没有可用的hccl或MindSpore是GPU版本。
    """
    return _get_local_rank_helper(group=_get_group(group), backend=GlobalComm.BACKEND)


def get_group_size(group=GlobalComm.WORLD_COMM_GROUP):
    """
    获取指定的集体通信组的大小。

    Note:
        这个方法应该在init()之后使用。

    Args:
        要处理的通信组。一般情况下,需要使用create group创建组,否则使用default group。默认值:WORLD_COMM_GROUP。

    Returns:
        int, group的rank size值

    Raises:
        TypeError: 如果group不是字符串。
        ValueError: 如果backend不合法。
        RuntimeError: 如果HCCL/NCCL不可用。
    """
    return _get_size_helper(group=_get_group(group), backend=GlobalComm.BACKEND)


def get_local_rank_size(group=GlobalComm.WORLD_COMM_GROUP):
    """
    获取指定集合通信组的本地秩大小。

    Note:
        MindSpore的GPU版本不支持这个方法。
        这个方法应该在init()之后使用。

    Args:
        group (str): 要处理的通信组。一般情况下,需要使用create group创建组,否则使用default group。默认值:WORLD_COMM_GROUP。

    Returns:
        int, 调用进程在组内的本地等级大小。

    Raises:
        TypeError: 如果group不是字符串。
        ValueError: 如果backend不合法。
        RuntimeError: 如果没有可用的hccl或MindSpore是GPU版本。
    """
    return _get_local_size_helper(group=_get_group(group), backend=GlobalComm.BACKEND)


def get_world_rank_from_group_rank(group, group_rank_id):
    """
    根据用户通信组中的rank ID, 获取对应的全局通讯组中的rank ID。

    Note:
        MindSpore的GPU版本不支持这个方法。
        这个方法应该在init()之后使用。
        参数组不应是“hccl_world_group”。

    Args:
        group (str): 要处理的通信组。 该group由 create_group 创建。
        group_rank_id (int): 通信组的rank ID。

    Returns:
        int, 全局通讯组中的rank ID

    Raises:
        TypeError: 如果`group_rank_id` 不是整数或group不是字符串。
        ValueError: 如果group是“hccl_world_group”或后端无效.
        RuntimeError: 如果 HCCL/NCCL 不可用或 MindSpore 是 GPU 版本。

    Examples:
        >>> from mindspore.context import set_context
        >>> set_context(device_target="Ascend")
        >>> init()
        >>> group = "0-4"
        >>> rank_ids = [0,4]
        >>> create_group(group, rank_ids)
        >>> world_rank_id = get_world_rank_from_group_rank(group, 1)
        >>> print("world_rank_id is: ", world_rank_id) # world_rank_id is: 4
    """
    return _get_world_rank_from_group_rank_helper(group=group, group_rank_id=group_rank_id, backend=GlobalComm.BACKEND)


def get_group_rank_from_world_rank(world_rank_id, group):
    """
    根据全局通讯组中的rank ID, 获取对应的指定用户通讯组中的rank ID。

    Note:
        MindSpore的GPU版本不支持这个方法。
        这个方法应该在init()之后使用。
        参数组不应是“hccl_world_group”。

    Args:
        world_rank_id (int): 全局通信组的rank ID。
        group (str): 要处理的通信组。 该group由 create_group 创建。

    Returns:
        int, 用户通讯组中的rank ID

    Raises:
        TypeError: 如果`world_rank_id` 不是整数或group不是字符串。
        ValueError: 如果group是“hccl_world_group”或后端无效.
        RuntimeError: 如果 HCCL/NCCL 不可用或 MindSpore 是 GPU 版本。

    Examples:
        >>> from mindspore.context import set_context
        >>> set_context(device_target="Ascend")
        >>> init()
        >>> group = "0-4"
        >>> rank_ids = [0,4]
        >>> create_group(group, rank_ids)
        >>> group_rank_id = get_group_rank_from_world_rank(4, group)
        >>> print("group_rank_id is: ", group_rank_id) # group_rank_id is: 1
    """
    return _get_group_rank_from_world_rank_helper(world_rank_id=world_rank_id, group=group, backend=GlobalComm.BACKEND)


def create_group(group, rank_ids):
    """
    创建用户集体通信组。

    Note:
        MindSpore 的 GPU 版本不支持此方法。
        rank_ids 的大小应该大于 1。
        Rank_ids 不应有重复数据。
        这个方法应该在 init()之后使用。
        仅在 PyNative 模式下支持全局单个通信组。

    Args:
        group (str): 要创建的通信组的名称。
        rank_ids (list): 设备id列表。

    Raises:
        TypeError: 如果group不是字符串或' rank id '不是列表。
        ValueError: 如果' rank id '大小不大于1,或' rank id '有重复的数据,或后端无效。
        RuntimeError: 如果没有可用的hccl或MindSpore是GPU版本。

    Examples:
        >>> from mindspore.context import set_context
        >>> set_context(device_target="Ascend")
        >>> init()
        >>> group = "0-8"
        >>> rank_ids = [0,8]
        >>> create_group(group, rank_ids)
    """
    _create_group_helper(group, rank_ids, backend=GlobalComm.BACKEND)


def destroy_group(group):
    """
    销毁用户集体通信组。

    Note:
        MindSpore 的 GPU 版本不支持此方法。
        参数group不应是“hccl_world_group”。
        这个方法应该在 init() 之后使用。

    Args:
        group (str): 要销毁的通信组,该group应由create_group 创建。

    Raises:
        TypeError: 如果group不是字符串或' rank id '不是列表。
        ValueError: 如果' rank id '大小不大于1,或' rank id '有重复的数据,或后端无效。
        RuntimeError: 如果没有可用的hccl或MindSpore是GPU版本。
    """
    _destroy_group_helper(group, backend=GlobalComm.BACKEND)
posted @ 2021-12-20 15:13  MS小白  阅读(858)  评论(0编辑  收藏  举报