Source code for utils.checkpoints


import random
import string
import numpy as np
import torch
from torch import distributed as dist
import os

from tqdm import tqdm
import urllib.request as request

from utils import smart_joint


def _load_mammoth_model(dict_keys, model: torch.nn.Module, args):
    for k in list(dict_keys):
        if args.distributed != 'dp':
            dict_keys[k.replace('module.', '')] = dict_keys.pop(k)
        elif 'module' not in k:
            dict_keys[k.replace('net.', 'net.module.')] = dict_keys.pop(k)

    for k in list(dict_keys):
        if '_features' in dict_keys:
            dict_keys.pop(k)

    if 'lucir' in args.model.lower():
        model.register_buffer('classes_so_far', torch.zeros_like(
            dict_keys['classes_so_far']).to('cpu'))

    model.load_state_dict(dict_keys)
    model.net.to(model.device)
    return model


def _load_net(dict_keys, model: torch.nn.Module, args, ignore_classifier=True):
    for k in list(dict_keys):
        if args.distributed != 'dp':
            dict_keys[k.replace('module.', '')] = dict_keys.pop(k)
        elif 'module' not in k:
            if 'net' in k:
                dict_keys[k.replace('net.', 'net.module.')] = dict_keys.pop(k)
            else:
                dict_keys[f'module.{k}'] = dict_keys.pop(k)

    if not ignore_classifier:
        cl_weights = [dict_keys[k] for k in list(dict_keys.keys()) if 'classifier' in k]
        if len(cl_weights) > 0:
            cl_size = cl_weights[-1].shape[0]
            model.net.classifier = torch.nn.Linear(
                model.net.classifier.in_features, cl_size).to(model.device)
    else:
        for k in list(dict_keys):
            if 'classifier' in k:
                dict_keys.pop(k)

    for k in list(dict_keys):
        if '_features' in dict_keys:
            dict_keys.pop(k)
    for k in list(dict_keys):
        if 'net' in k:
            dict_keys[k[4:]] = dict_keys.pop(k)
    for k in list(dict_keys):
        if 'wrappee.' in k:
            dict_keys[k.replace('wrappee.', '')] = dict_keys.pop(k)

    try:
        model.net.load_state_dict(dict_keys)
    except BaseException:
        _, unm = model.net.load_state_dict(dict_keys, strict=False)
        unm = [k for k in unm if '_features' not in k and 'linear' not in k]
        if ignore_classifier:
            assert all(['classifier' in k for k in unm]
                       ), f"Some of the keys not loaded where not classifier keys: {unm}"
        else:
            assert unm is None or len(unm) == 0, f"Missing keys: {unm}"

    model.net.to(model.device)
    return model


def _get_random_filename(length=10):
    return ''.join(random.choices(string.ascii_uppercase + string.digits, k=length))


def _download_from_raw_url(url: str, root: str):
    os.makedirs(root, exist_ok=True)
    filename = _get_random_filename()

    download_target = smart_joint(root, filename)

    if os.path.exists(download_target) and not os.path.isfile(download_target):
        raise RuntimeError(f"{download_target} exists and is not a regular file")

    with request.urlopen(url) as source, open(download_target, "wb") as output:
        with tqdm(total=int(source.info().get("Content-Length")), unit='iB', unit_scale=True, unit_divisor=1024) as loop:
            while True:
                buffer = source.read(8192)
                if not buffer:
                    break

                output.write(buffer)
                loop.update(len(buffer))

    return download_target


[docs] def mammoth_load_checkpoint(args, model: torch.nn.Module, ignore_classifier=False) -> torch.nn.Module: """ Loads the keys from the given checkpoint. - Handles DataParallel and DistributedDataParallel checkpoints. - Handles checkpoints from previous versions of the code. - Handles head initialization for LUCIR. Args: args: the model with the checkpoint loaded. model: the model to be loaded. ignore_classifier: whether to ignore the classifier weights. Returns: the model with the checkpoint loaded. """ # check if checkpoint is a URL if args.loadcheck.startswith('http'): if 'sharepoint' in args.loadcheck: try: from onedrivedownloader import download except ImportError: raise ImportError('OneDriveDownloader is required to download from Sharepoint. Please install it with "pip install onedrivedownloader"') print('Downloading checkpoint using OneDriveDownloader...') args.loadcheck = download(args.loadcheck, filename='checkpoints/', unzip=True, unzip_path='checkpoints/', clean=True) elif 'drive.google.com' in args.loadcheck: try: from google_drive_downloader import GoogleDriveDownloader as gdd except ImportError: raise ImportError('GoogleDriveDownloader is required to download from Google Drive. Please install it with "pip install googledrivedownloader"') print('Downloading checkpoint using GoogleDriveDownloader...') # get random filename filename = _get_random_filename() gdd.download_file_from_google_drive(file_id=args.loadcheck.split('/')[-2], dest_path=f'checkpoints/{filename}', unzip=True) args.loadcheck = f'checkpoints/{filename}' else: print('Attempting to download raw checkpoint...') args.loadcheck = _download_from_raw_url(args.loadcheck, 'checkpoints/') print(f'Checkpoint downloaded to {args.loadcheck}') else: if not os.path.exists(args.loadcheck): raise ValueError('The given checkpoint does not exist.') saved_obj = torch.load(args.loadcheck, map_location=torch.device("cpu")) if 'args' in saved_obj and 'model' in saved_obj: _check_loaded_args(args, saved_obj['args']) # Mammoth checkpoint model = _load_mammoth_model(saved_obj['model'], model, args) if 'buffer' in saved_obj: loading_model = saved_obj['args'].model if args.model != loading_model: print(f'WARNING: The loaded model was trained with a different model: {loading_model}') model.load_buffer(saved_obj['buffer']) return model, saved_obj['results'] else: # Model only checkpoint model = _load_net(saved_obj, model, args, ignore_classifier=ignore_classifier) return model, None
def _check_loaded_args(args, loaded_args): def _check_arg(arg, loaded_arg): if isinstance(arg, (list, tuple)): return any([a != la for a, la in zip(arg, loaded_arg)]) elif isinstance(arg, dict): return any([k not in loaded_arg or _check_arg(v, loaded_arg[k]) for k, v in arg.items()]) elif isinstance(arg, (torch.Tensor, np.ndarray)): return (arg != loaded_arg).any() return arg != loaded_arg ignored_args = ['loadcheck', 'start_from', 'stop_after', 'conf_jobnum', 'conf_host', 'conf_timestamp', 'distributed', 'examples_log', 'examples_full_log', 'intensive_savecheck', 'job_number', 'conf_git_commit', 'loss_log', 'tensorboard', 'seed', 'savecheck', 'notes', 'non_verbose', 'autorelaunch', 'force_compat', 'conf_external_path', 'ckpt_name'] mismatched_args = [x for x in vars(args) if x not in ignored_args and ( x not in vars(loaded_args) or _check_arg(getattr(args, x), getattr(loaded_args, x)))] if len(mismatched_args): if 'force_compat' not in vars(args) or args.force_compat: print( "WARNING: The following arguments do not match between loaded and current model:") print(mismatched_args) else: raise ValueError( 'The loaded model was trained with different arguments: {}'.format(mismatched_args))