"""
Slow Learner with Classifier Alignment.
Note:
    SLCA USES A CUSTOM BACKBONE (see `feature_extractor_type` argument)
Arguments:
    --feature_extractor_type: the type of convnet to use. `vit-b-p16` is the default: ViT-B/16 pretrained on Imagenet 21k (**NO** finetuning on ImageNet 1k)
"""
import copy
import numpy as np
from models.ranpac_utils.toolkit import target2onehot
from utils import binary_to_boolean_type
from utils.args import *
from models.utils.continual_model import ContinualModel
import torch
import torch.nn.functional as F
from utils.conf import get_device
from models.ranpac_utils.ranpac import RanPAC_Model
[docs]
class RanPAC(ContinualModel):
    """RanPAC: Random Projections and Pre-trained Models for Continual Learning."""
    NAME = 'ranpac'
    COMPATIBILITY = ['class-il', 'domain-il', 'task-il']
    net: RanPAC_Model
[docs]
    @staticmethod
    def get_parser(parser) -> ArgumentParser:
        parser.set_defaults(pretrain_type='in21k')
        parser.set_defaults(optim_mom=0.9, optim_wd=0.0005, batch_size=48)
        parser.add_argument('--rp_size', type=int, default=10000, help='size of the random projection layer (L in the paper)')
        return parser 
    def __init__(self, backbone, loss, args, transform, dataset=None):
        self.device = get_device()
        logging.warning("-" * 20)
        logging.warning(f"RanPAC USES `in21k` AS DEFAULT PRETRAIN. CHANGE IT WITH `--pretrain_type` IF NEEDED.")
        logging.warning("-" * 20)
        backbone = RanPAC_Model(backbone, args)
        super().__init__(backbone, loss, args, transform, dataset=dataset)
[docs]
    def get_parameters(self):
        return self.net._network.parameters() 
[docs]
    def get_scheduler(self):
        return torch.optim.lr_scheduler.CosineAnnealingLR(self.opt, T_max=self.args.n_epochs, eta_min=0) 
[docs]
    def end_task(self, dataset):
        if self.current_task == 0:
            self.freeze_backbone()
            self.setup_RP()
        dataset.train_loader.dataset.transform = self.dataset.TEST_TRANSFORM
        self.replace_fc(dataset.train_loader) 
[docs]
    def setup_RP(self):
        self.net._network.fc.use_RP = True
        # RP with M > 0
        M = self.args.rp_size
        self.net._network.fc.weight = torch.nn.Parameter(torch.Tensor(self.net._network.fc.out_features, M).to(self.net._network.device))  # num classes in task x M
        self.net._network.fc.reset_parameters()
        self.net._network.fc.W_rand = torch.randn(self.net._network.fc.in_features, M).to(self.net._network.device)
        self.W_rand = copy.deepcopy(self.net._network.fc.W_rand)  # make a copy that gets passed each time the head is replaced
        self.Q = torch.zeros(M, self.dataset.N_CLASSES)
        self.G = torch.zeros(M, M) 
[docs]
    def replace_fc(self, trainloader):
        self.net._network.eval()
        # these lines are needed because the CosineLinear head gets deleted between streams and replaced by one with more classes (for CIL)
        self.net._network.fc.use_RP = True
        self.net._network.fc.W_rand = self.W_rand
        Features_f = []
        label_list = []
        with torch.no_grad():
            for i, data in enumerate(trainloader):
                data, label = data[0].to(self.device), data[1].to(self.device)
                embedding = self.net._network.convnet(data)
                Features_f.append(embedding.cpu())
                label_list.append(label.cpu())
        Features_f = torch.cat(Features_f, dim=0)
        label_list = torch.cat(label_list, dim=0)
        Y = target2onehot(label_list, self.dataset.N_CLASSES)
        Features_h = torch.nn.functional.relu(Features_f @ self.net._network.fc.W_rand.cpu())
        self.Q = self.Q + Features_h.T @ Y
        self.G = self.G + Features_h.T @ Features_h
        ridge = self.optimise_ridge_parameter(Features_h, Y)
        Wo = torch.linalg.solve(self.G + ridge * torch.eye(self.G.size(dim=0)), self.Q).T  # better nmerical stability than .inv
        self.net._network.fc.weight.data = Wo[0:self.net._network.fc.weight.shape[0], :].to(self.net._network.device) 
[docs]
    def optimise_ridge_parameter(self, Features, Y):
        ridges = 10.0**np.arange(-8, 9)
        num_val_samples = int(Features.shape[0] * 0.8)
        losses = []
        Q_val = Features[0:num_val_samples, :].T @ Y[0:num_val_samples, :]
        G_val = Features[0:num_val_samples, :].T @ Features[0:num_val_samples, :]
        for ridge in ridges:
            Wo = torch.linalg.solve(G_val + ridge * torch.eye(G_val.size(dim=0)), Q_val).T  # better nmerical stability than .inv
            Y_train_pred = Features[num_val_samples::, :] @ Wo.T
            losses.append(F.mse_loss(Y_train_pred, Y[num_val_samples::, :]))
        ridge = ridges[np.argmin(np.array(losses))]
        logging.info("Optimal lambda: " + str(ridge))
        return ridge 
[docs]
    def begin_task(self, dataset):
        # temporarily remove RP weights
        del self.net._network.fc
        self.net._network.fc = None
        self.net._network.update_fc(self.n_seen_classes)  # creates a new head with a new number of classes (if CIL)
        if self.current_task == 0:
            self.opt = self.get_optimizer()
            self.custom_scheduler = self.get_scheduler()
            self.opt.zero_grad() 
[docs]
    def freeze_backbone(self, is_first_session=False):
        # Freeze the parameters for ViT.
        if isinstance(self.net._network.convnet, torch.nn.Module):
            for name, param in self.net._network.convnet.named_parameters():
                if is_first_session:
                    if "head." not in name and "ssf_scale" not in name and "ssf_shift_" not in name:
                        param.requires_grad = False
                else:
                    param.requires_grad = False 
[docs]
    def observe(self, inputs, labels, not_aug_inputs, epoch=0):
        if self.current_task == 0:  # simple train on first task
            logits = self.net._network(inputs)["logits"]
            loss = self.loss(logits, labels)
            self.opt.zero_grad()
            loss.backward()
            self.opt.step()
            return loss.item()
        return 0 
[docs]
    def forward(self, x):
        return self.net._network(x)['logits']