"""
Custom buffer for Learning without Shortcuts (LwS).
"""
import torch
import numpy as np
from typing import Tuple
from torch.functional import Tensor
from torchvision import transforms
from torch.utils.data import Dataset
import math
from typing import Tuple
from utils.augmentations import apply_transform
[docs]
class Buffer(Dataset):
    def __init__(self, buffer_size, device, n_tasks, attributes=['examples', 'labels', 'logits', 'task_labels'], n_bin=8):
        """
        Initializes the memory buffer.
        Args:
            buffer_size: the maximum buffer size
            device: the device to store the data
            n_tasks: the total number of tasks
            attributes: the attributes to store in the memory buffer
            n_bin: the number of bins for the reservoir binning strategy
        """
        self.buffer_size = buffer_size
        self.device = device
        self.num_seen_examples = 0
        self.task = 1
        self.task_number = n_tasks
        self.attributes = attributes
        self.delta = torch.zeros(buffer_size, device=device)
        self.balanced_class_perm = None
        self.num_bins = n_bin
        self.bins = np.zeros(self.num_bins)  # Initialize bins with zero counts
        self.min_loss = float('inf')
        self.max_loss = float('-inf')
        self.budget = (self.buffer_size // self.num_bins) // self.task
        self.num_examples = 0
[docs]
    def reset_budget(self):
        self.task += 1
        self.budget = (self.buffer_size // self.num_bins) // self.task 
[docs]
    def reset_bins(self):
        self.bins = np.zeros(self.num_bins)
        self.min_loss = float('inf')
        self.max_loss = float('-inf')
        self.reset_budget() 
[docs]
    def update_loss_range(self, loss_value):
        """
        Updates the min and max loss values seen, for binning purposes.
        """
        self.min_loss = min(self.min_loss, loss_value)
        self.max_loss = max(self.max_loss, loss_value) 
[docs]
    def get_bin_index(self, loss_value):
        """
        Determines the bin index for a given loss value.
        """
        bin_range = self.max_loss - self.min_loss
        if bin_range == 0:
            return 0  # All losses are the same, only one bin needed
        bin_width = bin_range / self.num_bins
        bin_index = int((loss_value - self.min_loss) / bin_width)
        return min(bin_index, self.num_bins - 1)  # To handle the max loss 
[docs]
    def reservoir_bin_loss(self, loss_value: float) -> int:
        """
        Modified reservoir sampling algorithm considering loss values and binning.
        """
        self.update_loss_range(loss_value)
        bin_index = self.get_bin_index(loss_value)
        if self.bins[bin_index] < self.budget:
            if self.num_examples < self.buffer_size:
                self.bins[bin_index] += 1
                return self.num_examples
            else:
                rand = np.random.randint(0, self.buffer_size)
                self.bins[bin_index] += 1
                return rand
        else:
            return -1 
[docs]
    def reservoir_loss(self, num_seen_examples: int, buffer_size: int, loss_value: float) -> int:
        """
        Modified reservoir sampling algorithm considering loss values
        """
        # Probability based on the loss value (higher loss, higher probability)
        loss_probability = math.exp(loss_value) / (1 + math.exp(loss_value))
        rand = np.random.random()
        if rand < loss_probability and self.budget > 0:
            self.budget -= 1
            return np.random.randint(buffer_size)
        else:
            return -1 
[docs]
    def to(self, device):
        self.device = device
        for attr_str in self.attributes:
            if hasattr(self, attr_str):
                setattr(self, attr_str, getattr(self, attr_str).to(device))
        return self 
[docs]
    def init_tensors(self, examples: torch.Tensor, labels: torch.Tensor,
                     logits: torch.Tensor, task_labels: torch.Tensor,
                     clusters_labels=None, clusters_logits=None,
                     loss_values=None) -> None:
        for attr_str in self.attributes:
            attr = eval(attr_str)
            if attr is not None and not hasattr(self, attr_str):
                typ = torch.int64 if attr_str.endswith('els') else torch.float32
                if attr_str.startswith('loss_val'):
                    setattr(self, attr_str, torch.zeros((self.buffer_size,
                                                         *attr.shape[1:]), dtype=typ, device=self.device) - 1)
                else:
                    setattr(self, attr_str, torch.zeros((self.buffer_size,
                                                         *attr.shape[1:]), dtype=typ, device=self.device)) 
[docs]
    def add_data(self, examples, labels=None, clusters_labels=None, logits=None, clusters_logits=None, task_labels=None, loss_values=None):
        if not hasattr(self, 'examples'):
            self.init_tensors(examples, labels, logits, task_labels, clusters_labels=clusters_labels, clusters_logits=clusters_logits, loss_values=loss_values)
        rix = []
        for i in range(examples.shape[0]):
            index = self.reservoir_bin_loss(loss_values[i])
            self.num_seen_examples += 1
            if index >= 0:
                self.num_examples += 1
                if self.examples.device != self.device:
                    self.examples.to(self.device)
                self.examples[index] = examples[i].to(self.device)
                if labels is not None:
                    if self.labels.device != self.device:
                        self.labels.to(self.device)
                    self.labels[index] = labels[i].to(self.device)
                if clusters_labels is not None:
                    if self.clusters_labels.device != self.device:
                        self.clusters_labels.to(self.device)
                    self.clusters_labels[index] = clusters_labels[i].to(self.device)
                if logits is not None:
                    if self.logits.device != self.device:
                        self.logits.to(self.device)
                    self.logits[index] = logits[i].to(self.device)
                if clusters_logits is not None:
                    if self.clusters_logits.device != self.device:
                        self.clusters_logits.to(self.device)
                    self.clusters_logits[index] = clusters_logits[i].to(self.device)
                if task_labels is not None:
                    if self.task_labels.device != self.device:
                        self.task_labels.to(self.device)
                    self.task_labels[index] = task_labels[i].to(self.device)
                if loss_values is not None:
                    if self.loss_values.device != self.device:
                        self.loss_values.to(self.device)
                    self.loss_values[index] = loss_values[i].to(self.device)
            rix.append(index)
        return torch.tensor(rix).to(self.device) 
[docs]
    def update_losses(self, loss_values, indexes):
        self.loss_values[indexes] = loss_values 
[docs]
    def get_losses(self):
        return self.loss_values.cpu().numpy() 
[docs]
    def get_task_labels(self):
        return self.task_labels.cpu().numpy() 
[docs]
    def get_data(self, size: int, transform: transforms = None, return_index=False, to_device=None) -> Tuple:
        m_t = min(self.num_examples, self.examples.shape[0])
        if size > m_t:
            size = m_t
        target_device = self.device if to_device is None else to_device
        choice = np.random.choice(m_t, size=size, replace=False)
        if transform is None:
            def transform(x): return x
        ret_tuple = (apply_transform(self.examples[choice], transform=transform).to(target_device),)
        for attr_str in self.attributes[1:]:
            if hasattr(self, attr_str):
                attr = getattr(self, attr_str).to(target_device)
                ret_tuple += (attr[choice],)
        if not return_index:
            return ret_tuple
        else:
            return (torch.tensor(choice).to(target_device), ) + ret_tuple 
[docs]
    def is_empty(self) -> bool:
        """
        Returns true if the buffer is empty, false otherwise.
        """
        if self.num_seen_examples == 0:
            return True
        else:
            return False 
[docs]
    def empty(self) -> None:
        """
        Set all the tensors to None.
        """
        for attr_str in self.attributes:
            if hasattr(self, attr_str):
                delattr(self, attr_str)
        self.num_seen_examples = 0