megengine.distributed.functional 源代码

# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Optional, Tuple

from ..core._imperative_rt.core2 import apply
from ..core.autodiff.grad import _grad_manager_dict
from ..core.ops.builtin import CollectiveComm, Copy, PyOpBase, RemoteRecv, RemoteSend
from ..core.tensor.utils import isscalar, setscalar
from ..device import get_default_device
from ..tensor import Tensor
from .group import WORLD, Group, get_backend, get_client, get_mm_server_addr, get_rank

__all__ = [
    "reduce_sum",
    "broadcast",
    "all_gather",
    "reduce_scatter_sum",
    "all_reduce_sum",
    "all_reduce_max",
    "all_reduce_min",
    "gather",
    "scatter",
    "all_to_all",
    "remote_send",
    "remote_recv",
]


def collective_comm(inp, mode, group, device):
    """Helper function for applying collective communication functions."""
    assert isinstance(group, Group)
    if group is None:
        return inp
    addr, port = get_mm_server_addr()
    op = CollectiveComm(
        key=group.key,
        nr_devices=group.size,
        rank=group.rank,
        is_root=(group.rank == 0),
        local_grad=False,
        addr=addr,
        port=port,
        mode=mode,
        dtype=inp.dtype,
        backend=get_backend(),
        comp_node=device,
    )
    (result,) = apply(op, inp)
    # assume all workers have homogeneous shape
    if mode in (
        CollectiveComm.Mode.REDUCE_SUM,
        CollectiveComm.Mode.BROADCAST,
        CollectiveComm.Mode.ALL_REDUCE_SUM,
        CollectiveComm.Mode.ALL_REDUCE_MAX,
        CollectiveComm.Mode.ALL_REDUCE_MIN,
    ):
        if isscalar(inp):
            setscalar(result)
    return result


[文档]def reduce_sum( inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" ) -> Tensor: """ Create reduce_sum operator for collective communication. :param inp: input tensor. :param group: communication group. :param device: execution device. """ mode = CollectiveComm.Mode.REDUCE_SUM return collective_comm(inp, mode, group, device)
[文档]def broadcast( inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" ) -> Tensor: """ Create broadcast operator for collective communication. :param inp: input tensor. :param group: communication group. :param device: execution device. """ mode = CollectiveComm.Mode.BROADCAST return collective_comm(inp, mode, group, device)
[文档]def all_gather( inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" ) -> Tensor: """ Create all_gather operator for collective communication. :param inp: input tensor. :param group: communication group. :param device: execution device. """ mode = CollectiveComm.Mode.ALL_GATHER return collective_comm(inp, mode, group, device)
[文档]def reduce_scatter_sum( inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" ) -> Tensor: """ Create reduce_scatter_sum operator for collective communication. :param inp: input tensor. :param group: communication group. :param device: execution device. """ mode = CollectiveComm.Mode.REDUCE_SCATTER_SUM return collective_comm(inp, mode, group, device)
[文档]def all_reduce_sum( inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" ) -> Tensor: """ Create all_reduce_sum operator for collective communication. :param inp: input tensor. :param group: communication group. :param device: execution device. """ mode = CollectiveComm.Mode.ALL_REDUCE_SUM return collective_comm(inp, mode, group, device)
[文档]def all_reduce_max( inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" ) -> Tensor: """ Create all_reduce_max operator for collective communication. :param inp: input tensor. :param group: communication group. :param device: execution device. """ mode = CollectiveComm.Mode.ALL_REDUCE_MAX return collective_comm(inp, mode, group, device)
[文档]def all_reduce_min( inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" ) -> Tensor: """ Create all_reduce_min operator for collective communication. :param inp: input tensor. :param group: communication group. :param device: execution device. """ mode = CollectiveComm.Mode.ALL_REDUCE_MIN return collective_comm(inp, mode, group, device)
[文档]def gather( inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" ) -> Tensor: """ Create gather operator for collective communication. :param inp: input tensor. :param group: communication group. :param device: execution device. """ mode = CollectiveComm.Mode.GATHER return collective_comm(inp, mode, group, device)
[文档]def scatter( inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" ) -> Tensor: """ Create scatter operator for collective communication. :param inp: input tensor. :param group: communication group. :param device: execution device. """ mode = CollectiveComm.Mode.SCATTER return collective_comm(inp, mode, group, device)
[文档]def all_to_all( inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" ) -> Tensor: """ Create all_to_all operator for collective communication. :param inp: input tensor. :param group: communication group. :param device: execution device. """ mode = CollectiveComm.Mode.ALL_TO_ALL return collective_comm(inp, mode, group, device)
class _RemoteSend(PyOpBase): def __init__(self, op: RemoteSend): self.op = op def _default_rule(self, data): return apply(self.op, data) def _grad_rule(self, data): self.dtype = data.dtype self.shape = data.shape self.device = data.device (self.dummy,) = self._default_rule(data) return self.dummy, self.backward def backward(self, grad): assert grad is None if get_client().check_is_grad(self.op.key): return remote_recv( self.op.rank_to, self.shape, self.dtype, device=str(self.device), inp=self.dummy, ) class _RemoteRecv(PyOpBase): def __init__(self, op: RemoteRecv): self.op = op def _default_rule(self, dummy): return apply(self.op, dummy) def _grad_rule(self, dummy): return self._default_rule(dummy), self.backward def backward(self, grad): get_client().set_is_grad(self.op.key, grad is not None) if grad is not None: remote_send(grad, self.op.rank_from)
[文档]def remote_send(inp: Tensor, dest_rank: int) -> Tensor: """ Send a Tensor to a remote process. :param inp: tensor to send. :param dest_rank: destination process rank. """ key = "{}->{}".format(get_rank(), dest_rank) grad_keys = {} for n, g in _grad_manager_dict.items(): if g._is_attached_to(inp): grad_keys[n] = g get_client().set_remote_tracer(key, grad_keys) op = RemoteSend() op.key = key op.addr, op.port = get_mm_server_addr() op.rank_to = dest_rank (dummy,) = apply(_RemoteSend(op), inp) for g in grad_keys.values(): g._refkeeper.append(dummy)
[文档]def remote_recv( src_rank: int, shape: Tuple[int], dtype: type, device: Optional[str] = None, inp=None, ) -> Tensor: """ Receive a Tensor from a remote process. :param src_rank: source process rank. :param shape: the shape of the tensor to receive. :param dtype: the data type of the tensor to receive. :param device: the device to place the received tensor. :param inp: dummy input to determine recved tensor type """ key = "{}->{}".format(src_rank, get_rank()) if device is None: device = get_default_device() # dummy input if inp is None: inp = Tensor([0], device=device) tracer_set = get_client().check_remote_tracer(key) for n in tracer_set: g = _grad_manager_dict.get(n) if g is not None: g.wrt(inp) g._refkeeper.append(inp) _isscalar = False if len(shape) == 0: shape = (1,) _isscalar = True op = RemoteRecv() op.key = key op.cn = device op.shape = shape op.dtype = dtype op.addr, op.port = get_mm_server_addr() op.rank_from = src_rank (ret,) = apply(_RemoteRecv(op), inp) if _isscalar: setscalar(ret) return ret