Source code for utils.training

# Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import numpy as np
import copy
import math
import os
import sys
from argparse import Namespace
from time import time
from typing import Iterable, Tuple
import logging
import torch
from tqdm import tqdm

from datasets import get_dataset
from datasets.utils.continual_dataset import ContinualDataset, MammothDatasetWrapper
from datasets.utils.gcl_dataset import GCLDataset
from models.utils.continual_model import ContinualModel
from models.utils.future_model import FutureModel

from utils.checkpoints import mammoth_load_checkpoint
from utils.loggers import log_extra_metrics, log_accs, Logger
from utils.schedulers import get_scheduler
from utils.stats import track_system_stats

try:
    import wandb
except ImportError:
    wandb = None


[docs] def mask_classes(outputs: torch.Tensor, dataset: ContinualDataset, k: int) -> None: """ Given the output tensor, the dataset at hand and the current task, masks the former by setting the responses for the other tasks at -inf. It is used to obtain the results for the task-il setting. Args: outputs: the output tensor dataset: the continual dataset k: the task index """ num_classes = dataset.N_CLASSES start_c, end_c = dataset.get_offsets(k) outputs[:, :start_c] = -float('inf') outputs[:, end_c:num_classes] = -float('inf')
[docs] @torch.no_grad() def evaluate(model: ContinualModel, dataset: ContinualDataset, last=False, return_loss=False) -> Tuple[list, list]: """ Evaluates the accuracy of the model for each past task. The accuracy is evaluated for all the tasks up to the current one, only for the total number of classes seen so far. Args: model: the model to be evaluated dataset: the continual dataset at hand last: a boolean indicating whether to evaluate only the last task return_loss: a boolean indicating whether to return the loss in addition to the accuracy Returns: a tuple of lists, containing the class-il and task-il accuracy for each task. If return_loss is True, the loss is also returned as a third element. """ status = model.net.training model.net.eval() accs, accs_mask_classes = [], [] n_classes = dataset.get_offsets()[1] loss_fn = dataset.get_loss() avg_loss = 0 total_len = sum(len(x) for x in dataset.test_loaders) if hasattr(dataset.test_loaders[0], '__len__') else None pbar = tqdm(dataset.test_loaders, total=total_len, desc='Evaluating', disable=model.args.non_verbose) for k, test_loader in enumerate(dataset.test_loaders): if last and k < len(dataset.test_loaders) - 1: continue correct, correct_mask_classes, total = 0.0, 0.0, 0.0 test_iter = iter(test_loader) i = 0 while True: try: data = next(test_iter) except StopIteration: break if model.args.debug_mode and i > model.get_debug_iters(): break inputs, labels = data[0], data[1] inputs, labels = inputs.to(model.device), labels.to(model.device) if 'class-il' not in model.COMPATIBILITY and 'general-continual' not in model.COMPATIBILITY: outputs = model(inputs, k) else: if model.args.eval_future and k >= model.current_task: outputs = model.future_forward(inputs) else: outputs = model(inputs) if return_loss: loss = loss_fn(outputs, labels) avg_loss += loss.item() _, pred = torch.max(outputs[:, :n_classes].data, 1) correct += torch.sum(pred == labels).item() total += labels.shape[0] i += 1 pbar.set_postfix({f'acc_task_{k+1}': max(0, correct / total * 100)}, refresh=False) pbar.set_description(f"Evaluating Task {k+1}", refresh=False) pbar.update(1) if dataset.SETTING == 'class-il': mask_classes(outputs, dataset, k) _, pred = torch.max(outputs.data, 1) correct_mask_classes += torch.sum(pred == labels).item() accs.append(correct / total * 100 if 'class-il' in model.COMPATIBILITY or 'general-continual' in model.COMPATIBILITY else 0) accs_mask_classes.append(correct_mask_classes / total * 100) pbar.close() model.net.train(status) if return_loss: return accs, accs_mask_classes, avg_loss / total return accs, accs_mask_classes
[docs] def initialize_wandb(args: Namespace) -> None: """ Initializes wandb, if installed. Args: args: the arguments of the current execution """ assert wandb is not None, "Wandb not installed, please install it or run without wandb" run_name = args.wandb_name if args.wandb_name is not None else args.model run_id = args.conf_jobnum.split('-')[0] name = f'{run_name}_{run_id}' mode = 'disabled' if os.getenv('MAMMOTH_TEST', '0') == '1' else os.getenv('WANDB_MODE', 'online') wandb.init(project=args.wandb_project, entity=args.wandb_entity, config=vars(args), name=name, mode=mode) args.wandb_url = wandb.run.get_url()
def _to_device(name: str, x, device): if isinstance(x, torch.Tensor): if 'label' in name.lower() or 'target' in name.lower(): return x.to(device, dtype=torch.long) return x.to(device) return x
[docs] def train_single_epoch(model: ContinualModel, train_loader: Iterable, args: Namespace, epoch: int, current_task: int, system_tracker=None, data_len=None, scheduler=None) -> int: """ Trains the model for a single epoch. Args: model: the model to be trained train_loader: the data loader for the training set args: the arguments from the command line epoch: the current epoch current_task: the current task index system_tracker: the system tracker to monitor the system stats data_len: the length of the training data loader. If None, the progress bar will not show the training percentage scheduler: the scheduler for the current epoch Returns: the number of iterations performed in the current epoch """ train_iter = iter(train_loader) i = 0 previous_time = time() mininterval = 0.5 if data_len is not None and data_len > 1000 else 0.1 pbar = tqdm(train_iter, total=data_len, desc=f"Task {current_task + 1} - Epoch {epoch + 1}", disable=args.non_verbose, mininterval=mininterval) while True: try: data = next(train_iter) except StopIteration: break if args.debug_mode and i > model.get_debug_iters(): break if args.fitting_mode == 'iters' and model.task_iteration >= model.args.n_iters: break inputs, labels, not_aug_inputs = data[0], data[1], data[2] inputs, labels = inputs.to(model.device), labels.to(model.device, dtype=torch.long) not_aug_inputs = not_aug_inputs.to(model.device) extra_fields = { train_loader.dataset.extra_return_fields[k]: _to_device(train_loader.dataset.extra_return_fields[k], data[3 + k], model.device) for k in range(len(data) - 3) } loss = model.meta_observe(inputs, labels, not_aug_inputs, epoch=epoch, **extra_fields) assert not math.isnan(loss) if scheduler is not None and args.scheduler_mode == 'iter': scheduler.step() if args.code_optimization == 0 and 'cuda' in str(args.device): torch.cuda.synchronize() system_tracker() i += 1 time_diff = time() - previous_time previous_time = time() bar_log = {'loss': loss, 'lr': model.opt.param_groups[0]['lr']} if data_len: ep_h = 3600 / (data_len * time_diff) bar_log['ep/h'] = ep_h pbar.set_postfix(bar_log, refresh=False) pbar.update() if scheduler is not None and args.scheduler_mode == 'epoch': scheduler.step()
[docs] def train(model: ContinualModel, dataset: ContinualDataset, args: Namespace) -> None: """ The training process, including evaluations and loggers. Args: model: the module to be trained dataset: the continual dataset at hand args: the arguments of the current execution """ print(args) is_fwd_enabled = True can_compute_fwd_beforetask = True random_results_class, random_results_task = [], [] if not args.nowand: initialize_wandb(args) if not args.disable_log: logger = Logger(args, dataset.SETTING, dataset.NAME, model.NAME) model.net.to(model.device) torch.cuda.empty_cache() with track_system_stats(logger) as system_tracker: results, results_mask_classes = [], [] if args.eval_future: results_transf, results_mask_classes_transf = [], [] if args.start_from is not None: for i in range(args.start_from): train_loader, _ = dataset.get_data_loaders() model.meta_begin_task(dataset) model.meta_end_task(dataset) if args.loadcheck is not None: model, past_res = mammoth_load_checkpoint(args, model) if not args.disable_log and past_res is not None: (results, results_mask_classes, csvdump) = past_res logger.load(csvdump) print('Checkpoint Loaded!') print(file=sys.stderr) start_task = 0 if args.start_from is None else args.start_from end_task = dataset.N_TASKS if args.stop_after is None else args.stop_after if args.eval_future: assert isinstance(model, FutureModel), "Model must be an instance of FutureModel to evaluate on future tasks" eval_dataset = get_dataset(args) for _ in range(dataset.N_TASKS): eval_dataset.get_data_loaders() model.change_transform(eval_dataset) del eval_dataset.train_loader else: eval_dataset = dataset torch.cuda.empty_cache() for t in range(start_task, end_task): model.net.train() train_loader, _ = dataset.get_data_loaders() if not issubclass(dataset.__class__, GCLDataset): assert issubclass(train_loader.dataset.__class__, MammothDatasetWrapper), "Dataset must be an instance of MammothDatasetWrapper (did you forget to call the `store_masked_loaders`?)" if can_compute_fwd_beforetask and is_fwd_enabled and args.enable_other_metrics: # try to compute accuracy at the beginning of the task try: logging.info("Evaluating model before task (for Forward Transfer metric)...") random_res_class, random_res_task = evaluate(model, dataset, last=True) random_results_class.append(random_res_class) random_results_task.append(random_res_task) except Exception as e: logging.info(f"Could not evaluate before `begin_task`, will try after") # will try after the begin_task in case the model needs to setup something can_compute_fwd_beforetask = False model.meta_begin_task(dataset) if not can_compute_fwd_beforetask and is_fwd_enabled and args.enable_other_metrics: if train_loader.dataset.num_times_iterated == 0: # compute only if the model has not been trained yet try: logging.info("Evaluating model before task (for Forward Transfer metric)...") random_res_class, random_res_task = evaluate(model, dataset, last=True) random_results_class.append(random_res_class) random_results_task.append(random_res_task) except Exception as e: logging.error(f"Model `{model.NAME}` does not support pre-evaluation, will not compute Forward Transfer metric\n{e}") is_fwd_enabled = False else: logging.info("Model used the training data, skipping Forward Transfer metric compute") is_fwd_enabled = False if not args.inference_only and args.n_epochs > 0: if t and args.enable_other_metrics: accs = evaluate(model, eval_dataset, last=True) results[t - 1] = results[t - 1] + accs[0] if dataset.SETTING == 'class-il': results_mask_classes[t - 1] = results_mask_classes[t - 1] + accs[1] scheduler = get_scheduler(model, args, reload_optim=True) if not hasattr(model, 'scheduler') else model.scheduler epoch = 0 best_ea_metric = None best_ea_model = None cur_stopping_patience = args.early_stopping_patience while True: data_len = None if not isinstance(dataset, GCLDataset): data_len = len(train_loader) model.begin_epoch(epoch, dataset) train_single_epoch(model, train_loader, args, current_task=t, epoch=epoch, system_tracker=system_tracker, data_len=data_len, scheduler=scheduler) model.end_epoch(epoch, dataset) epoch += 1 if args.fitting_mode == 'epochs' and epoch >= model.args.n_epochs: break elif args.fitting_mode == 'iters' and model.task_iteration >= model.args.n_iters: break elif args.fitting_mode == 'early_stopping' and epoch % args.early_stopping_freq == 0 and epoch > 0: epoch_accs, _, epoch_loss = evaluate(model, eval_dataset, return_loss=True, last=True) if args.early_stopping_metric == 'accuracy': ea_metric = np.mean(epoch_accs) # Higher accuracy is better elif args.early_stopping_metric == 'loss': ea_metric = -epoch_loss # Lower loss is better else: raise ValueError(f'Unknown early stopping metric {args.early_stopping_metric}') # Higher accuracy is better if best_ea_metric is not None and ea_metric - best_ea_metric < args.early_stopping_epsilon: cur_stopping_patience -= args.early_stopping_freq if cur_stopping_patience <= 0: print(f"\nEarly stopping at epoch {epoch} with metric {abs(ea_metric)}", file=sys.stderr) model.load_state_dict({k: v.to(model.device) for k, v in best_ea_model.items()}) break print(f"\nNo improvement at epoch {epoch} (best {abs(best_ea_metric)} | current {abs(ea_metric)}). " f"Waiting for {cur_stopping_patience} epochs to stop.", file=sys.stderr) else: print(f"\nFound better model with metric {abs(ea_metric)} at epoch {epoch}. " f"Previous value was {abs(best_ea_metric) if best_ea_metric is not None else 'None'}", file=sys.stderr) best_ea_metric = ea_metric best_ea_model = copy.deepcopy({k: v.cpu() for k, v in model.state_dict().items()}) cur_stopping_patience = args.early_stopping_patience if args.eval_epochs is not None and (epoch > 0 or args.eval_epochs) and epoch % args.eval_epochs == 0 and epoch < model.args.n_epochs: epoch_accs = evaluate(model, eval_dataset) log_accs(args, logger, epoch_accs, t, dataset.SETTING, epoch=epoch) model.meta_end_task(dataset) accs = evaluate(model, eval_dataset) if args.eval_future and t < dataset.N_TASKS - 1: transf_accs = accs[0][t + 1:], accs[1][t + 1:] accs = accs[0][:t + 1], accs[1][:t + 1] results_transf.append(transf_accs[0]) results_mask_classes_transf.append(transf_accs[1]) results.append(accs[0]) results_mask_classes.append(accs[1]) log_accs(args, logger, accs, t, dataset.SETTING) if args.eval_future: avg_transf = np.mean([np.mean(task_) for task_ in results_transf]) print(f"Transfer Metrics - AVG Transfer {avg_transf:.2f}") if t < dataset.N_TASKS - 1: log_accs(args, logger, transf_accs, t, dataset.SETTING, future=True) if args.savecheck: save_obj = { 'model': model.state_dict(), 'args': args, 'results': [results, results_mask_classes, logger.dump()], 'optimizer': model.opt.state_dict() if hasattr(model, 'opt') else None, 'scheduler': scheduler.state_dict() if scheduler is not None else None, } if 'buffer_size' in model.args: save_obj['buffer'] = copy.deepcopy(model.buffer).to('cpu') # Saving model checkpoint for the current task checkpoint_name = None if args.savecheck == 'task': checkpoint_name = f'checkpoints/{args.ckpt_name}_joint.pt' if args.joint else f'checkpoints/{args.ckpt_name}_{t}.pt' elif args.savecheck == 'last' and t == end_task - 1: checkpoint_name = f'checkpoints/{args.ckpt_name}_joint.pt' if args.joint else f'checkpoints/{args.ckpt_name}_last.pt' if checkpoint_name is not None: torch.save(save_obj, checkpoint_name) if args.validation: # Final evaluation on the real test set print("Starting final evaluation on the real test set...", file=sys.stderr) del dataset args.validation = None args.validation_mode = 'current' final_dataset = get_dataset(args) for _ in range(final_dataset.N_TASKS): final_dataset.get_data_loaders() accs = evaluate(model, final_dataset) log_accs(args, logger, accs, 'final', final_dataset.SETTING, prefix="FINAL") if args.enable_other_metrics: bwt, bwt_mask_class = logger.add_bwt(results, results_mask_classes) log_extra_metrics(args, bwt, bwt_mask_class, 'Backward Transfer', t) forgetting, forgetting_mask_class = logger.add_forgetting(results, results_mask_classes) log_extra_metrics(args, forgetting, forgetting_mask_class, 'Forgetting', t) if is_fwd_enabled: fwt, fwt_mask_class = logger.add_fwt(results, random_results_class, results_mask_classes, random_results_task) log_extra_metrics(args, fwt, fwt_mask_class, 'Forward Transfer', t) else: logging.warning("Forward Transfer metric incompatible with the current model, skipped.") system_tracker.print_stats() if not args.disable_log: logger.write(vars(args)) if not args.nowand: d = logger.dump() d['wandb_url'] = wandb.run.get_url() wandb.log(d) if not args.nowand: wandb.finish()