Source code for models.ranpac

"""
Slow Learner with Classifier Alignment.

Note:
    SLCA USES A CUSTOM BACKBONE (see `feature_extractor_type` argument)

Arguments:
    --feature_extractor_type: the type of convnet to use. `vit-b-p16` is the default: ViT-B/16 pretrained on Imagenet 21k (**NO** finetuning on ImageNet 1k)
"""

import copy

import numpy as np
from models.ranpac_utils.toolkit import target2onehot
from utils import binary_to_boolean_type
from utils.args import *
from models.utils.continual_model import ContinualModel

import torch
import torch.nn.functional as F
from utils.conf import get_device
from models.ranpac_utils.ranpac import RanPAC_Model


[docs] class RanPAC(ContinualModel): """RanPAC: Random Projections and Pre-trained Models for Continual Learning.""" NAME = 'ranpac' COMPATIBILITY = ['class-il', 'domain-il', 'task-il'] net: RanPAC_Model
[docs] @staticmethod def get_parser(parser) -> ArgumentParser: parser.set_defaults(pretrain_type='in21k') parser.set_defaults(optim_mom=0.9, optim_wd=0.0005, batch_size=48) parser.add_argument('--rp_size', type=int, default=10000, help='size of the random projection layer (L in the paper)') return parser
def __init__(self, backbone, loss, args, transform, dataset=None): self.device = get_device() print("-" * 20) print(f"WARNING: RanPAC USES `in21k` AS DEFAULT PRETRAIN. CHANGE IT WITH `--pretrain_type` IF NEEDED.") backbone = RanPAC_Model(backbone, args) print("-" * 20) super().__init__(backbone, loss, args, transform, dataset=dataset)
[docs] def get_parameters(self): return self.net._network.parameters()
[docs] def get_scheduler(self): return torch.optim.lr_scheduler.CosineAnnealingLR(self.opt, T_max=self.args.n_epochs, eta_min=0)
[docs] def end_task(self, dataset): if self.current_task == 0: self.freeze_backbone() self.setup_RP() dataset.train_loader.dataset.transform = self.dataset.TEST_TRANSFORM self.replace_fc(dataset.train_loader)
[docs] def setup_RP(self): self.net._network.fc.use_RP = True # RP with M > 0 M = self.args.rp_size self.net._network.fc.weight = torch.nn.Parameter(torch.Tensor(self.net._network.fc.out_features, M).to(self.net._network.device)) # num classes in task x M self.net._network.fc.reset_parameters() self.net._network.fc.W_rand = torch.randn(self.net._network.fc.in_features, M).to(self.net._network.device) self.W_rand = copy.deepcopy(self.net._network.fc.W_rand) # make a copy that gets passed each time the head is replaced self.Q = torch.zeros(M, self.dataset.N_CLASSES) self.G = torch.zeros(M, M)
[docs] def replace_fc(self, trainloader): self.net._network.eval() # these lines are needed because the CosineLinear head gets deleted between streams and replaced by one with more classes (for CIL) self.net._network.fc.use_RP = True self.net._network.fc.W_rand = self.W_rand Features_f = [] label_list = [] with torch.no_grad(): for i, data in enumerate(trainloader): data, label = data[0].to(self.device), data[1].to(self.device) embedding = self.net._network.convnet(data) Features_f.append(embedding.cpu()) label_list.append(label.cpu()) Features_f = torch.cat(Features_f, dim=0) label_list = torch.cat(label_list, dim=0) Y = target2onehot(label_list, self.dataset.N_CLASSES) # print('Number of pre-trained feature dimensions = ',Features_f.shape[-1]) Features_h = torch.nn.functional.relu(Features_f @ self.net._network.fc.W_rand.cpu()) self.Q = self.Q + Features_h.T @ Y self.G = self.G + Features_h.T @ Features_h ridge = self.optimise_ridge_parameter(Features_h, Y) Wo = torch.linalg.solve(self.G + ridge * torch.eye(self.G.size(dim=0)), self.Q).T # better nmerical stability than .inv self.net._network.fc.weight.data = Wo[0:self.net._network.fc.weight.shape[0], :].to(self.net._network.device)
[docs] def optimise_ridge_parameter(self, Features, Y): ridges = 10.0**np.arange(-8, 9) num_val_samples = int(Features.shape[0] * 0.8) losses = [] Q_val = Features[0:num_val_samples, :].T @ Y[0:num_val_samples, :] G_val = Features[0:num_val_samples, :].T @ Features[0:num_val_samples, :] for ridge in ridges: Wo = torch.linalg.solve(G_val + ridge * torch.eye(G_val.size(dim=0)), Q_val).T # better nmerical stability than .inv Y_train_pred = Features[num_val_samples::, :] @ Wo.T losses.append(F.mse_loss(Y_train_pred, Y[num_val_samples::, :])) ridge = ridges[np.argmin(np.array(losses))] logging.info("Optimal lambda: " + str(ridge)) return ridge
[docs] def begin_task(self, dataset): # temporarily remove RP weights del self.net._network.fc self.net._network.fc = None self.net._network.update_fc(self.n_seen_classes) # creates a new head with a new number of classes (if CIL) if self.current_task == 0: self.opt = self.get_optimizer() self.scheduler = self.get_scheduler() self.opt.zero_grad()
[docs] def freeze_backbone(self, is_first_session=False): # Freeze the parameters for ViT. if isinstance(self.net._network.convnet, torch.nn.Module): for name, param in self.net._network.convnet.named_parameters(): if is_first_session: if "head." not in name and "ssf_scale" not in name and "ssf_shift_" not in name: param.requires_grad = False else: param.requires_grad = False
[docs] def observe(self, inputs, labels, not_aug_inputs, epoch=0): if self.current_task == 0: # simple train on first task logits = self.net._network(inputs)["logits"] loss = self.loss(logits, labels) self.opt.zero_grad() loss.backward() self.opt.step() return loss.item() return 0
[docs] def forward(self, x): return self.net._network(x)['logits']