RING BUFFER#

This module contains a version of the reservoir buffer that uses a ring buffer strategy instead of reservoir.

Classes#

class utils.ring_buffer.RingBuffer(buffer_size, n_tasks=1, device='cpu')[source]#

Bases: object

The memory buffer of rehearsal method.

add_data(examples, labels=None, logits=None, task_labels=None)[source]#

Adds the data to the memory buffer according to the reservoir strategy.

Parameters:
  • examples – tensor containing the images

  • labels – tensor containing the labels

  • logits – tensor containing the outputs of the network

  • task_labels – tensor containing the task labels

empty()[source]#

Set all the tensors to None.

get_all_data(transform=None, device=None)[source]#

Return all the items in the memory buffer.

Parameters:
  • transform (<module 'torchvision.transforms' from '/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/torchvision/transforms/__init__.py'>) – the transformation to be applied (data augmentation)

  • device – the device to be used

Returns:

a tuple with all the items in the memory buffer

Return type:

Tuple

get_data(size, transform=None, device=None)[source]#

Random samples a batch of size items.

Parameters:
  • size (int) – the number of requested items

  • transform (<module 'torchvision.transforms' from '/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/torchvision/transforms/__init__.py'>) – the transformation to be applied (data augmentation)

Returns:

a tuple with the requested items

Return type:

Tuple

init_tensors(examples, labels, logits, task_labels)[source]#

Initializes just the required tensors.

Parameters:
  • examples (Tensor) – tensor containing the images

  • labels (Tensor) – tensor containing the labels

  • logits (Tensor) – tensor containing the outputs of the network

  • task_labels (Tensor) – tensor containing the task labels

is_empty()[source]#

Returns true if the buffer is empty, false otherwise.

Return type:

bool

Functions#

utils.ring_buffer.ring(num_seen_examples, buffer_portion_size, task)[source]#
Return type:

int