BUFFER#
Classes#
- class utils.buffer.ABSSampling(buffer_size, device, dataset)[source]#
Bases:
LARSSampling
- class utils.buffer.BaseSampleSelection(buffer_size, device)[source]#
Bases:
object
Base class for sample selection strategies.
- class utils.buffer.Buffer(buffer_size, device='cpu', sample_selection_strategy='reservoir', **kwargs)[source]#
Bases:
object
The memory buffer of rehearsal method.
- add_data(examples, labels=None, logits=None, task_labels=None, attention_maps=None, true_labels=None, sample_selection_scores=None)[source]#
Adds the data to the memory buffer according to the reservoir strategy.
- Parameters:
examples – tensor containing the images
labels – tensor containing the labels
logits – tensor containing the outputs of the network
task_labels – tensor containing the task labels
attention_maps – list of tensors containing the attention maps
true_labels – if setting is noisy, the true labels associated with the examples. Used only for logging.
sample_selection_scores – tensor containing the scores used for the sample selection strategy. NOTE: this is only used if the sample selection strategy defines the update method.
Note
Only the examples are required. The other tensors are initialized only if they are provided.
- get_data(size, transform=None, return_index=False, device=None, mask_task_out=None, cpt=None, return_not_aug=False, not_aug_transform=None)[source]#
Random samples a batch of size items.
- Parameters:
size (int) – the number of requested items
transform (Module | None) – the transformation to be applied (data augmentation)
return_index – if True, returns the indexes of the sampled items
mask_task – if not None, masks OUT the examples from the given task
cpt – the number of classes per task (required if mask_task is not None and task_labels are not present)
return_not_aug – if True, also returns the not augmented items
not_aug_transform – the transformation to be applied to the not augmented items (if return_not_aug is True)
- Returns:
a tuple containing the requested items. If return_index is True, the tuple contains the indexes as first element.
- Return type:
- get_data_by_index(indexes, transform=None, device=None)[source]#
Returns the data by the given index.
- init_tensors(examples, labels, logits, task_labels, true_labels)[source]#
Initializes just the required tensors.
- Parameters:
- to(device)[source]#
Move the buffer and its attributes to the specified device.
- Parameters:
device – The device to move the buffer and its attributes to.
- Returns:
The buffer instance with the updated device and attributes.
- property used_attributes#
Returns a list of attributes that are currently being used by the object.
- class utils.buffer.LARSSampling(buffer_size, device)[source]#
Bases:
BaseSampleSelection
- class utils.buffer.LossAwareBalancedSampling(buffer_size, device)[source]#
Bases:
BaseSampleSelection
Combination of Loss-Aware Sampling (LARS) and Balanced Reservoir Sampling (BRS) from Rethinking Experience Replay: a Bag of Tricks for Continual Learning.
- class utils.buffer.ReservoirSampling(buffer_size, device)[source]#
Bases:
BaseSampleSelection
Functions#
- utils.buffer.fill_buffer(buffer, dataset, t_idx, net=None, use_herding=False, required_attributes=None)[source]#
Adds examples from the current task to the memory buffer. Supports images, labels, task_labels, and logits.
- Parameters:
buffer (Buffer) – the memory buffer
dataset (ContinualDataset) – the dataset from which take the examples
t_idx (int) – the task index
net (ContinualModel) – (optional) the model instance. Used if logits are in buffer. If provided, adds logits.
use_herding – (optional) if True, uses herding strategy. Otherwise, random sampling.
required_attributes (List[str]) – (optional) the attributes to be added to the buffer. If None and buffer is empty, adds only examples and labels.
- utils.buffer.icarl_replay(self, dataset, val_set_split=0)[source]#
Merge the replay buffer with the current task data. Optionally split the replay buffer into a validation set.
- Parameters:
self (ContinualModel) – the model instance
dataset – the dataset
val_set_split – the fraction of the replay buffer to be used as validation set