LWF MC#

Arguments#

Options

--wd_regfloat

Help: L2 regularization applied to the parameters.

  • Default: 0.0

Classes#

class models.lwf_mc.LwFMC(backbone, loss, args, transform, dataset=None)[source]#

Bases: ContinualModel

Learning without Forgetting - Multi-Class.

COMPATIBILITY: List[str] = ['class-il', 'task-il']#
NAME: str = 'lwf_mc'#
end_task(dataset)[source]#
get_loss(inputs, labels, task_idx, logits)[source]#

Computes the loss tensor.

Parameters:
  • inputs (Tensor) – the images to be fed to the network

  • labels (Tensor) – the ground-truth labels

  • task_idx (int) – the task index

Returns:

the differentiable loss value

Return type:

Tensor

static get_parser(parser)[source]#
Return type:

ArgumentParser

observe(inputs, labels, not_aug_inputs, logits=None, epoch=None)[source]#