Source code for megengine.distributed.server

# -*- coding: utf-8 -*-
import multiprocessing as mp
import threading
import time
from collections import defaultdict
from functools import partial
from socketserver import ThreadingMixIn
from xmlrpc.client import ServerProxy
from xmlrpc.server import SimpleXMLRPCServer

from ..core._imperative_rt.utils import create_mm_server
from ..utils.future import Future


class Methods:
    r"""Distributed Server Method.
    Used for exchange information between distributed nodes.

    Args:
        mm_server_port: multiple machine rpc server port.
    """

    def __init__(self, mm_server_port):
        self.lock = threading.Lock()
        self.mm_server_port = mm_server_port
        self.dict_is_grad = defaultdict(partial(Future, True))
        self.dict_remote_tracer = defaultdict(partial(Future, True))
        self.dict_pack_list = defaultdict(partial(Future, False))
        self.dict_barrier_counter = defaultdict(int)
        self.dict_barrier_event = defaultdict(threading.Event)
        self.user_dict = defaultdict(partial(Future, False))
        self.bcast_dict = {}

    def connect(self):
        r"""Method for checking connection success."""
        return True

    def get_mm_server_port(self):
        r"""Get multiple machine rpc server port."""
        return self.mm_server_port

    def set_is_grad(self, key, is_grad):
        r"""Mark send/recv need gradiants by key.

        Args:
            key: key to match send/recv op.
            is_grad: whether this op need grad.
        """
        with self.lock:
            future = self.dict_is_grad[key]
        future.set(is_grad)
        return True

    def check_is_grad(self, key):
        r"""Check whether send/recv need gradiants.

        Args:
            key: key to match send/recv op.
        """
        with self.lock:
            future = self.dict_is_grad[key]
        ret = future.get()
        with self.lock:
            del self.dict_is_grad[key]
        return ret

    def set_remote_tracer(self, key, tracer_set):
        r"""Set tracer dict for tracing send/recv op.

        Args:
            key: key to match send/recv op.
            tracer_set: valid tracer set.
        """
        with self.lock:
            future = self.dict_remote_tracer[key]
        future.set(tracer_set)
        return True

    def check_remote_tracer(self, key):
        r"""Get tracer dict for send/recv op.

        Args:
            key: key to match send/recv op.
        """
        with self.lock:
            future = self.dict_remote_tracer[key]
        ret = future.get()
        with self.lock:
            del self.dict_remote_tracer[key]
        return ret

    def group_barrier(self, key, size):
        r"""A barrier wait for all group member.

        Args:
            key: group key to match each other.
            size: group size.
        """
        with self.lock:
            self.dict_barrier_counter[key] += 1
            counter = self.dict_barrier_counter[key]
            event = self.dict_barrier_event[key]
        if counter == size:
            del self.dict_barrier_counter[key]
            del self.dict_barrier_event[key]
            event.set()
        else:
            event.wait()
        return True

    def user_set(self, key, val):
        r"""Set user defined key-value pairs across processes."""
        with self.lock:
            future = self.user_dict[key]
        future.set(val)
        return True

    def user_get(self, key):
        r"""Get user defined key-value pairs across processes."""
        with self.lock:
            future = self.user_dict[key]
        return future.get()

    def bcast_val(self, val, key, size):
        with self.lock:
            if key not in self.bcast_dict:
                self.bcast_dict[key] = [Future(False), size]
            arr = self.bcast_dict[key]
        if val is not None:
            arr[0].set(val)
            val = None
        else:
            val = arr[0].get()
        with self.lock:
            cnt = arr[1] - 1
            arr[1] = cnt
            if cnt == 0:
                del self.bcast_dict[key]
        return val

    def _del(self, key):
        with self.lock:
            del self.user_dict[key]

    # thread safe function
    def user_pop(self, key):
        ret = self.user_get(key)
        self._del(key)
        return ret


class ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer):
    pass


def _start_server(py_server_port, queue):
    r"""Start python distributed server and multiple machine server.

    Args:
        py_server_port: python server port.
        mm_server_port: multiple machine server port.
        queue: server port will put in this queue, puts exception when process fails.
    """
    try:
        mm_server_port = create_mm_server("0.0.0.0", 0)
        server = ThreadXMLRPCServer(
            ("0.0.0.0", py_server_port), logRequests=False, allow_none=True
        )
        server.register_instance(Methods(mm_server_port))
        _, py_server_port = server.server_address
        queue.put((py_server_port, mm_server_port))
        server.serve_forever()
    except Exception as e:
        queue.put(e)


[docs]class Server: r"""Distributed Server for distributed training. Should be running at master node. Args: port: python server port. """ def __init__(self, port=0): q = mp.Queue() self.proc = mp.Process(target=_start_server, args=(port, q), daemon=True) self.proc.start() ret = q.get() if isinstance(ret, Exception): raise ret else: self.py_server_port, self.mm_server_port = ret def __del__(self): self.proc.terminate()
class Client: r"""Distributed Client for distributed training. Args: master_ip: ip address of master node. port: port of server at master node. """ def __init__(self, master_ip, port): self.master_ip = master_ip self.port = port self.connect() self.bcast_dict = defaultdict(lambda: 0) def connect(self): r"""Check connection success.""" while True: try: self.proxy = ServerProxy( "http://{}:{}".format(self.master_ip, self.port), allow_none=True ) if self.proxy.connect(): break except: time.sleep(1) def get_mm_server_port(self): r"""Get multiple machine server port.""" while True: try: return self.proxy.get_mm_server_port() except: time.sleep(0.5) def set_is_grad(self, key, is_grad): r"""Mark send/recv need gradiants by key. Args: key: key to match send/recv op. is_grad: whether this op need grad. """ self.proxy.set_is_grad(key, is_grad) def check_is_grad(self, key): r"""Check whether send/recv need gradiants. Args: key: key to match send/recv op. """ return self.proxy.check_is_grad(key) def set_remote_tracer(self, key, tracer_set): r"""Set tracer dict for tracing send/recv op. Args: key: key to match send/recv op. tracer_set: valid tracer set. """ self.proxy.set_remote_tracer(key, tracer_set) def check_remote_tracer(self, key): r"""Get tracer dict for send/recv op. Args: key: key to match send/recv op. """ return self.proxy.check_remote_tracer(key) def group_barrier(self, key, size): r"""A barrier wait for all group member. Args: key: group key to match each other. size: group size. """ # FIXME: group_barrier is not idempotent while True: try: self.proxy.group_barrier(key, size) return except: time.sleep(0.5) def user_set(self, key, val): r"""Set user defined key-value pairs across processes.""" return self.proxy.user_set(key, val) def user_get(self, key): r"""Get user defined key-value pairs across processes.""" return self.proxy.user_get(key) def user_pop(self, key): r"""Get user defined key-value pairs and delete the resources when the get is done""" return self.proxy.user_pop(key) def bcast_val(self, val, key, size): idx = self.bcast_dict[key] + 1 self.bcast_dict[key] = idx key = key + "_bcast_" + str(idx) return self.proxy.bcast_val(val, key, size) def main(port=0, verbose=True): mm_server_port = create_mm_server("0.0.0.0", 0) server = ThreadXMLRPCServer(("0.0.0.0", port), logRequests=verbose) server.register_instance(Methods(mm_server_port)) _, port = server.server_address print("serving on port", port) server.serve_forever() if __name__ == "__main__": import argparse ap = argparse.ArgumentParser() ap.add_argument("-p", "--port", type=int, default=0) ap.add_argument("-v", "--verbose", type=bool, default=True) args = ap.parse_args() main(port=args.port, verbose=args.verbose)