TRAINING#

Functions#

utils.training.evaluate(model, dataset, last=False, return_loss=False)[source]#

Evaluates the accuracy of the model for each past task.

The accuracy is evaluated for all the tasks up to the current one, only for the total number of classes seen so far.

Parameters:
  • model (ContinualModel) – the model to be evaluated

  • dataset (ContinualDataset) – the continual dataset at hand

  • last – a boolean indicating whether to evaluate only the last task

  • return_loss – a boolean indicating whether to return the loss in addition to the accuracy

Returns:

a tuple of lists, containing the class-il and task-il accuracy for each task. If return_loss is True, the loss is also returned as a third element.

Return type:

Tuple[list, list]

utils.training.initialize_wandb(args)[source]#

Initializes wandb, if installed.

Parameters:

args (Namespace) – the arguments of the current execution

utils.training.mask_classes(outputs, dataset, k)[source]#

Given the output tensor, the dataset at hand and the current task, masks the former by setting the responses for the other tasks at -inf. It is used to obtain the results for the task-il setting.

Parameters:
utils.training.train(model, dataset, args)[source]#

The training process, including evaluations and loggers.

Parameters:
utils.training.train_single_epoch(model, train_loader, args, epoch, current_task, system_tracker=None, data_len=None, scheduler=None)[source]#

Trains the model for a single epoch.

Parameters:
  • model (ContinualModel) – the model to be trained

  • train_loader (Iterable) – the data loader for the training set

  • args (Namespace) – the arguments from the command line

  • epoch (int) – the current epoch

  • current_task (int) – the current task index

  • system_tracker – the system tracker to monitor the system stats

  • data_len – the length of the training data loader. If None, the progress bar will not show the training percentage

  • scheduler – the scheduler for the current epoch

Returns:

the number of iterations performed in the current epoch

Return type:

int