TRAINING#
Functions#
- utils.training.evaluate(model, dataset, last=False, return_loss=False)[source]#
Evaluates the accuracy of the model for each past task.
The accuracy is evaluated for all the tasks up to the current one, only for the total number of classes seen so far.
- Parameters:
model (ContinualModel) – the model to be evaluated
dataset (ContinualDataset) – the continual dataset at hand
last – a boolean indicating whether to evaluate only the last task
return_loss – a boolean indicating whether to return the loss in addition to the accuracy
- Returns:
a tuple of lists, containing the class-il and task-il accuracy for each task. If return_loss is True, the loss is also returned as a third element.
- Return type:
- utils.training.initialize_wandb(args)[source]#
Initializes wandb, if installed.
- Parameters:
args (Namespace) – the arguments of the current execution
- utils.training.mask_classes(outputs, dataset, k)[source]#
Given the output tensor, the dataset at hand and the current task, masks the former by setting the responses for the other tasks at -inf. It is used to obtain the results for the task-il setting.
- Parameters:
outputs (Tensor) – the output tensor
dataset (ContinualDataset) – the continual dataset
k (int) – the task index
- utils.training.train(model, dataset, args)[source]#
The training process, including evaluations and loggers.
- Parameters:
model (ContinualModel) – the module to be trained
dataset (ContinualDataset) – the continual dataset at hand
args (Namespace) – the arguments of the current execution
- utils.training.train_single_epoch(model, train_loader, args, epoch, current_task, system_tracker=None, data_len=None, scheduler=None)[source]#
Trains the model for a single epoch.
- Parameters:
model (ContinualModel) – the model to be trained
train_loader (Iterable) – the data loader for the training set
args (Namespace) – the arguments from the command line
epoch (int) – the current epoch
current_task (int) – the current task index
system_tracker – the system tracker to monitor the system stats
data_len – the length of the training data loader. If None, the progress bar will not show the training percentage
scheduler – the scheduler for the current epoch
- Returns:
the number of iterations performed in the current epoch
- Return type: