CONTINUAL MODEL#
This is the base class for all models. It provides some useful methods and defines the interface of the models.
The observe method is the most important one: it is called at each training iteration and it is responsible for computing the loss and updating the model’s parameters.
The begin_task and end_task methods are called before and after each task, respectively.
The get_parser method returns the parser of the model. Additional model-specific hyper-parameters can be added by overriding this method.
The get_debug_iters method returns the number of iterations to be used for debugging. Default: 3.
The get_optimizer method returns the optimizer to be used for training. Default: SGD.
The load_buffer method is called when a buffer is loaded. Default: do nothing.
The meta_observe, meta_begin_task and meta_end_task methods are wrappers for observe, begin_task and end_task methods, respectively. They take care of updating the internal counters and of logging to wandb if installed.
The autolog_wandb method is used to automatically log to wandb all variables starting with “_wandb_” or “loss” in the observe function. It is called by meta_observe if wandb is installed. It can be overridden to add custom logging.
Classes#
- class models.utils.continual_model.ContinualModel(backbone, loss, args, transform, dataset=None)[source]#
Bases:
Module
Continual learning model.
- AVAIL_OPTIMS = ['sgd', 'adam', 'adamw']#
- autolog_wandb(locals, extra=None)[source]#
All variables starting with “_wandb_” or “loss” in the observe function are automatically logged to wandb upon return if wandb is installed.
- property classes_per_task#
Returns the raw number of classes per task. Warning: return value might be either an integer or a list of integers depending on the dataset.
- property cpt#
returns the raw number of classes per task. Warning: return value might be either an integer or a list of integers depending on the dataset.
- Type:
Alias of classes_per_task
- property current_task#
Returns the index of current task.
- dataset: ContinualDataset#
- property epoch_iteration#
Returns the number of iterations in the current epoch.
- get_optimizer(params=None, lr=None)[source]#
Returns the optimizer to be used for training.
Default: SGD.
- static get_parser(parser)[source]#
Defines model-specific hyper-parameters, which will be added to the command line arguments. Additional model-specific hyper-parameters can be added by overriding this method.
For backward compatibility, the parser object may be omitted (although this should be avoided). In this case, the method should create and return a new parser.
This method may also be used to set default values for all other hyper-parameters of the framework (e.g., lr, buffer_size, etc.) with the set_defaults method of the parser. In this case, this method MUST update the original parser object and not create a new one.
- Parameters:
parser (ArgumentParser) – the main parser, to which the model-specific arguments will be added
- Returns:
the parser of the model
- Return type:
- meta_begin_task(dataset)[source]#
Wrapper for begin_task method.
Takes care of updating the internal counters.
- Parameters:
dataset – the current task’s dataset
- meta_end_task(dataset)[source]#
Wrapper for end_task method.
Takes care of updating the internal counters.
- Parameters:
dataset – the current task’s dataset
- meta_observe(*args, **kwargs)[source]#
Wrapper for observe method.
Takes care of dropping unlabeled data if not supported by the model and of logging to wandb if installed.
- Parameters:
inputs – batch of inputs
labels – batch of labels
not_aug_inputs – batch of inputs without augmentation
kwargs – some methods could require additional parameters
- Returns:
the value of the loss function
- property n_classes_current_task#
Returns the number of classes in the current task. Returns -1 if task has not been initialized yet.
- property n_past_classes#
Returns the number of classes seen up to the PAST task. Returns -1 if task has not been initialized yet.
- property n_remaining_classes#
Returns the number of classes remaining to be seen. Returns -1 if task has not been initialized yet.
- property n_seen_classes#
Returns the number of classes seen so far. Returns -1 if task has not been initialized yet.
- net: MammothBackbone#
- abstract observe(inputs, labels, not_aug_inputs, epoch=None)[source]#
Compute a training step over a given batch of examples.
- original_transform: Compose#
- scheduler: _LRScheduler#
- property task_iteration#
Returns the number of iterations in the current task.
- transform: Compose | AugmentationSequential#