Source code for utils.gss_buffer

"""
This module contains a version of the reservoir buffer that is specifically designed for the GSS model.
"""

# Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from typing import Tuple

import numpy as np
import torch
import torch.nn.functional as F
from torchvision import transforms

from utils.augmentations import apply_transform


[docs] class Buffer: """ Memory buffer for the GSS model. The buffer supports only `examples` and `labels` tensors. """ def __init__(self, buffer_size, device, minibatch_size, model=None): self.buffer_size = buffer_size self.device = device self.num_seen_examples = 0 self.attributes = ['examples', 'labels'] self.model = model self.minibatch_size = minibatch_size self.cache = {} self.fathom = 0 self.fathom_mask = None self.reset_fathom()
[docs] def reset_fathom(self): self.fathom = 0 self.fathom_mask = torch.randperm(min(self.num_seen_examples, self.examples.shape[0] if hasattr(self, 'examples') else self.num_seen_examples))
[docs] def get_grad_score(self, batch_x, batch_y, X, Y, indices): g = self.model.get_grads(batch_x, batch_y) G = [] for x, y, idx in zip(X, Y, indices): if idx in self.cache: grd = self.cache[idx] else: grd = self.model.get_grads(x.unsqueeze(0), y.unsqueeze(0)) self.cache[idx] = grd G.append(grd) G = torch.cat(G).to(g.device) c_score = 0 grads_at_a_time = 5 # let's split this so your gpu does not melt. You're welcome. for it in range(int(np.ceil(G.shape[0] / grads_at_a_time))): tmp = F.cosine_similarity(g, G[it * grads_at_a_time: (it + 1) * grads_at_a_time], dim=1).max().item() + 1 c_score = max(c_score, tmp) return c_score
[docs] def functional_reservoir(self, x, y, batch_c, bigX=None, bigY=None, indices=None): if self.num_seen_examples < self.buffer_size: return self.num_seen_examples, batch_c elif batch_c < 1: single_c = self.get_grad_score(x.unsqueeze(0), y.unsqueeze(0), bigX, bigY, indices) s = self.scores.cpu().numpy() i = np.random.choice(np.arange(0, self.buffer_size), size=1, p=s / s.sum())[0] rand = np.random.rand(1)[0] # print(rand, s[i] / (s[i] + c)) if rand < s[i] / (s[i] + single_c): return i, single_c return -1, 0
[docs] def init_tensors(self, examples: torch.Tensor, labels: torch.Tensor) -> None: """ Initializes just the required tensors. Args: examples: tensor containing the images labels: tensor containing the labels logits: tensor containing the outputs of the network task_labels: tensor containing the task labels """ for attr_str in self.attributes: attr = eval(attr_str) if attr is not None and not hasattr(self, attr_str): typ = torch.int64 if attr_str.endswith('els') else torch.float32 setattr(self, attr_str, torch.zeros((self.buffer_size, *attr.shape[1:]), dtype=typ, device=self.device)) self.scores = torch.zeros((self.buffer_size, *attr.shape[1:]), dtype=torch.float32, device=self.device)
[docs] def add_data(self, examples, labels=None): """ Adds the data to the memory buffer according to the reservoir strategy. Args: examples: tensor containing the images labels: tensor containing the labels logits: tensor containing the outputs of the network task_labels: tensor containing the task labels """ if not hasattr(self, 'examples'): self.init_tensors(examples, labels) # compute buffer score if self.num_seen_examples > 0: bigX, bigY, indices = self.get_data(min(self.minibatch_size, self.num_seen_examples), give_index=True, random=True) c = self.get_grad_score(examples, labels, bigX, bigY, indices) else: bigX, bigY, indices = None, None, None c = 0.1 for i in range(examples.shape[0]): index, score = self.functional_reservoir(examples[i], labels[i], c, bigX, bigY, indices) self.num_seen_examples += 1 if index >= 0: self.examples[index] = examples[i].to(self.device) if labels is not None: self.labels[index] = labels[i].to(self.device) self.scores[index] = score if index in self.cache: del self.cache[index]
[docs] def drop_cache(self): self.cache = {}
[docs] def get_data(self, size: int, transform: transforms = None, give_index=False, random=False, device=None) -> Tuple: """ Random samples a batch of size items. Args: size: the number of requested items transform: the transformation to be applied (data augmentation) Returns: a tuple with the requested items """ target_device = self.device if device is None else device if size > self.examples.shape[0]: size = self.examples.shape[0] if random: choice = np.random.choice(min(self.num_seen_examples, self.examples.shape[0]), size=min(size, self.num_seen_examples), replace=False) else: choice = np.arange(self.fathom, min(self.fathom + size, self.examples.shape[0], self.num_seen_examples)) choice = self.fathom_mask[choice] self.fathom += len(choice) if self.fathom >= self.examples.shape[0] or self.fathom >= self.num_seen_examples: self.fathom = 0 if transform is None: def transform(x): return x selected_samples = self.examples[choice] ret_tuple = (apply_transform(selected_samples, transform=transform).to(target_device),) for attr_str in self.attributes[1:]: if hasattr(self, attr_str): attr = getattr(self, attr_str) ret_tuple += (attr[choice].to(target_device),) if give_index: ret_tuple += (choice,) return ret_tuple
[docs] def is_empty(self) -> bool: """ Returns true if the buffer is empty, false otherwise. """ if self.num_seen_examples == 0: return True else: return False
[docs] def get_all_data(self, transform: transforms = None) -> Tuple: """ Return all the items in the memory buffer. Args: transform: the transformation to be applied (data augmentation) Returns: a tuple with all the items in the memory buffer """ if transform is None: def transform(x): return x ret_tuple = (torch.stack([transform(ee.cpu()) for ee in self.examples]).to(self.device),) for attr_str in self.attributes[1:]: if hasattr(self, attr_str): attr = getattr(self, attr_str) ret_tuple += (attr,) return ret_tuple
[docs] def empty(self) -> None: """ Set all the tensors to None. """ for attr_str in self.attributes: if hasattr(self, attr_str): delattr(self, attr_str) self.num_seen_examples = 0