Shortcuts

Source code for torch.hub

from __future__ import absolute_import, division, print_function, unicode_literals
import errno
import hashlib
import os
import re
import shutil
import sys
import tempfile
import torch
import warnings
import zipfile

if sys.version_info[0] == 2:
    from urlparse import urlparse
    from urllib2 import urlopen  # noqa f811
else:
    from urllib.request import urlopen
    from urllib.parse import urlparse  # noqa: F401

try:
    from tqdm import tqdm
except ImportError:
    # fake tqdm if it's not installed
    class tqdm(object):

        def __init__(self, total=None, disable=False,
                     unit=None, unit_scale=None, unit_divisor=None):
            self.total = total
            self.disable = disable
            self.n = 0
            # ignore unit, unit_scale, unit_divisor; they're just for real tqdm

        def update(self, n):
            if self.disable:
                return

            self.n += n
            if self.total is None:
                sys.stderr.write("\r{0:.1f} bytes".format(self.n))
            else:
                sys.stderr.write("\r{0:.1f}%".format(100 * self.n / float(self.total)))
            sys.stderr.flush()

        def __enter__(self):
            return self

        def __exit__(self, exc_type, exc_val, exc_tb):
            if self.disable:
                return

            sys.stderr.write('\n')

# matches bfd8deac from resnet18-bfd8deac.pth
HASH_REGEX = re.compile(r'-([a-f0-9]*)\.')

MASTER_BRANCH = 'master'
ENV_TORCH_HOME = 'TORCH_HOME'
ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
DEFAULT_CACHE_DIR = '~/.cache'
VAR_DEPENDENCY = 'dependencies'
MODULE_HUBCONF = 'hubconf.py'
READ_DATA_CHUNK = 8192
hub_dir = None


# Copied from tools/shared/module_loader to be included in torch package
def import_module(name, path):
    if sys.version_info >= (3, 5):
        import importlib.util
        spec = importlib.util.spec_from_file_location(name, path)
        module = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(module)
        return module
    elif sys.version_info >= (3, 0):
        from importlib.machinery import SourceFileLoader
        return SourceFileLoader(name, path).load_module()
    else:
        import imp
        return imp.load_source(name, path)


def _remove_if_exists(path):
    if os.path.exists(path):
        if os.path.isfile(path):
            os.remove(path)
        else:
            shutil.rmtree(path)


def _git_archive_link(repo_owner, repo_name, branch):
    return 'https://github.com/{}/{}/archive/{}.zip'.format(repo_owner, repo_name, branch)


def _load_attr_from_module(module, func_name):
    # Check if callable is defined in the module
    if func_name not in dir(module):
        return None
    return getattr(module, func_name)


def _get_torch_home():
    torch_home = os.path.expanduser(
        os.getenv(ENV_TORCH_HOME,
                  os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'torch')))
    return torch_home


def _setup_hubdir():
    global hub_dir
    # Issue warning to move data if old env is set
    if os.getenv('TORCH_HUB'):
        warnings.warn('TORCH_HUB is deprecated, please use env TORCH_HOME instead')

    if hub_dir is None:
        torch_home = _get_torch_home()
        hub_dir = os.path.join(torch_home, 'hub')

    if not os.path.exists(hub_dir):
        os.makedirs(hub_dir)


def _parse_repo_info(github):
    branch = MASTER_BRANCH
    if ':' in github:
        repo_info, branch = github.split(':')
    else:
        repo_info = github
    repo_owner, repo_name = repo_info.split('/')
    return repo_owner, repo_name, branch


def _get_cache_or_reload(github, force_reload, verbose=True):
    # Parse github repo information
    repo_owner, repo_name, branch = _parse_repo_info(github)

    # Github renames folder repo-v1.x.x to repo-1.x.x
    # We don't know the repo name before downloading the zip file
    # and inspect name from it.
    # To check if cached repo exists, we need to normalize folder names.
    repo_dir = os.path.join(hub_dir, '_'.join([repo_owner, repo_name, branch]))

    use_cache = (not force_reload) and os.path.exists(repo_dir)

    if use_cache:
        if verbose:
            sys.stderr.write('Using cache found in {}\n'.format(repo_dir))
    else:
        cached_file = os.path.join(hub_dir, branch + '.zip')
        _remove_if_exists(cached_file)

        url = _git_archive_link(repo_owner, repo_name, branch)
        sys.stderr.write('Downloading: \"{}\" to {}\n'.format(url, cached_file))
        download_url_to_file(url, cached_file, progress=False)

        with zipfile.ZipFile(cached_file) as cached_zipfile:
            extraced_repo_name = cached_zipfile.infolist()[0].filename
            extracted_repo = os.path.join(hub_dir, extraced_repo_name)
            _remove_if_exists(extracted_repo)
            # Unzip the code and rename the base folder
            cached_zipfile.extractall(hub_dir)

        _remove_if_exists(cached_file)
        _remove_if_exists(repo_dir)
        shutil.move(extracted_repo, repo_dir)  # rename the repo

    return repo_dir


def _check_module_exists(name):
    if sys.version_info >= (3, 4):
        import importlib.util
        return importlib.util.find_spec(name) is not None
    elif sys.version_info >= (3, 3):
        # Special case for python3.3
        import importlib.find_loader
        return importlib.find_loader(name) is not None
    else:
        # NB: Python2.7 imp.find_module() doesn't respect PEP 302,
        #     it cannot find a package installed as .egg(zip) file.
        #     Here we use workaround from:
        #     https://stackoverflow.com/questions/28962344/imp-find-module-which-supports-zipped-eggs?lq=1
        #     Also imp doesn't handle hierarchical module names (names contains dots).
        try:
            # 1. Try imp.find_module(), which searches sys.path, but does
            # not respect PEP 302 import hooks.
            import imp
            result = imp.find_module(name)
            if result:
                return True
        except ImportError:
            pass
        path = sys.path
        for item in path:
            # 2. Scan path for import hooks. sys.path_importer_cache maps
            # path items to optional "importer" objects, that implement
            # find_module() etc.  Note that path must be a subset of
            # sys.path for this to work.
            importer = sys.path_importer_cache.get(item)
            if importer:
                try:
                    result = importer.find_module(name, [item])
                    if result:
                        return True
                except ImportError:
                    pass
        return False

def _check_dependencies(m):
    dependencies = _load_attr_from_module(m, VAR_DEPENDENCY)

    if dependencies is not None:
        missing_deps = [pkg for pkg in dependencies if not _check_module_exists(pkg)]
        if len(missing_deps):
            raise RuntimeError('Missing dependencies: {}'.format(', '.join(missing_deps)))


def _load_entry_from_hubconf(m, model):
    if not isinstance(model, str):
        raise ValueError('Invalid input: model should be a string of function name')

    # Note that if a missing dependency is imported at top level of hubconf, it will
    # throw before this function. It's a chicken and egg situation where we have to
    # load hubconf to know what're the dependencies, but to import hubconf it requires
    # a missing package. This is fine, Python will throw proper error message for users.
    _check_dependencies(m)

    func = _load_attr_from_module(m, model)

    if func is None or not callable(func):
        raise RuntimeError('Cannot find callable {} in hubconf'.format(model))

    return func


[docs]def set_dir(d): r""" Optionally set hub_dir to a local dir to save downloaded models & weights. If ``set_dir`` is not called, default path is ``$TORCH_HOME/hub`` where environment variable ``$TORCH_HOME`` defaults to ``$XDG_CACHE_HOME/torch``. ``$XDG_CACHE_HOME`` follows the X Design Group specification of the Linux filesytem layout, with a default value ``~/.cache`` if the environment variable is not set. Args: d (string): path to a local folder to save downloaded models & weights. """ global hub_dir hub_dir = d
[docs]def list(github, force_reload=False): r""" List all entrypoints available in `github` hubconf. Args: github (string): a string with format "repo_owner/repo_name[:tag_name]" with an optional tag/branch. The default branch is `master` if not specified. Example: 'pytorch/vision[:hub]' force_reload (bool, optional): whether to discard the existing cache and force a fresh download. Default is `False`. Returns: entrypoints: a list of available entrypoint names Example: >>> entrypoints = torch.hub.list('pytorch/vision', force_reload=True) """ # Setup hub_dir to save downloaded files _setup_hubdir() repo_dir = _get_cache_or_reload(github, force_reload, True) sys.path.insert(0, repo_dir) hub_module = import_module(MODULE_HUBCONF, repo_dir + '/' + MODULE_HUBCONF) sys.path.remove(repo_dir) # We take functions starts with '_' as internal helper functions entrypoints = [f for f in dir(hub_module) if callable(getattr(hub_module, f)) and not f.startswith('_')] return entrypoints
[docs]def help(github, model, force_reload=False): r""" Show the docstring of entrypoint `model`. Args: github (string): a string with format <repo_owner/repo_name[:tag_name]> with an optional tag/branch. The default branch is `master` if not specified. Example: 'pytorch/vision[:hub]' model (string): a string of entrypoint name defined in repo's hubconf.py force_reload (bool, optional): whether to discard the existing cache and force a fresh download. Default is `False`. Example: >>> print(torch.hub.help('pytorch/vision', 'resnet18', force_reload=True)) """ # Setup hub_dir to save downloaded files _setup_hubdir() repo_dir = _get_cache_or_reload(github, force_reload, True) sys.path.insert(0, repo_dir) hub_module = import_module(MODULE_HUBCONF, repo_dir + '/' + MODULE_HUBCONF) sys.path.remove(repo_dir) entry = _load_entry_from_hubconf(hub_module, model) return entry.__doc__
# Ideally this should be `def load(github, model, *args, forece_reload=False, **kwargs):`, # but Python2 complains syntax error for it. We have to skip force_reload in function # signature here but detect it in kwargs instead. # TODO: fix it after Python2 EOL
[docs]def load(github, model, *args, **kwargs): r""" Load a model from a github repo, with pretrained weights. Args: github (string): a string with format "repo_owner/repo_name[:tag_name]" with an optional tag/branch. The default branch is `master` if not specified. Example: 'pytorch/vision[:hub]' model (string): a string of entrypoint name defined in repo's hubconf.py *args (optional): the corresponding args for callable `model`. force_reload (bool, optional): whether to force a fresh download of github repo unconditionally. Default is `False`. verbose (bool, optional): If False, mute messages about hitting local caches. Note that the message about first download is cannot be muted. Default is `True`. **kwargs (optional): the corresponding kwargs for callable `model`. Returns: a single model with corresponding pretrained weights. Example: >>> model = torch.hub.load('pytorch/vision', 'resnet50', pretrained=True) """ # Setup hub_dir to save downloaded files _setup_hubdir() force_reload = kwargs.get('force_reload', False) kwargs.pop('force_reload', None) verbose = kwargs.get('verbose', True) kwargs.pop('verbose', None) repo_dir = _get_cache_or_reload(github, force_reload, verbose) sys.path.insert(0, repo_dir) hub_module = import_module(MODULE_HUBCONF, repo_dir + '/' + MODULE_HUBCONF) entry = _load_entry_from_hubconf(hub_module, model) model = entry(*args, **kwargs) sys.path.remove(repo_dir) return model
[docs]def download_url_to_file(url, dst, hash_prefix=None, progress=True): r"""Download object at the given URL to a local path. Args: url (string): URL of the object to download dst (string): Full path where object will be saved, e.g. `/tmp/temporary_file` hash_prefix (string, optional): If not None, the SHA256 downloaded file should start with `hash_prefix`. Default: None progress (bool, optional): whether or not to display a progress bar to stderr Default: True Example: >>> torch.hub.download_url_to_file('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', '/tmp/temporary_file') """ file_size = None # We use a different API for python2 since urllib(2) doesn't recognize the CA # certificates in older Python u = urlopen(url) meta = u.info() if hasattr(meta, 'getheaders'): content_length = meta.getheaders("Content-Length") else: content_length = meta.get_all("Content-Length") if content_length is not None and len(content_length) > 0: file_size = int(content_length[0]) # We deliberately save it in a temp file and move it after # download is complete. This prevents a local working checkpoint # being overriden by a broken download. dst = os.path.expanduser(dst) dst_dir = os.path.dirname(dst) f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir) try: if hash_prefix is not None: sha256 = hashlib.sha256() with tqdm(total=file_size, disable=not progress, unit='B', unit_scale=True, unit_divisor=1024) as pbar: while True: buffer = u.read(8192) if len(buffer) == 0: break f.write(buffer) if hash_prefix is not None: sha256.update(buffer) pbar.update(len(buffer)) f.close() if hash_prefix is not None: digest = sha256.hexdigest() if digest[:len(hash_prefix)] != hash_prefix: raise RuntimeError('invalid hash value (expected "{}", got "{}")' .format(hash_prefix, digest)) shutil.move(f.name, dst) finally: f.close() if os.path.exists(f.name): os.remove(f.name)
def _download_url_to_file(url, dst, hash_prefix=None, progress=True): warnings.warn('torch.hub._download_url_to_file has been renamed to\ torch.hub.download_url_to_file to be a public API,\ _download_url_to_file will be removed in after 1.3 release') download_url_to_file(url, dst, hash_prefix, progress)
[docs]def load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True, check_hash=False): r"""Loads the Torch serialized object at the given URL. If downloaded file is a zip file, it will be automatically decompressed. If the object is already present in `model_dir`, it's deserialized and returned. The default value of `model_dir` is ``$TORCH_HOME/checkpoints`` where environment variable ``$TORCH_HOME`` defaults to ``$XDG_CACHE_HOME/torch``. ``$XDG_CACHE_HOME`` follows the X Design Group specification of the Linux filesytem layout, with a default value ``~/.cache`` if not set. Args: url (string): URL of the object to download model_dir (string, optional): directory in which to save the object map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load) progress (bool, optional): whether or not to display a progress bar to stderr. Default: True check_hash(bool, optional): If True, the filename part of the URL should follow the naming convention ``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more digits of the SHA256 hash of the contents of the file. The hash is used to ensure unique names and to verify the contents of the file. Default: False Example: >>> state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth') """ # Issue warning to move data if old env is set if os.getenv('TORCH_MODEL_ZOO'): warnings.warn('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead') if model_dir is None: torch_home = _get_torch_home() model_dir = os.path.join(torch_home, 'checkpoints') try: os.makedirs(model_dir) except OSError as e: if e.errno == errno.EEXIST: # Directory already exists, ignore. pass else: # Unexpected OSError, re-raise. raise parts = urlparse(url) filename = os.path.basename(parts.path) cached_file = os.path.join(model_dir, filename) if not os.path.exists(cached_file): sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) hash_prefix = HASH_REGEX.search(filename).group(1) if check_hash else None download_url_to_file(url, cached_file, hash_prefix, progress=progress) # Note: extractall() defaults to overwrite file if exists. No need to clean up beforehand. # We deliberately don't handle tarfile here since our legacy serialization format was in tar. # E.g. resnet18-5c106cde.pth which is widely used. if zipfile.is_zipfile(cached_file): with zipfile.ZipFile(cached_file) as cached_zipfile: members = cached_zipfile.infolist() if len(members) != 1: raise RuntimeError('Only one file(not dir) is allowed in the zipfile') cached_zipfile.extractall(model_dir) extraced_name = members[0].filename cached_file = os.path.join(model_dir, extraced_name) return torch.load(cached_file, map_location=map_location)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources