Source code for utils.checkpoints

import signal
import sys
import uuid
import functools
import os
from argparse import Namespace
import copy
import logging
import random
import string
from typing import Callable, Dict, Optional, Tuple, Union
import numpy as np
import torch

from tqdm.auto import tqdm
import urllib.request as request

from utils import smart_joint, to_parsable_obj, in_notebook
from utils.globals import GLOBALS 
from utils.conf import get_checkpoint_path

def _load_mammoth_model(dict_keys, model: torch.nn.Module, args):
    for k in list(dict_keys):
        if args.distributed != 'dp':
            dict_keys[k.replace('module.', '')] = dict_keys.pop(k)
        elif 'module' not in k:
            dict_keys[k.replace('net.', 'net.module.')] = dict_keys.pop(k)

    for k in list(dict_keys):
        if '_features' in dict_keys:
            dict_keys.pop(k)

    if 'lucir' in args.model.lower():
        model.register_buffer('classes_so_far', torch.zeros_like(
            dict_keys['classes_so_far']).to('cpu'))

    model.load_state_dict(dict_keys)
    model.net.to(model.device)
    return model


def _load_net(dict_keys, model: torch.nn.Module, ignore_classifier=True):
    """
    Load a model from a checkpoint. Handles DataParallel and DistributedDataParallel checkpoints.
    If ignore_classifier is True, the classifier weights are not loaded.
    """
    for k in list(dict_keys):
        if k.startswith('module.'): # remove 'module.' prefix if present
            dict_keys[k.replace('module.', '')] = dict_keys.pop(k)
        else: #remove '.module.' if present
            dict_keys[k.replace('.module.', '.')] = dict_keys.pop(k)

    if not ignore_classifier:
        cl_weights = [dict_keys[k] for k in list(dict_keys.keys()) if 'classifier' in k]
        if len(cl_weights) > 0:
            cl_size = cl_weights[-1].shape[0]
            model.net.classifier = torch.nn.Linear(
                model.net.classifier.in_features, cl_size).to(model.device)
    else:
        for k in list(dict_keys):
            if 'classifier' in k:
                dict_keys.pop(k)

    for k in list(dict_keys):
        if '_features' in dict_keys:
            dict_keys.pop(k)
    for k in list(dict_keys):
        if 'net' in k:
            dict_keys[k[4:]] = dict_keys.pop(k)
    for k in list(dict_keys):
        if 'wrappee.' in k:
            dict_keys[k.replace('wrappee.', '')] = dict_keys.pop(k)

    try:
        model.net.load_state_dict(dict_keys)
    except BaseException:
        _, unm = model.net.load_state_dict(dict_keys, strict=False)
        unm = [k for k in unm if '_features' not in k and 'linear' not in k]
        if ignore_classifier:
            assert all(['classifier' in k for k in unm]
                       ), f"Some of the keys not loaded where not classifier keys: {unm}"
        else:
            assert unm is None or len(unm) == 0, f"Missing keys: {unm}"

    model.net.to(model.device)
    return model


def _get_random_filename(length=10):
    return ''.join(random.choices(string.ascii_uppercase + string.digits, k=length))


def _download_from_raw_url(url: str, root: str, filename: str = None) -> str:
    os.makedirs(root, exist_ok=True)
    filename = _get_random_filename() + '.pth' if filename is None else filename

    download_target = smart_joint(root, filename)

    if os.path.exists(download_target) and not os.path.isfile(download_target):
        raise RuntimeError(f"{download_target} exists and is not a regular file")

    with request.urlopen(url) as source, open(download_target, "wb") as output:
        with tqdm(total=int(source.info().get("Content-Length")), unit='iB', unit_scale=True, unit_divisor=1024) as loop:
            while True:
                buffer = source.read(8192)
                if not buffer:
                    break

                output.write(buffer)
                loop.update(len(buffer))

    return download_target

[docs] class OnlyArgsError(Exception): """Raised when the checkpoint does not contain any arguments and `return_only_args` is True.""" pass
[docs] def mammoth_load_checkpoint(checkpoint_path: str, model: Optional[torch.nn.Module] = None, ignore_classifier=False, args: Optional[Namespace]=None, return_only_args: bool=False) -> Union[Namespace, Tuple[torch.nn.Module, Optional[Dict[str, Union[float, int]]]]]: """ Loads the keys from the given checkpoint. - Handles DataParallel and DistributedDataParallel checkpoints. - Handles checkpoints from previous versions of the code. - Handles head initialization for LUCIR. Args: checkpoint_path: the path to the checkpoint file or URL. model: the model to be loaded. It can be None ONLY with `return_only_args=True`. ignore_classifier: whether to ignore the classifier weights. args: the current arguments. If provided, it will check if the loaded arguments match the current ones. return_only_args: if True, only returns the loaded arguments and not the model. Returns: the model with the checkpoint loaded. """ assert model is not None or return_only_args, "Model must be provided if return_only_args is False." # check if checkpoint is a URL if checkpoint_path.startswith('http'): if 'sharepoint' in checkpoint_path: try: from onedrivedownloader import download except ImportError: raise ImportError('OneDriveDownloader is required to download from Sharepoint. Please install it with "pip install onedrivedownloader"') logging.info('Downloading checkpoint using OneDriveDownloader...') checkpoint_path = download(checkpoint_path, filename=get_checkpoint_path(), unzip=True, unzip_path=get_checkpoint_path(), clean=True) elif 'drive.google.com' in checkpoint_path: try: from google_drive_downloader import GoogleDriveDownloader as gdd except ImportError: raise ImportError('GoogleDriveDownloader is required to download from Google Drive. Please install it with "pip install googledrivedownloader"') logging.info('Downloading checkpoint using GoogleDriveDownloader...') # get random filename filename = _get_random_filename() dest = os.path.join(get_checkpoint_path(), filename) gdd.download_file_from_google_drive(file_id=checkpoint_path.split('/')[-2], dest_path=dest, unzip=True) checkpoint_path = dest elif checkpoint_path.startswith('https://huggingface.co/'): logging.info('Downloading checkpoints from HuggingFace...') filename = checkpoint_path.split('/')[-1].split('?')[0] checkpoint_path = _download_from_raw_url(checkpoint_path, get_checkpoint_path(), filename=filename) else: logging.warning('Attempting to download raw checkpoint. Make sure to check the URL.') checkpoint_path = _download_from_raw_url(checkpoint_path, get_checkpoint_path()) logging.info(f'Checkpoint downloaded to {checkpoint_path}') else: if not os.path.exists(checkpoint_path): raise ValueError('The given checkpoint does not exist.') saved_obj = torch.load(checkpoint_path, map_location=torch.device("cpu"), weights_only=True) if 'args' in saved_obj: ckpt_args = Namespace(**saved_obj['args']) # convert back to Namespace if args: _check_loaded_args(args, ckpt_args) else: args = ckpt_args if return_only_args: return args if 'model' in saved_obj: # Mammoth checkpoint model = _load_mammoth_model(saved_obj['model'], model, args) if 'buffer' in saved_obj: loading_model = args.model if args.model != loading_model: logging.warning(f'The loaded model was trained with a different model: {loading_model}') model.load_buffer(saved_obj['buffer']) return model, saved_obj['results'] else: raise ValueError("""The checkpoint is not in a valid format. Expect a checkpoint either with: - 'args' and 'model' keys (Mammoth checkpoint) - simple state_dict WITH NO 'args' KEY""") else: if return_only_args: raise OnlyArgsError('The checkpoint does not contain any arguments. Cannot return only args.') # Model only checkpoint model = _load_net(saved_obj, model, ignore_classifier=ignore_classifier) return model, None
[docs] def save_mammoth_checkpoint(task: int, end_task: int, args: Namespace, model: torch.nn.Module, results=None, optimizer_st: Dict[str, torch.Tensor] = None, scheduler_st: Dict[str, torch.Tensor] = None, checkpoint_name: str = None): """ Save a checkpoint for the model for the given task. Handles saving as a single file (will require `weights_only=False)` or separate weights (can be loaded safely with `weights_only=True`). """ if checkpoint_name is None: if args.savecheck == 'task': checkpoint_name = os.path.join(get_checkpoint_path(), f'{args.ckpt_name}_joint') if args.joint else os.path.join(get_checkpoint_path(), f'{args.ckpt_name}_{task}') elif args.savecheck == 'last': if task == end_task - 1: checkpoint_name = os.path.join(get_checkpoint_path(), f'{args.ckpt_name}_joint') if args.joint else os.path.join(get_checkpoint_path(), f'{args.ckpt_name}_last') else: return else: raise ValueError(f'Invalid savecheck mode: {args.savecheck}') if args.save_checkpoint_mode == 'old_pickle': save_obj = { 'model': model.state_dict(), 'args': args, 'results': results, 'optimizer': optimizer_st, 'scheduler': scheduler_st, } if 'buffer_size' in model.args: save_obj['buffer'] = copy.deepcopy(model.buffer).to('cpu') elif args.save_checkpoint_mode == 'safe': # TODO CHECK save_obj = { 'model': model.state_dict(), 'optimizer': optimizer_st, 'scheduler': scheduler_st, 'args': to_parsable_obj(vars(args)), # avoid Namespace and other non-builtin types 'results': to_parsable_obj(results), # avoid numpy, torch, and non-builtin types } if 'buffer_size' in model.args: save_obj['buffer'] = model.buffer.serialize() torch.save(save_obj, checkpoint_name + '.pt') logging.warning(f"Checkpoint for task {task} saved at {checkpoint_name}")
def _check_loaded_args(args, loaded_args): pruned_original_args = to_parsable_obj(vars(args)) def _check_arg(arg, loaded_arg): if isinstance(arg, (list, tuple)): return any([a != la for a, la in zip(arg, loaded_arg)]) elif isinstance(arg, dict): return any([k not in loaded_arg or _check_arg(v, loaded_arg[k]) for k, v in arg.items()]) elif isinstance(arg, (torch.Tensor, np.ndarray)): return (arg != loaded_arg).any() return arg != loaded_arg ignored_args = ['loadcheck', 'start_from', 'stop_after', 'conf_jobnum', 'conf_host', 'conf_timestamp', 'distributed', 'examples_log', 'examples_full_log', 'intensive_savecheck', 'job_number', 'conf_git_commit', 'loss_log', 'tensorboard', 'seed', 'savecheck', 'notes', 'non_verbose', 'autorelaunch', 'force_compat', 'conf_external_path', 'ckpt_name'] mismatched_args = [x for x in pruned_original_args if x not in ignored_args and ( x not in vars(loaded_args) or _check_arg(pruned_original_args[x], getattr(loaded_args, x)))] if len(mismatched_args): if 'force_compat' not in vars(args) or args.force_compat: logging.warning("The following arguments do not match between loaded and current model:") logging.warning(mismatched_args) else: raise ValueError('The loaded model was trained with different arguments: {}'.format(mismatched_args))
[docs] def can_save_and_exit(fn: Callable) -> Callable: """ Wraps a function to catch KeyboardInterrupt and SigInt signals. If running in a Jupyter notebook, this will prevent the kernel from crashing when the user interrupts the execution of a cell and retain the current state. If running in a script, this will: - catch the KeyboardInterrupt and exit gracefully - catch the SigInt and save a checkpoint before exiting This is useful for training scripts where you want to be able to stop the training process and save the current state of the model. Args: fn: the function to be wrapped Returns: the wrapped function """ wrapped = hasattr(can_save_and_exit, 'wrapped') ckpt_path = get_checkpoint_path() tmp_filename = str(uuid.uuid4()) ckpt_path = os.path.join(ckpt_path, 'paused', tmp_filename) if not os.path.exists(os.path.dirname(ckpt_path)): os.makedirs(os.path.dirname(ckpt_path)) @functools.wraps(fn) def wrapped_fn(*args, **kwargs): if not wrapped: if not in_notebook(): signal.signal(signal.SIGINT, _get_sigint_handler(fn, ckpt_path)) signal.signal(signal.SIGTERM, _get_sigint_handler(fn, ckpt_path)) else: def _ignore_sigint(signum, frame): global GLOBALS logging.info("SIGINT received in notebook. Ignoring to prevent kernel crash.") GLOBALS['SHOULD_STOP'] = True # type: ignore[assignment] signal.signal(signal.SIGINT, _ignore_sigint) try: return fn(*args, **kwargs) except (KeyboardInterrupt, SystemExit): pass setattr(can_save_and_exit, 'wrapped', True) # avoid re-registering the signal handler return wrapped_fn
def _get_sigint_handler(fn: Callable, ckpt_path: str) -> Callable: def _handle_sigint_terminal(signum, frame): global GLOBALS if GLOBALS['SHOULD_STOP']: # should have stopped already, forcing logging.info("SIGINT received again. Forcing exit...") sys.exit(1) current = frame _locals = {} while current: if current.f_code == fn.__code__: _locals = current.f_locals break current = current.f_back if 'args' not in _locals: # not initialized yet, can safely exit logging.info("SIGINT received before initialization. Exiting...") GLOBALS['SHOULD_STOP'] = True logging.info("SIGINT received. Saving checkpoint and exiting...") exp_args = _locals.get('args') model = _locals.get('model') scheduler = _locals.get('scheduler') if exp_args.save_after_interrupt: save_mammoth_checkpoint(_locals['cur_task'], _locals['end_task'], exp_args, model, results=[_locals['results'], _locals['results_mask_classes'], _locals['logger'].dump()], optimizer_st=model.opt.state_dict() if hasattr(model, 'opt') else None, scheduler_st=scheduler.state_dict() if scheduler is not None else None, checkpoint_name=ckpt_path) GLOBALS['SHOULD_STOP'] = True return _handle_sigint_terminal