# Copyright 2021-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from utils.ring_buffer import RingBuffer
from datasets import get_dataset
from utils.args import *
from models.utils.continual_model import ContinualModel
from utils.buffer import Buffer
from utils.mixup import mixup
from utils.triplet import batch_hard_triplet_loss, negative_only_triplet_loss
import torch
import torch.nn.functional as F
[docs]
class Ccic(ContinualModel):
    """Continual Semi-Supervised Learning via Continual Contrastive Interpolation Consistency."""
    NAME = 'ccic'
    COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'cssl']
[docs]
    @staticmethod
    def get_parser(parser) -> ArgumentParser:
        parser.set_defaults(optimizer='adam')
        add_rehearsal_args(parser)
        parser.set_defaults(optimizer='adam')
        parser.add_argument('--alpha', type=float, default=0.5,
                            help='Unsupervised loss weight.')
        parser.add_argument('--knn_k', '--k', type=int, default=2, dest='knn_k',
                            help='k of kNN.')
        parser.add_argument('--memory_penalty', type=float,
                            default=1.0, help='Unsupervised penalty weight.')
        parser.add_argument('--k_aug', type=int, default=3,
                            help='Number of augumentation to compute label predictions.')
        parser.add_argument('--mixmatch_alpha', '--lamda', type=float, default=0.5, dest='mixmatch_alpha',
                            help='Regularization weight.')
        parser.add_argument('--sharp_temp', default=0.5,
                            type=float, help='Temperature for sharpening.')
        parser.add_argument('--mixup_alpha', default=0.75, type=float)
        return parser 
    def __init__(self, backbone, loss, args, transform, dataset=None):
        super(Ccic, self).__init__(backbone, loss, args, transform, dataset=dataset)
        self.buffer = Buffer(self.args.buffer_size, self.device)
        self.embeddings = None
        self.eye = torch.eye(self.num_classes).to(self.device)
        self.sup_virtual_batch = RingBuffer(self.args.batch_size)
        self.unsup_virtual_batch = RingBuffer(self.args.batch_size)
[docs]
    def get_debug_iters(self):
        """
        Returns the number of iterations to wait before logging.
        - CCIC needs a couple more iterations to initialize the KNN.
        """
        return 1000 if len(self.buffer) < self.args.buffer_size else 5 
[docs]
    def forward(self, x):
        if self.embeddings is None:
            with torch.no_grad():
                self.compute_embeddings()
        n_seen_classes = self.cpt * self.current_task if isinstance(self.cpt, int) else sum(self.cpt[:self.current_task])
        n_remaining_classes = self.N_CLASSES - n_seen_classes
        buf_labels = self.buffer.labels[:self.buffer.num_seen_examples]
        feats = self.net(x, returnt='features')
        feats = F.normalize(feats, p=2, dim=1)
        distances = (self.embeddings.unsqueeze(0) - feats.unsqueeze(1)).pow(2).sum(2)
        dist = torch.stack([distances[:, buf_labels == c].topk(1, largest=False)[0].mean(dim=1)
                            if (buf_labels == c).sum() > 0 else torch.zeros(x.shape[0]).to(self.device)
                            for c in range(n_seen_classes)] +
                           [torch.zeros(x.shape[0]).to(self.device)] * n_remaining_classes).T
        topkappas = self.eye[buf_labels[distances.topk(self.args.knn_k, largest=False)[1]]].sum(1)
        return topkappas - dist * 10e-6 
[docs]
    def end_task(self, dataset):
        self.embeddings = None 
[docs]
    def discard_unsupervised_labels(self, inputs, labels, not_aug_inputs):
        mask = labels != -1
        return inputs[mask], labels[mask], not_aug_inputs[mask] 
[docs]
    def discard_supervised_labels(self, inputs, labels, not_aug_inputs):
        mask = labels == -1
        return inputs[mask], labels[mask], not_aug_inputs[mask] 
[docs]
    def observe(self, inputs, labels, not_aug_inputs, epoch=None):
        self.opt.zero_grad()
        real_batch_size = inputs.shape[0]
        sup_inputs, sup_labels, sup_not_aug_inputs = self.discard_unsupervised_labels(inputs, labels, not_aug_inputs)
        sup_inputs_for_buffer, sup_labels_for_buffer = sup_not_aug_inputs.clone(), sup_labels.clone()
        unsup_inputs, unsup_labels, unsup_not_aug_inputs = self.discard_supervised_labels(inputs, labels, not_aug_inputs)
        if len(sup_inputs) == 0 and self.buffer.is_empty():  # if there is no data to train on, just return 1.
            return 1.
        self.sup_virtual_batch.add_data(sup_not_aug_inputs, sup_labels)
        sup_inputs, sup_labels = self.sup_virtual_batch.get_data(self.args.batch_size, transform=self.transform, device=self.device)
        if self.current_task > 0 and unsup_not_aug_inputs.shape[0] > 0:
            self.unsup_virtual_batch.add_data(unsup_not_aug_inputs)
            unsup_inputs = self.unsup_virtual_batch.get_data(self.args.batch_size, transform=self.transform, device=self.device)[0]
        # BUFFER RETRIEVAL
        if not self.buffer.is_empty():
            buf_inputs, buf_labels = self.buffer.get_data(self.args.minibatch_size, transform=self.transform, device=self.device)
            sup_inputs = torch.cat((sup_inputs, buf_inputs))
            sup_labels = torch.cat((sup_labels, buf_labels))
            if self.current_task > 0:
                masked_buf_inputs = self.buffer.get_data(self.args.minibatch_size,
                                                         mask_task_out=self.current_task,
                                                         transform=self.transform,
                                                         cpt=self.n_classes_current_task,
                                                         device=self.device)[0]
                unsup_labels = torch.cat((torch.zeros(unsup_inputs.shape[0]).to(self.device),
                                          torch.ones(masked_buf_inputs.shape[0]).to(self.device))).long()
                unsup_inputs = torch.cat((unsup_inputs, masked_buf_inputs))
        # ------------------ K AUG ---------------------
        mask = labels != -1
        real_mask = mask[:real_batch_size]
        if (~real_mask).sum() > 0:
            unsup_aug_inputs = self.transform(not_aug_inputs[~real_mask].repeat_interleave(self.args.k_aug, 0))
        else:
            unsup_aug_inputs = torch.zeros((0,)).to(self.device)
        # ------------------ PSEUDO LABEL ---------------------
        self.net.eval()
        if len(unsup_aug_inputs):
            with torch.no_grad():
                unsup_aug_outputs = self.net(unsup_aug_inputs).reshape(self.args.k_aug, -1, self.eye.shape[0]).mean(0)
                unsup_sharp_outputs = unsup_aug_outputs ** (1 / self.args.sharp_temp)
                unsup_norm_outputs = unsup_sharp_outputs / unsup_sharp_outputs.sum(1).unsqueeze(1)
                unsup_norm_outputs = unsup_norm_outputs.repeat(self.args.k_aug, 1)
        else:
            unsup_norm_outputs = torch.zeros((0, len(self.eye))).to(self.device)
        self.net.train()
        # ------------------ MIXUP ---------------------
        self.opt.zero_grad()
        W_inputs = torch.cat((sup_inputs, unsup_aug_inputs))
        W_probs = torch.cat((self.eye[sup_labels], unsup_norm_outputs))
        perm = torch.randperm(W_inputs.shape[0])
        W_inputs, W_probs = W_inputs[perm], W_probs[perm]
        sup_shape = sup_inputs.shape[0]
        sup_mix_inputs, _ = mixup([(sup_inputs, W_inputs[:sup_shape]), (self.eye[sup_labels], W_probs[:sup_shape])], self.args.mixup_alpha)
        sup_mix_outputs = self.net(sup_mix_inputs)
        if len(unsup_aug_inputs):
            unsup_mix_inputs, _ = mixup(
                [(unsup_aug_inputs, W_inputs[sup_shape:]),
                 (unsup_norm_outputs, W_probs[sup_shape:])],
                self.args.mixup_alpha)
            unsup_mix_outputs = self.net(unsup_mix_inputs)
        effective_mbs = min(self.args.minibatch_size,
                            self.buffer.num_seen_examples)
        if effective_mbs == 0:
            effective_mbs = -self.N_CLASSES
        # ------------------ CIC LOSS ---------------------
        loss_X = 0
        if real_mask.sum() > 0:
            loss_X += self.loss(sup_mix_outputs[:-effective_mbs],
                                sup_labels[:-effective_mbs])
        if not self.buffer.is_empty():
            assert effective_mbs > 0
            loss_X += self.args.memory_penalty * self.loss(sup_mix_outputs[-effective_mbs:],
                                                           sup_labels[-effective_mbs:])
        if len(unsup_aug_inputs):
            loss_U = F.mse_loss(unsup_norm_outputs, unsup_mix_outputs) / self.eye.shape[0]
        else:
            loss_U = 0
        # CIC LOSS
        if self.current_task > 0 and epoch < self.args.n_epochs / 10 * 9:
            W_inputs = sup_inputs
            W_probs = self.eye[sup_labels]
            perm = torch.randperm(W_inputs.shape[0])
            W_inputs, W_probs = W_inputs[perm], W_probs[perm]
            sup_mix_inputs, _ = mixup([(sup_inputs, W_inputs), (self.eye[sup_labels], W_probs)], 1)
        else:
            sup_mix_inputs = sup_inputs
        # STANDARD TRIPLET
        sup_mix_embeddings = self.net.features(sup_mix_inputs)
        loss = batch_hard_triplet_loss(sup_labels, sup_mix_embeddings, self.args.batch_size // 10,
                                       margin=1, margin_type='hard')
        if loss is None:
            loss = loss_X + self.args.mixmatch_alpha * loss_U
        else:
            loss += loss_X + self.args.mixmatch_alpha * loss_U
        self.buffer.add_data(examples=sup_inputs_for_buffer,
                             labels=sup_labels_for_buffer)
        # SELF-SUPERVISED PAST TASKS NEGATIVE ONLY
        if self.current_task > 0 and epoch < self.args.n_epochs / 10 * 9:
            unsup_embeddings = self.net.features(unsup_inputs)
            loss_unsup = negative_only_triplet_loss(unsup_labels, unsup_embeddings, self.args.batch_size // 10,
                                                    margin=1, margin_type='hard')
            if loss_unsup is not None:
                loss += self.args.alpha * loss_unsup
        loss.backward()
        self.opt.step()
        return loss.item() 
[docs]
    @torch.no_grad()
    def compute_embeddings(self):
        """
        Computes a vector representing mean features for each class.
        """
        was_training = self.net.training
        self.net.eval()
        data = self.buffer.get_all_data(transform=self.normalization_transform)[0]
        outputs = []
        while data.shape[0] > 0:
            inputs = data[:self.args.batch_size]
            data = data[self.args.batch_size:]
            out = self.net(inputs, returnt='features')
            out = F.normalize(out, p=2, dim=1)
            outputs.append(out)
        self.embeddings = torch.cat(outputs)
        self.net.train(was_training)