import random
import string
import numpy as np
import torch
from torch import distributed as dist
import os

from tqdm import tqdm
import urllib.request as request

from utils import smart_joint

def _load_mammoth_model(dict_keys, model: torch.nn.Module, args):
    for k in list(dict_keys):
        if args.distributed != 'dp':
            dict_keys[k.replace('module.', '')] = dict_keys.pop(k)
        elif 'module' not in k:
            dict_keys[k.replace('net.', 'net.module.')] = dict_keys.pop(k)

    for k in list(dict_keys):
        if '_features' in dict_keys:

    if 'lucir' in args.model.lower():
        model.register_buffer('classes_so_far', torch.zeros_like(

    return model

def _load_net(dict_keys, model: torch.nn.Module, args, ignore_classifier=True):
    for k in list(dict_keys):
        if args.distributed != 'dp':
            dict_keys[k.replace('module.', '')] = dict_keys.pop(k)
        elif 'module' not in k:
            if 'net' in k:
                dict_keys[k.replace('net.', 'net.module.')] = dict_keys.pop(k)
                dict_keys[f'module.{k}'] = dict_keys.pop(k)

    if not ignore_classifier:
        cl_weights = [dict_keys[k] for k in list(dict_keys.keys()) if 'classifier' in k]
        if len(cl_weights) > 0:
            cl_size = cl_weights[-1].shape[0]
   = torch.nn.Linear(
      , cl_size).to(model.device)
        for k in list(dict_keys):
            if 'classifier' in k:

    for k in list(dict_keys):
        if '_features' in dict_keys:
    for k in list(dict_keys):
        if 'net' in k:
            dict_keys[k[4:]] = dict_keys.pop(k)
    for k in list(dict_keys):
        if 'wrappee.' in k:
            dict_keys[k.replace('wrappee.', '')] = dict_keys.pop(k)

    except BaseException:
        _, unm =, strict=False)
        unm = [k for k in unm if '_features' not in k and 'linear' not in k]
        if ignore_classifier:
            assert all(['classifier' in k for k in unm]
                       ), f"Some of the keys not loaded where not classifier keys: {unm}"
            assert unm is None or len(unm) == 0, f"Missing keys: {unm}"
    return model

def _get_random_filename(length=10):
    return ''.join(random.choices(string.ascii_uppercase + string.digits, k=length))

def _download_from_raw_url(url: str, root: str):
    os.makedirs(root, exist_ok=True)
    filename = _get_random_filename()

    download_target = smart_joint(root, filename)

    if os.path.exists(download_target) and not os.path.isfile(download_target):
        raise RuntimeError(f"{download_target} exists and is not a regular file")

    with request.urlopen(url) as source, open(download_target, "wb") as output:
        with tqdm(total=int("Content-Length")), unit='iB', unit_scale=True, unit_divisor=1024) as loop:
            while True:
                buffer =
                if not buffer:


    return download_target

[docs] def mammoth_load_checkpoint(args, model: torch.nn.Module, ignore_classifier=False) -> torch.nn.Module: """ Loads the keys from the given checkpoint. - Handles DataParallel and DistributedDataParallel checkpoints. - Handles checkpoints from previous versions of the code. - Handles head initialization for LUCIR. Args: args: the model with the checkpoint loaded. model: the model to be loaded. ignore_classifier: whether to ignore the classifier weights. Returns: the model with the checkpoint loaded. """ # check if checkpoint is a URL if args.loadcheck.startswith('http'): if 'sharepoint' in args.loadcheck: try: from onedrivedownloader import download except ImportError: raise ImportError('OneDriveDownloader is required to download from Sharepoint. Please install it with "pip install onedrivedownloader"') print('Downloading checkpoint using OneDriveDownloader...') args.loadcheck = download(args.loadcheck, filename='checkpoints/', unzip=True, unzip_path='checkpoints/', clean=True) elif '' in args.loadcheck: try: from google_drive_downloader import GoogleDriveDownloader as gdd except ImportError: raise ImportError('GoogleDriveDownloader is required to download from Google Drive. Please install it with "pip install googledrivedownloader"') print('Downloading checkpoint using GoogleDriveDownloader...') # get random filename filename = _get_random_filename() gdd.download_file_from_google_drive(file_id=args.loadcheck.split('/')[-2], dest_path=f'checkpoints/{filename}', unzip=True) args.loadcheck = f'checkpoints/{filename}' else: print('Attempting to download raw checkpoint...') args.loadcheck = _download_from_raw_url(args.loadcheck, 'checkpoints/') print(f'Checkpoint downloaded to {args.loadcheck}') else: if not os.path.exists(args.loadcheck): raise ValueError('The given checkpoint does not exist.') saved_obj = torch.load(args.loadcheck, map_location=torch.device("cpu")) if 'args' in saved_obj and 'model' in saved_obj: _check_loaded_args(args, saved_obj['args']) # Mammoth checkpoint model = _load_mammoth_model(saved_obj['model'], model, args) if 'buffer' in saved_obj: loading_model = saved_obj['args'].model if args.model != loading_model: print(f'WARNING: The loaded model was trained with a different model: {loading_model}') model.load_buffer(saved_obj['buffer']) return model, saved_obj['results'] else: # Model only checkpoint model = _load_net(saved_obj, model, args, ignore_classifier=ignore_classifier) return model, None
def _check_loaded_args(args, loaded_args): def _check_arg(arg, loaded_arg): if isinstance(arg, (list, tuple)): return any([a != la for a, la in zip(arg, loaded_arg)]) elif isinstance(arg, dict): return any([k not in loaded_arg or _check_arg(v, loaded_arg[k]) for k, v in arg.items()]) elif isinstance(arg, (torch.Tensor, np.ndarray)): return (arg != loaded_arg).any() return arg != loaded_arg ignored_args = ['loadcheck', 'start_from', 'stop_after', 'conf_jobnum', 'conf_host', 'conf_timestamp', 'distributed', 'examples_log', 'examples_full_log', 'intensive_savecheck', 'job_number', 'conf_git_commit', 'loss_log', 'tensorboard', 'seed', 'savecheck', 'notes', 'non_verbose', 'autorelaunch', 'force_compat', 'conf_external_path', 'ckpt_name'] mismatched_args = [x for x in vars(args) if x not in ignored_args and ( x not in vars(loaded_args) or _check_arg(getattr(args, x), getattr(loaded_args, x)))] if len(mismatched_args): if 'force_compat' not in vars(args) or args.force_compat: print( "WARNING: The following arguments do not match between loaded and current model:") print(mismatched_args) else: raise ValueError( 'The loaded model was trained with different arguments: {}'.format(mismatched_args))