# Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import copy
import math
import os
import sys
from argparse import Namespace
from time import time
from typing import Iterable, Tuple
import logging
import torch
from tqdm import tqdm
from datasets import get_dataset
from datasets.utils.continual_dataset import ContinualDataset, MammothDatasetWrapper
from datasets.utils.gcl_dataset import GCLDataset
from models.utils.continual_model import ContinualModel
from models.utils.future_model import FutureModel
from utils import disable_logging
from utils.checkpoints import mammoth_load_checkpoint, save_mammoth_checkpoint
from utils.loggers import log_extra_metrics, log_accs, Logger
from utils.schedulers import get_scheduler
from utils.stats import track_system_stats
try:
import wandb
except ImportError:
wandb = None
[docs]
def initialize_wandb(args: Namespace) -> None:
"""
Initializes wandb, if installed.
Args:
args: the arguments of the current execution
"""
assert wandb is not None, "Wandb not installed, please install it or run without wandb"
run_name = args.wandb_name if args.wandb_name is not None else args.model
run_id = args.conf_jobnum.split('-')[0]
name = f'{run_name}_{run_id}'
mode = 'disabled' if os.getenv('MAMMOTH_TEST', '0') == '1' else os.getenv('WANDB_MODE', 'online')
wandb.init(project=args.wandb_project, entity=args.wandb_entity, config=vars(args), name=name, mode=mode)
args.wandb_url = wandb.run.get_url()
def _to_device(name: str, x, device):
if isinstance(x, torch.Tensor):
if 'label' in name.lower() or 'target' in name.lower():
return x.to(device, dtype=torch.long)
return x.to(device)
return x
[docs]
def train_single_epoch(model: ContinualModel,
train_loader: Iterable,
args: Namespace,
epoch: int,
pbar: tqdm,
system_tracker=None,
scheduler=None) -> int:
"""
Trains the model for a single epoch.
Args:
model: the model to be trained
train_loader: the data loader for the training set
args: the arguments from the command line
epoch: the current epoch
system_tracker: the system tracker to monitor the system stats
scheduler: the scheduler for the current epoch
Returns:
the number of iterations performed in the current epoch
"""
train_iter = iter(train_loader)
epoch_len = len(train_loader) if hasattr(train_loader, "__len__") else None
i = 0
previous_time = time()
while True:
try:
data = next(train_iter)
except StopIteration:
break
if args.debug_mode and i > model.get_debug_iters():
break
if args.fitting_mode == 'iters' and model.task_iteration >= model.args.n_iters:
break
inputs, labels, not_aug_inputs = data[0], data[1], data[2]
inputs, labels = inputs.to(model.device), labels.to(model.device, dtype=torch.long)
not_aug_inputs = not_aug_inputs.to(model.device)
extra_fields = {
train_loader.dataset.extra_return_fields[k]: _to_device(train_loader.dataset.extra_return_fields[k], data[3 + k], model.device)
for k in range(len(data) - 3)
}
loss = model.meta_observe(inputs, labels, not_aug_inputs, epoch=epoch, **extra_fields)
assert not math.isnan(loss)
if scheduler is not None and args.scheduler_mode == 'iter':
scheduler.step()
if args.code_optimization == 0 and 'cuda' in str(args.device):
torch.cuda.synchronize()
system_tracker()
i += 1
time_diff = time() - previous_time
previous_time = time()
bar_log = {'loss': loss, 'lr': model.opt.param_groups[0]['lr']}
if epoch_len:
ep_h = 3600 / (epoch_len * time_diff)
bar_log['ep/h'] = ep_h
pbar.set_postfix(bar_log, refresh=False)
pbar.update()
if scheduler is not None and args.scheduler_mode == 'epoch':
scheduler.step()
[docs]
def train(model: ContinualModel, dataset: ContinualDataset,
args: Namespace) -> None:
"""
The training process, including evaluations and loggers.
Args:
model: the module to be trained
dataset: the continual dataset at hand
args: the arguments of the current execution
"""
print(args)
is_fwd_enabled = True
can_compute_fwd_beforetask = True
random_results_class, random_results_task = [], []
if not args.nowand:
initialize_wandb(args)
if not args.disable_log:
logger = Logger(args, dataset.SETTING, dataset.NAME, model.NAME)
model.net.to(model.device)
torch.cuda.empty_cache()
with track_system_stats(logger) as system_tracker:
results, results_mask_classes = [], []
if args.eval_future:
results_transf, results_mask_classes_transf = [], []
if args.start_from is not None:
for i in range(args.start_from):
train_loader, _ = dataset.get_data_loaders()
model.meta_begin_task(dataset)
model.meta_end_task(dataset)
if args.loadcheck is not None:
model, past_res = mammoth_load_checkpoint(args, model)
if not args.disable_log and past_res is not None:
(results, results_mask_classes, csvdump) = past_res
logger.load(csvdump)
print('Checkpoint Loaded!')
print(file=sys.stderr)
start_task = 0 if args.start_from is None else args.start_from
end_task = dataset.N_TASKS if args.stop_after is None else args.stop_after
if args.eval_future:
assert isinstance(model, FutureModel), "Model must be an instance of FutureModel to evaluate on future tasks"
eval_dataset = get_dataset(args)
# disable logging for this loop
with disable_logging(logging.WARNING):
for _ in range(dataset.N_TASKS):
eval_dataset.get_data_loaders()
model.change_transform(eval_dataset)
del eval_dataset.train_loader
else:
eval_dataset = dataset
torch.cuda.empty_cache()
for t in range(start_task, end_task):
model.net.train()
train_loader, _ = dataset.get_data_loaders()
if not issubclass(dataset.__class__, GCLDataset):
assert issubclass(train_loader.dataset.__class__, MammothDatasetWrapper), "Dataset must be an instance of MammothDatasetWrapper (did you forget to call the `store_masked_loaders`?)"
if can_compute_fwd_beforetask and is_fwd_enabled and args.enable_other_metrics:
# try to compute accuracy at the beginning of the task
try:
logging.info("Evaluating model before task (for Forward Transfer metric)...")
random_res_class, random_res_task = dataset.evaluate(model, dataset, last=True) # the ugliness of this line is for backward compatibility
random_results_class.append(random_res_class)
random_results_task.append(random_res_task)
except Exception as e:
logging.info(f"Could not evaluate before `begin_task`, will try after")
# will try after the begin_task in case the model needs to setup something
can_compute_fwd_beforetask = False
model.meta_begin_task(dataset)
if not can_compute_fwd_beforetask and is_fwd_enabled and args.enable_other_metrics:
if train_loader.dataset.num_times_iterated == 0: # compute only if the model has not been trained yet
try:
logging.info("Evaluating model before task (for Forward Transfer metric)...")
random_res_class, random_res_task = dataset.evaluate(model, dataset, last=True)
random_results_class.append(random_res_class)
random_results_task.append(random_res_task)
except Exception as e:
logging.error(f"Model `{model.NAME}` does not support pre-evaluation, will not compute Forward Transfer metric\n{e}")
is_fwd_enabled = False
else:
logging.info("Model used the training data, skipping Forward Transfer metric compute")
is_fwd_enabled = False
if not args.inference_only and args.n_epochs > 0:
if t and args.enable_other_metrics:
accs = eval_dataset.evaluate(model, eval_dataset, last=True)
results[t - 1] = results[t - 1] + accs[0]
if dataset.SETTING == 'class-il':
results_mask_classes[t - 1] = results_mask_classes[t - 1] + accs[1]
# Scheduler is automatically reloaded after each task if defined in the dataset.
# If the model defines it, it becomes the job of the model to reload it.
scheduler = get_scheduler(model, args, reload_optim=True) if not hasattr(model, 'scheduler') else model.scheduler
epoch = 0
best_ea_metric = None
best_ea_model = None
cur_stopping_patience = args.early_stopping_patience
n_iterations = None
if not isinstance(dataset, GCLDataset):
n_iterations = model.args.n_epochs * len(train_loader) if model.args.fitting_mode == 'epochs' else model.args.n_iters
mininterval = 0.2 if n_iterations is not None and n_iterations > 1000 else 0.1
train_pbar = tqdm(train_loader, total=n_iterations, # train_loader is actually ignored, will update the progress bar manually
disable=args.non_verbose, mininterval=mininterval)
if args.non_verbose:
logging.info(f"Task {t + 1}") # at least print the task number
while True:
model.begin_epoch(epoch, dataset)
train_pbar.set_description(f"Task {t + 1} - Epoch {epoch + 1}")
train_single_epoch(model, train_loader, args, pbar=train_pbar, epoch=epoch,
system_tracker=system_tracker, scheduler=scheduler)
model.end_epoch(epoch, dataset)
epoch += 1
if args.fitting_mode == 'epochs' and epoch >= model.args.n_epochs:
break
elif args.fitting_mode == 'iters' and model.task_iteration >= model.args.n_iters:
break
elif args.fitting_mode == 'early_stopping' and epoch % args.early_stopping_freq == 0 and epoch > 0:
epoch_accs, _, epoch_loss = eval_dataset.evaluate(model, eval_dataset, return_loss=True, last=True)
if args.early_stopping_metric == 'accuracy':
ea_metric = np.mean(epoch_accs) # Higher accuracy is better
elif args.early_stopping_metric == 'loss':
ea_metric = -epoch_loss # Lower loss is better
else:
raise ValueError(f'Unknown early stopping metric {args.early_stopping_metric}')
# Higher accuracy is better
if best_ea_metric is not None and ea_metric - best_ea_metric < args.early_stopping_epsilon:
cur_stopping_patience -= args.early_stopping_freq
if cur_stopping_patience <= 0:
print(f"\nEarly stopping at epoch {epoch} with metric {abs(ea_metric)}", file=sys.stderr)
model.load_state_dict({k: v.to(model.device) for k, v in best_ea_model.items()})
break
print(f"\nNo improvement at epoch {epoch} (best {abs(best_ea_metric)} | current {abs(ea_metric)}). "
f"Waiting for {cur_stopping_patience} epochs to stop.", file=sys.stderr)
else:
print(f"\nFound better model with metric {abs(ea_metric)} at epoch {epoch}. "
f"Previous value was {abs(best_ea_metric) if best_ea_metric is not None else 'None'}", file=sys.stderr)
best_ea_metric = ea_metric
best_ea_model = copy.deepcopy({k: v.cpu() for k, v in model.state_dict().items()})
cur_stopping_patience = args.early_stopping_patience
if args.eval_epochs is not None and (epoch > 0 or args.eval_epochs) and epoch % args.eval_epochs == 0 and epoch < model.args.n_epochs:
epoch_accs = eval_dataset.evaluate(model, eval_dataset)
eval_dataset.log(args, logger, epoch_accs, t, dataset.SETTING, epoch=epoch)
train_pbar.close()
model.meta_end_task(dataset)
accs = eval_dataset.evaluate(model, eval_dataset)
if args.eval_future and t < dataset.N_TASKS - 1:
transf_accs = accs[0][t + 1:], accs[1][t + 1:]
accs = accs[0][:t + 1], accs[1][:t + 1]
results_transf.append(transf_accs[0])
results_mask_classes_transf.append(transf_accs[1])
logged_accs = eval_dataset.log(args, logger, accs, t, dataset.SETTING)
if dataset.SETTING != 'biased-class-il':
results.append(accs[0])
results_mask_classes.append(accs[1])
else:
results.append(logged_accs[0]) # avg
results_mask_classes.append(logged_accs[1]) # worst
if args.eval_future:
avg_transf = np.mean([np.mean(task_) for task_ in results_transf])
print(f"Transfer Metrics - AVG Transfer {avg_transf:.2f}")
if t < dataset.N_TASKS - 1:
eval_dataset.log(args, logger, transf_accs, t, dataset.SETTING, future=True)
if args.savecheck:
save_mammoth_checkpoint(t, end_task, args,
model,
results=[results, results_mask_classes, logger.dump()],
optimizer_st=model.opt.state_dict() if hasattr(model, 'opt') else None,
scheduler_st=scheduler.state_dict() if scheduler is not None else None)
if args.validation:
# Final evaluation on the real test set
print("Starting final evaluation on the real test set...", file=sys.stderr)
del dataset
args.validation = None
args.validation_mode = 'current'
final_dataset = get_dataset(args)
for _ in range(final_dataset.N_TASKS):
final_dataset.get_data_loaders()
accs = final_dataset.evaluate(model, final_dataset)
final_dataset.log(args, logger, accs, 'final', final_dataset.SETTING, prefix="FINAL")
if args.enable_other_metrics:
bwt, bwt_mask_class = logger.add_bwt(results, results_mask_classes)
log_extra_metrics(args, bwt, bwt_mask_class, 'Backward Transfer', t)
forgetting, forgetting_mask_class = logger.add_forgetting(results, results_mask_classes)
log_extra_metrics(args, forgetting, forgetting_mask_class, 'Forgetting', t)
if is_fwd_enabled:
fwt, fwt_mask_class = logger.add_fwt(results, random_results_class,
results_mask_classes, random_results_task)
log_extra_metrics(args, fwt, fwt_mask_class, 'Forward Transfer', t)
else:
logging.warning("Forward Transfer metric incompatible with the current model, skipped.")
system_tracker.print_stats()
if not args.disable_log:
logger.write(vars(args))
if not args.nowand:
d = logger.dump()
d['wandb_url'] = wandb.run.get_url()
wandb.log(d)
if not args.nowand:
wandb.finish()