# -*- coding: utf-8 -*-
import functools
import hashlib
import os
import sys
import types
from typing import Any, List
from urllib.parse import urlparse
from megengine.utils.http_download import download_from_url
from ..distributed import is_distributed
from ..logger import get_logger
from ..serialization import load as _mge_load_serialized
from .const import (
    DEFAULT_CACHE_DIR,
    DEFAULT_GIT_HOST,
    DEFAULT_PROTOCOL,
    ENV_MGE_HOME,
    ENV_XDG_CACHE_HOME,
    HUBCONF,
    HUBDEPENDENCY,
)
from .exceptions import InvalidProtocol
from .fetcher import GitHTTPSFetcher, GitSSHFetcher
from .tools import cd, check_module_exists, load_module
logger = get_logger(__name__)
PROTOCOLS = {
    "HTTPS": GitHTTPSFetcher,
    "SSH": GitSSHFetcher,
}
def _get_megengine_home() -> str:
    r"""MGE_HOME setting complies with the XDG Base Directory Specification"""
    megengine_home = os.path.expanduser(
        os.getenv(
            ENV_MGE_HOME,
            os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), "megengine"),
        )
    )
    return megengine_home
def _get_repo(
    git_host: str,
    repo_info: str,
    use_cache: bool = False,
    commit: str = None,
    protocol: str = DEFAULT_PROTOCOL,
) -> str:
    if protocol not in PROTOCOLS:
        raise InvalidProtocol(
            "Invalid protocol, the value should be one of {}.".format(
                ", ".join(PROTOCOLS.keys())
            )
        )
    cache_dir = os.path.expanduser(os.path.join(_get_megengine_home(), "hub"))
    with cd(cache_dir):
        fetcher = PROTOCOLS[protocol]
        repo_dir = fetcher.fetch(git_host, repo_info, use_cache, commit)
        return os.path.join(cache_dir, repo_dir)
def _check_dependencies(module: types.ModuleType) -> None:
    if not hasattr(module, HUBDEPENDENCY):
        return
    dependencies = getattr(module, HUBDEPENDENCY)
    if not dependencies:
        return
    missing_deps = [m for m in dependencies if not check_module_exists(m)]
    if len(missing_deps):
        raise RuntimeError("Missing dependencies: {}".format(", ".join(missing_deps)))
def _init_hub(
    repo_info: str,
    git_host: str,
    use_cache: bool = True,
    commit: str = None,
    protocol: str = DEFAULT_PROTOCOL,
):
    r"""Imports hubmodule like python import.
    Args:
        repo_info: a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional
            tag/branch. The default branch is ``master`` if not specified. Eg: ``"brain_sdk/MegBrain[:hub]"``
        git_host: host address of git repo. Eg: github.com
        use_cache: whether to use locally cached code or completely re-fetch.
        commit: commit id on github or gitlab.
        protocol: which protocol to use to get the repo, and HTTPS protocol only supports public repo on github.
            The value should be one of HTTPS, SSH.
    Returns:
        a python module.
    """
    cache_dir = os.path.expanduser(os.path.join(_get_megengine_home(), "hub"))
    os.makedirs(cache_dir, exist_ok=True)
    absolute_repo_dir = _get_repo(
        git_host, repo_info, use_cache=use_cache, commit=commit, protocol=protocol
    )
    sys.path.insert(0, absolute_repo_dir)
    hubmodule = load_module(HUBCONF, os.path.join(absolute_repo_dir, HUBCONF))
    sys.path.remove(absolute_repo_dir)
    return hubmodule
[文档]@functools.wraps(_init_hub)
def import_module(*args, **kwargs):
    return _init_hub(*args, **kwargs) 
[文档]def list(
    repo_info: str,
    git_host: str = DEFAULT_GIT_HOST,
    use_cache: bool = True,
    commit: str = None,
    protocol: str = DEFAULT_PROTOCOL,
) -> List[str]:
    r"""Lists all entrypoints available in repo hubconf.
    Args:
        repo_info: a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional
            tag/branch. The default branch is ``master`` if not specified. Eg: ``"brain_sdk/MegBrain[:hub]"``
        git_host: host address of git repo. Eg: github.com
        use_cache: whether to use locally cached code or completely re-fetch.
        commit: commit id on github or gitlab.
        protocol: which protocol to use to get the repo, and HTTPS protocol only supports public repo on github.
            The value should be one of HTTPS, SSH.
    Returns:
        all entrypoint names of the model.
    """
    hubmodule = _init_hub(repo_info, git_host, use_cache, commit, protocol)
    return [
        _
        for _ in dir(hubmodule)
        if not _.startswith("__") and callable(getattr(hubmodule, _))
    ] 
[文档]def load(
    repo_info: str,
    entry: str,
    *args,
    git_host: str = DEFAULT_GIT_HOST,
    use_cache: bool = True,
    commit: str = None,
    protocol: str = DEFAULT_PROTOCOL,
    **kwargs
) -> Any:
    r"""Loads model from github or gitlab repo, with pretrained weights.
    Args:
        repo_info: a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional
            tag/branch. The default branch is ``master`` if not specified. Eg: ``"brain_sdk/MegBrain[:hub]"``
        entry: an entrypoint defined in hubconf.
        git_host: host address of git repo. Eg: github.com
        use_cache: whether to use locally cached code or completely re-fetch.
        commit: commit id on github or gitlab.
        protocol: which protocol to use to get the repo, and HTTPS protocol only supports public repo on github.
            The value should be one of HTTPS, SSH.
    Returns:
        a single model with corresponding pretrained weights.
    """
    hubmodule = _init_hub(repo_info, git_host, use_cache, commit, protocol)
    if not hasattr(hubmodule, entry) or not callable(getattr(hubmodule, entry)):
        raise RuntimeError("Cannot find callable {} in hubconf.py".format(entry))
    _check_dependencies(hubmodule)
    module = getattr(hubmodule, entry)(*args, **kwargs)
    return module 
[文档]def help(
    repo_info: str,
    entry: str,
    git_host: str = DEFAULT_GIT_HOST,
    use_cache: bool = True,
    commit: str = None,
    protocol: str = DEFAULT_PROTOCOL,
) -> str:
    r"""This function returns docstring of entrypoint ``entry`` by following steps:
    1. Pull the repo code specified by git and repo_info.
    2. Load the entry defined in repo's hubconf.py
    3. Return docstring of function entry.
    Args:
        repo_info: a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional
            tag/branch. The default branch is ``master`` if not specified. Eg: ``"brain_sdk/MegBrain[:hub]"``
        entry: an entrypoint defined in hubconf.py
        git_host: host address of git repo. Eg: github.com
        use_cache: whether to use locally cached code or completely re-fetch.
        commit: commit id on github or gitlab.
        protocol: which protocol to use to get the repo, and HTTPS protocol only supports public repo on github.
            The value should be one of HTTPS, SSH.
    Returns:
        docstring of entrypoint ``entry``.
    """
    hubmodule = _init_hub(repo_info, git_host, use_cache, commit, protocol)
    if not hasattr(hubmodule, entry) or not callable(getattr(hubmodule, entry)):
        raise RuntimeError("Cannot find callable {} in hubconf.py".format(entry))
    doc = getattr(hubmodule, entry).__doc__
    return doc 
[文档]def load_serialized_obj_from_url(url: str, model_dir=None) -> Any:
    """Loads MegEngine serialized object from the given URL.
    If the object is already present in ``model_dir``, it's deserialized and
    returned. If no ``model_dir`` is specified, it will be ``MGE_HOME/serialized``.
    Args:
        url: url to serialized object.
        model_dir: dir to cache target serialized file.
    Returns:
        loaded object.
    """
    if model_dir is None:
        model_dir = os.path.join(_get_megengine_home(), "serialized")
    os.makedirs(model_dir, exist_ok=True)
    parts = urlparse(url)
    filename = os.path.basename(parts.path)
    # use hash as prefix to avoid filename conflict from different urls
    sha256 = hashlib.sha256()
    sha256.update(url.encode())
    digest = sha256.hexdigest()[:6]
    filename = digest + "_" + filename
    cached_file = os.path.join(model_dir, filename)
    logger.info(
        "load_serialized_obj_from_url: download to or using cached %s", cached_file
    )
    if not os.path.exists(cached_file):
        if is_distributed():
            logger.warning(
                "Downloading serialized object in DISTRIBUTED mode\n"
                "    File may be downloaded multiple times. We recommend\n"
                "    users to download in single process first."
            )
        download_from_url(url, cached_file)
    state_dict = _mge_load_serialized(cached_file)
    return state_dict 
[文档]class pretrained:
    r"""Decorator which helps to download pretrained weights from the given url. Including fs, s3, http(s).
    For example, we can decorate a resnet18 function as follows
    .. code-block::
        @hub.pretrained("https://url/to/pretrained_resnet18.pkl")
        def resnet18(**kwargs):
    Returns:
        When decorated function is called with ``pretrained=True``, MegEngine will automatically
        download and fill the returned model with pretrained weights.
    """
    def __init__(self, url):
        self.url = url
    def __call__(self, func):
        @functools.wraps(func)
        def pretrained_model_func(
            pretrained=False, **kwargs
        ):  # pylint: disable=redefined-outer-name
            model = func(**kwargs)
            if pretrained:
                weights = load_serialized_obj_from_url(self.url)
                model.load_state_dict(weights)
            return model
        return pretrained_model_func 
__all__ = [
    "list",
    "load",
    "help",
    "load_serialized_obj_from_url",
    "pretrained",
    "import_module",
]