Source code for utils.training

import json
import os
import numpy as np
import math
from argparse import Namespace
from typing import Iterable, Optional, TYPE_CHECKING
import sys
import torch

from datasets.utils.continual_dataset import ContinualDataset
from models.utils.continual_model import ContinualModel

from utils.checkpoints import mammoth_load_checkpoint, save_mammoth_checkpoint
from utils.evaluate import evaluate

if TYPE_CHECKING:
    from tqdm import tqdm
else:
    # Check if we're in a notebook environment
    if 'ipykernel' not in sys.modules:
        from tqdm import tqdm
    else:
        from tqdm.notebook import tqdm

[docs] def update_progress_bar(pbar: 'tqdm', model: ContinualModel): if model.pbar_suffix: pbar.set_postfix(model.pbar_suffix, refresh=False) pbar.update()
[docs] def train_epoch(model: ContinualModel, train_loader: Iterable, args: Namespace, epoch: int, pbar: tqdm): """ Trains the model for a single epoch. Args: model: the model to be trained train_loader: the data loader for the training set args: the arguments from the command line epoch: the current epoch pbar: the progress bar to update Returns: the number of iterations performed in the current epoch """ for i, data in enumerate(train_loader): if args.debug_mode and i > 5: break inputs, labels, not_aug_inputs = data[0], data[1], data[2] inputs, labels = inputs.to(model.device), labels.to(model.device, dtype=torch.long) not_aug_inputs = not_aug_inputs.to(model.device) loss = model.observe(inputs, labels, not_aug_inputs, epoch=epoch) assert not math.isnan(loss) model.pbar_suffix.update({'loss': loss, 'lr': model.opt.param_groups[0]['lr']}) update_progress_bar(pbar, model)
[docs] def train(model: ContinualModel, dataset: ContinualDataset, args: Optional[Namespace] = None) -> None: """ The training process, including evaluations and loggers. Args: model: the module to be trained dataset: the continual dataset at hand args: the arguments of the current execution. If None, it will be loaded from the environment variable 'MAMMOTH_ARGS'. """ if args is None: env_args = os.getenv('MAMMOTH_ARGS') if env_args is None: raise ValueError("No arguments provided. Did you run the `load_runner` function?") args = Namespace(**json.loads(env_args)) os.environ['MAMMOTH_ARGS'] = json.dumps(vars(args)) dataset.reset() # reset the dataset to the initial state model.net.to(model.device) torch.cuda.empty_cache() results: list[float] = [] results_mask_classes: list[float] = [] if args.loadcheck is not None: model, past_res = mammoth_load_checkpoint(args, model) if not args.disable_log and past_res is not None: results, results_mask_classes = past_res print('Checkpoint Loaded!') torch.cuda.empty_cache() try: for t in range(dataset.N_TASKS): model.net.train() train_loader, _ = dataset.meta_get_data_loaders() model.begin_task(dataset) with tqdm(train_loader, total=len(train_loader) * args.n_epochs, mininterval=0.1) as train_pbar: for epoch in range(args.n_epochs): train_pbar.set_description(f"Task {t + 1} - Epoch {epoch + 1}") model.begin_epoch(epoch, dataset) train_epoch(model, train_loader, args, pbar=train_pbar, epoch=epoch) model.end_epoch(epoch, dataset) update_progress_bar(train_pbar, model) model.end_task(dataset) accs, avg_loss = evaluate(model, dataset) mean_acc = np.mean(accs, axis=1) print(f'Accuracy for task {t + 1}\t[Class-IL]: {mean_acc[0]:.2f}% \t[Task-IL]: {mean_acc[1]:.2f}%') print(f'Average loss: {avg_loss:.6f}') results.append(accs[0]) results_mask_classes.append(accs[1]) if args.savecheck: save_mammoth_checkpoint(t, dataset.N_TASKS, args, model, results=(results, results_mask_classes), # type: ignore optimizer_st=model.opt.state_dict()) except KeyboardInterrupt: print("Training interrupted!")