Source code for utils.ring_buffer

"""
This module contains a version of the reservoir buffer that uses a ring buffer strategy instead of reservoir.
"""

# Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from typing import Tuple

import numpy as np
import torch
from torchvision import transforms

from utils.augmentations import apply_transform


[docs] def ring(num_seen_examples: int, buffer_portion_size: int, task: int) -> int: return num_seen_examples % buffer_portion_size + task * buffer_portion_size
[docs] class RingBuffer: """ The memory buffer of rehearsal method. """ def __init__(self, buffer_size, n_tasks=1, device="cpu"): self.buffer_size = buffer_size self.buffer_portion_size = buffer_size // n_tasks self.device = device self.task_number = 0 self.num_seen_examples = 0 self.attributes = ['examples', 'labels', 'logits', 'task_labels']
[docs] def init_tensors(self, examples: torch.Tensor, labels: torch.Tensor, logits: torch.Tensor, task_labels: torch.Tensor) -> None: """ Initializes just the required tensors. Args: examples: tensor containing the images labels: tensor containing the labels logits: tensor containing the outputs of the network task_labels: tensor containing the task labels """ for attr_str in self.attributes: attr = eval(attr_str) if attr is not None and not hasattr(self, attr_str): typ = torch.int64 if attr_str.endswith('els') else torch.float32 setattr(self, attr_str, torch.zeros((self.buffer_size, *attr.shape[1:]), dtype=typ, device=self.device)) self.filled_space = torch.zeros((self.buffer_size), dtype=torch.bool, device=self.device) # initialize filled space
[docs] def add_data(self, examples, labels=None, logits=None, task_labels=None): """ Adds the data to the memory buffer according to the reservoir strategy. Args: examples: tensor containing the images labels: tensor containing the labels logits: tensor containing the outputs of the network task_labels: tensor containing the task labels """ if not hasattr(self, 'examples'): self.init_tensors(examples, labels, logits, task_labels) for i in range(examples.shape[0]): index = ring(self.num_seen_examples, self.buffer_portion_size, self.task_number) self.num_seen_examples += 1 if index >= 0: self.examples[index] = examples[i].to(self.device) self.filled_space[index] = True if labels is not None: self.labels[index] = labels[i].to(self.device) if logits is not None: self.logits[index] = logits[i].to(self.device) if task_labels is not None: self.task_labels[index] = task_labels[i].to(self.device)
[docs] def get_data(self, size: int, transform: transforms = None, device=None) -> Tuple: """ Random samples a batch of size items. Args: size: the number of requested items transform: the transformation to be applied (data augmentation) Returns: a tuple with the requested items """ target_device = self.device if device is None else device populated_portion_length = self.filled_space.sum().item() if size > populated_portion_length: size = populated_portion_length choice = torch.from_numpy(np.random.choice(populated_portion_length, size=size, replace=False)) if transform is None: def transform(x): return x ret_tuple = (apply_transform(self.examples[choice], transform=transform).to(target_device),) for attr_str in self.attributes[1:]: if hasattr(self, attr_str): attr = getattr(self, attr_str) ret_tuple += (attr[choice].to(target_device),) return ret_tuple
[docs] def is_empty(self) -> bool: """ Returns true if the buffer is empty, false otherwise. """ if self.num_seen_examples == self.task_number == 0: return True else: return False
[docs] def get_all_data(self, transform: transforms = None, device=None) -> Tuple: """ Return all the items in the memory buffer. Args: transform: the transformation to be applied (data augmentation) device: the device to be used Returns: a tuple with all the items in the memory buffer """ target_device = self.device if device is None else device if transform is None: def transform(x): return x ret_tuple = (apply_transform(self.examples[choice], transform=transform).to(target_device),) for attr_str in self.attributes[1:]: if hasattr(self, attr_str): attr = getattr(self, attr_str) ret_tuple += (attr.to(target_device),) return ret_tuple
[docs] def empty(self) -> None: """ Set all the tensors to None. """ for attr_str in self.attributes: if hasattr(self, attr_str): delattr(self, attr_str) self.num_seen_examples = 0 self.filled_space[:] = False