LWS#

Arguments#

Options

--buf_lambda_logitsfloat

Help: Penalty weight BCE past logits.

  • Default: 1

--kd_lambdafloat

Help: Penalty weight MSE clusters Logits (fixed to 1, not searched)

  • Default: 1

--buf_lambda_clustersfloat

Help: Penalty weight BCE past clusters.

  • Default: 1

--gammafloat

Help: Weight cluster contribution (Eq. 3 and 4)

  • Default: 1

--kint

Help: Number of clusters

  • Default: 8

--n_binint

Help: Number of bins

  • Default: 4

--momentumfloat

Help: Momentum for weights update

  • Default: 0.3

Rehearsal arguments

Arguments shared by all rehearsal-based methods.

--buffer_sizeint

Help: The size of the memory buffer.

  • Default: None

--minibatch_sizeint

Help: The batch size of the memory buffer.

  • Default: None

Classes#

class models.lws.LwS(backbone, loss, args, transform, dataset=None)[source]#

Bases: ContinualModel

Implementation of “Towards Unbiased Continual Learning: Avoiding Forgetting in the Presence of Spurious Correlations”

COMPATIBILITY: List[str] = ['biased-class-il']#
NAME: str = 'lws'#
begin_epoch(epoch, dataset)[source]#
begin_task(dataset)[source]#
cluster_counts()[source]#
clustering(train_loader)[source]#
compute_stats()[source]#
end_task(dataset)[source]#
extract_features(train_loader)[source]#
forward(inputs)[source]#
freeze_classifiers(task_id, freeze_cluster=False)[source]#
get_classes_and_clusters(inputs)[source]#
get_initial_losses(dataset)[source]#
static get_parser(parser)[source]#
Return type:

ArgumentParser

get_task_weights()[source]#
get_weights(indexes)[source]#
init_classifiers()[source]#
observe(inputs, labels, not_aug_inputs, epoch, indexes)[source]#
update(target_losses, clusters_losses, indexes, epoch)[source]#
update_cluster_weights(indexes, epoch)[source]#