Source code for pugh_torch.models.io

import torch
from pathlib import Path
from . import ROOT_MODELS_PATH
from ..utils.io import gdrive_download
import logging

log = logging.getLogger(__name__)


[docs]def load_state_dict_from_url( url, local=None, map_location=None, progress=True, force=False, **kwargs ): """ Parameters ---------- url : str local : str-like Local path of where to save or check for cached file. If relative, is relative to the torch.hub directory. force : bool Force the redownload, even if the local file exists. """ url = str(url) gdrive_prefix = "https://drive.google.com" if url.startswith(gdrive_prefix): assert local is not None, f"Must specify path for gdrive files." local = Path(torch.hub.get_dir()) / local # Custom solution if not local.is_file() or force: downloaded_path = gdrive_download(url, local) assert downloaded_path == local state_dict = torch.load(local, map_location=map_location) else: if "drive.google.com" in url: log.warning( f'URL {url} contaings "drive.google.com", suggesting a google drive link. Google drive links MUST start with "{gdrive_prefix}"' ) # Let torch.hub handle everything if local is None: model_dir = None file_name = None else: local = Path(local) if local.suffix: # Has extension model_dir = Path(torch.hub.get_dir()) / local.parent file_name = local.name else: # Assume it was meant to be a directory model_dir = Path(torch.hub.get_dir()) / local model_dir.mkdir(parents=True, exist_ok=True) file_name = None # use whatever name torch.hub gives us if force and file_name is not None: # Delete the local file if it exists to force redownload. local_file = model_dir / file_name try: local_file.unlink() except FileNotFoundError: pass # TODO: force only works if an explicit local file path is provided. state_dict = torch.hub.load_state_dict_from_url( url, model_dir=None if model_dir is None else str(model_dir), map_location=map_location, progress=progress, check_hash=False, file_name=file_name, ) return state_dict