LWF MC#
Arguments#
Options
- --wd_regfloat
Help: L2 regularization applied to the parameters.
Default:
0.0
Classes#
- class models.lwf_mc.LwFMC(backbone, loss, args, transform, dataset=None)[source]#
Bases:
ContinualModel
Learning without Forgetting - Multi-Class.
- get_loss(inputs, labels, task_idx, logits)[source]#
Computes the loss tensor.
- Parameters:
inputs (Tensor) – the images to be fed to the network
labels (Tensor) – the ground-truth labels
task_idx (int) – the task index
- Returns:
the differentiable loss value
- Return type:
Tensor