GSS BUFFER#
This module contains a version of the reservoir buffer that is specifically designed for the GSS model.
Classes#
- class utils.gss_buffer.Buffer(buffer_size, device, minibatch_size, model=None)[source]#
Bases:
object
Memory buffer for the GSS model. The buffer supports only examples and labels tensors.
- add_data(examples, labels=None)[source]#
Adds the data to the memory buffer according to the reservoir strategy.
- Parameters:
examples – tensor containing the images
labels – tensor containing the labels
logits – tensor containing the outputs of the network
task_labels – tensor containing the task labels
- get_all_data(transform=None)[source]#
Return all the items in the memory buffer.
- Parameters:
transform (<module 'torchvision.transforms' from '/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/torchvision/transforms/__init__.py'>) – the transformation to be applied (data augmentation)
- Returns:
a tuple with all the items in the memory buffer
- Return type:
- get_data(size, transform=None, give_index=False, random=False, device=None)[source]#
Random samples a batch of size items.
- Parameters:
size (int) – the number of requested items
transform (<module 'torchvision.transforms' from '/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/torchvision/transforms/__init__.py'>) – the transformation to be applied (data augmentation)
- Returns:
a tuple with the requested items
- Return type: