PURIDIVER#

Arguments#

Options

--use_bn_classifierint

Help: Use batch normalization in the classifier?

  • Default: 1

  • Choices: 0, 1

--freeze_buffer_after_firstint

Help: Freeze buffer after first task (i.e., simulate online update of the buffer, useful for multi-epoch)?

  • Default: 0

  • Choices: 0, 1

--initial_alphafloat

Help: None

  • Default: 0.5

--disable_train_augint

Help: Disable training augmentation?

  • Default: 1

  • Choices: 0, 1

--buffer_fitting_epochsint

Help: Number of epochs to fit on buffer

  • Default: 255

--warmup_buffer_fitting_epochsint

Help: Number of warmup epochs during which fit with simple CE

  • Default: 10

--enable_cutmixint

Help: Enable cutmix augmentation?

  • Default: 1

  • Choices: 0, 1

--cutmix_probfloat

Help: Cutmix probability

  • Default: 0.5

Rehearsal arguments

Arguments shared by all rehearsal-based methods.

--buffer_sizeint

Help: The size of the memory buffer.

  • Default: None

--minibatch_sizeint

Help: The batch size of the memory buffer.

  • Default: None

Classes#

class models.puridiver.CustomDataset(data, targets, transform=None, probs=None, extra=None, device='cpu')[source]#

Bases: Dataset

set_probs(probs)[source]#

Set the probability of each data point being correct (i.e., belonging to the Gaussian with the lowest mean)

class models.puridiver.PuriDivER(backbone, loss, args, transform, dataset=None)[source]#

Bases: ContinualModel

PuriDivER: Online Continual Learning on a Contaminated Data Stream with Blurry Task Boundaries.

COMPATIBILITY: List[str] = ['class-il', 'task-il']#
NAME: str = 'puridiver'#
base_fit_buffer(loader=None)[source]#
begin_task(dataset)[source]#
end_task(dataset)[source]#
fit_buffer()[source]#
get_classifier_weights()[source]#
get_current_alpha_sim_score(loss)[source]#
static get_parser(parser)[source]#
Return type:

ArgumentParser

get_scheduler()[source]#
get_sim_score(feats, targets)[source]#
get_subset_dl_from_idxs(idxs, batch_size, probs=None, transform=None)[source]#
observe(inputs, labels, not_aug_inputs, true_labels, epoch)[source]#
puridiver_update_buffer(stream_not_aug_inputs, stream_labels, stream_true_labels)[source]#
split_data_puridiver(n=2)[source]#
train_with_mixmatch(loader_L, loader_U, loader_R)[source]#

Functions#

models.puridiver.get_dataloader_from_buffer(args, buffer, batch_size, shuffle=False, transform=None)[source]#
models.puridiver.get_hard_transform(dataset)[source]#
models.puridiver.soft_cross_entropy_loss(input, target, reduction='mean')[source]#

https://github.com/pytorch/pytorch/issues/11959

Parameters:
  • input – (batch, *)

  • target – (batch, *) same shape as input, each item must be a valid distribution: target[i, :].sum() == 1.