Source code for models.lucir

# Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import math
from copy import deepcopy

import numpy as np
import torch
import torch.nn.functional as F
from datasets import get_dataset
from torch import nn

from models.utils.continual_model import ContinualModel
from utils import binary_to_boolean_type
from utils.args import add_rehearsal_args, ArgumentParser
from utils.batch_norm import bn_track_stats
from utils.buffer import Buffer, fill_buffer, icarl_replay
from utils.conf import create_seeded_dataloader


[docs] def lucir_batch_hard_triplet_loss(labels, embeddings, k, margin, num_old_classes): """ LUCIR triplet loss. """ gt_index = torch.zeros(embeddings.size()).to(embeddings.device) gt_index = gt_index.scatter(1, labels.reshape(-1, 1).long(), 1).ge(0.5) gt_scores = embeddings.masked_select(gt_index) # get top-K scores on novel classes max_novel_scores = embeddings[:, num_old_classes:].topk(k, dim=1)[0] # the index of hard samples, i.e., samples of old classes hard_index = labels.lt(num_old_classes) hard_num = torch.nonzero(hard_index).size(0) if hard_num > 0: gt_scores = gt_scores[hard_index].view(-1, 1).repeat(1, k) max_novel_scores = max_novel_scores[hard_index] assert (gt_scores.size() == max_novel_scores.size()) assert (gt_scores.size(0) == hard_num) target = torch.ones(hard_num * k, 1).to(embeddings.device) loss = nn.MarginRankingLoss(margin=margin)(gt_scores.view(-1, 1), max_novel_scores.view(-1, 1), target) else: loss = torch.zeros(1).to(embeddings.device) return loss
[docs] class CustomClassifier(nn.Module): def __init__(self, in_features, cpt, n_tasks): super().__init__() self.weights = nn.ParameterList( [nn.parameter.Parameter(torch.Tensor(cpt, in_features)) for _ in range(n_tasks)] ) self.sigma = nn.parameter.Parameter(torch.Tensor(1)) self.in_features = in_features self.cpt = cpt self.n_tasks = n_tasks self.reset_parameters() self.weights[0].requires_grad = True
[docs] def reset_parameters(self): for i in range(self.n_tasks): stdv = 1. / math.sqrt(self.weights[i].size(1)) self.weights[i].data.uniform_(-stdv, stdv) self.weights[i].requires_grad = False self.sigma.data.fill_(1)
[docs] def forward(self, x): return self.noscale_forward(x) * self.sigma
[docs] def reset_weight(self, i): stdv = 1. / math.sqrt(self.weights[i].size(1)) self.weights[i].data.uniform_(-stdv, stdv) self.weights[i].requires_grad = True self.weights[i - 1].requires_grad = False
[docs] def noscale_forward(self, x): out = None x = F.normalize(x, p=2, dim=1).reshape(len(x), -1) for t in range(self.n_tasks): o = F.linear(x, F.normalize(self.weights[t], p=2, dim=1)) if out is None: out = o else: out = torch.cat((out, o), dim=1) return out
[docs] class Lucir(ContinualModel): """Continual Learning via Lucir.""" NAME = 'lucir' COMPATIBILITY = ['class-il', 'task-il']
[docs] @staticmethod def get_parser(parser) -> ArgumentParser: add_rehearsal_args(parser) parser.add_argument('--lamda_base', type=float, required=False, default=5., help='Regularization weight for embedding cosine similarity.') parser.add_argument('--lamda_mr', type=float, required=False, default=1., help='Regularization weight for embedding cosine similarity.') parser.add_argument('--k_mr', type=int, required=False, default=2, help='K for margin-ranking loss.') parser.add_argument('--mr_margin', type=float, default=0.5, required=False, help='Margin for margin-ranking loss.') parser.add_argument('--fitting_epochs', type=int, required=False, default=20, help='Number of epochs to finetune on coreset after each task.') parser.add_argument('--lr_finetune', type=float, required=False, default=0.01, help='Learning Rate for finetuning.') parser.add_argument('--imprint_weights', type=binary_to_boolean_type, required=False, default=1, help='Apply weight imprinting?') return parser
def __init__(self, backbone, loss, args, transform, dataset=None): super(Lucir, self).__init__(backbone, loss, args, transform, dataset=dataset) self.dataset = get_dataset(args) # Instantiate buffers self.buffer = Buffer(self.args.buffer_size) self.eye = torch.eye(self.dataset.N_CLASSES_PER_TASK * self.dataset.N_TASKS).to(self.device) self.old_net = None self.epochs = int(args.n_epochs) self.lamda_cos_sim = args.lamda_base self.net.classifier = CustomClassifier( self.net.classifier.in_features, self.dataset.N_CLASSES_PER_TASK, self.dataset.N_TASKS) upd_weights = [p for n, p in self.net.named_parameters() if 'classifier' not in n and '_fc' not in n] + [self.net.classifier.weights[0], self.net.classifier.sigma] fix_weights = list(self.net.classifier.weights[1:]) self.opt = torch.optim.SGD([{'params': upd_weights, 'lr': self.args.lr, 'momentum': self.args.optim_mom, 'weight_decay': self.args.optim_wd}, { 'params': fix_weights, 'lr': 0, 'momentum': self.args.optim_mom, 'weight_decay': 0}]) self.ft_lr_strat = [10] self.c_epoch = -1
[docs] def update_classifier(self): self.net.classifier.reset_weight(self.current_task)
[docs] def forward(self, x): with torch.no_grad(): outputs = self.net(x) return outputs
[docs] def observe(self, inputs, labels, not_aug_inputs, logits=None, epoch=None, fitting=False): if not hasattr(self, 'classes_so_far'): self.register_buffer('classes_so_far', labels.unique().to('cpu')) else: self.register_buffer('classes_so_far', torch.cat(( self.classes_so_far, labels.to('cpu'))).unique()) self.opt.zero_grad() loss = self.get_loss( inputs, labels.long(), self.current_task) loss.backward() self.opt.step() return loss.item()
[docs] def get_loss(self, inputs: torch.Tensor, labels: torch.Tensor, task_idx: int) -> torch.Tensor: """ Computes the loss tensor. Args: inputs: the images to be fed to the network labels: the ground-truth labels task_idx: the task index Returns: the differentiable loss value """ pc = task_idx * self.dataset.N_CLASSES_PER_TASK ac = (task_idx + 1) * self.dataset.N_CLASSES_PER_TASK outputs = self.net(inputs, returnt='features').float() cos_output = self.net.classifier.noscale_forward(outputs) outputs = outputs.reshape(outputs.size(0), -1) loss = F.cross_entropy(cos_output * self.net.classifier.sigma, labels) if task_idx > 0: with torch.no_grad(): logits = self.old_net(inputs, returnt='features') logits = logits.reshape(logits.size(0), -1) loss2 = F.cosine_embedding_loss( outputs, logits.detach(), torch.ones(outputs.shape[0]).to(outputs.device)) * self.lamda_cos_sim # Remove rescale by sigma before this loss loss3 = lucir_batch_hard_triplet_loss( labels, cos_output, self.args.k_mr, self.args.mr_margin, pc) * self.args.lamda_mr loss = loss + loss2 + loss3 return loss
[docs] def begin_task(self, dataset): if self.current_task > 0: icarl_replay(self, dataset) with torch.no_grad(): # Update model classifier self.update_classifier() if self.args.imprint_weights: self.imprint_weights(dataset) # Restore optimizer LR upd_weights = [p for n, p in self.net.named_parameters() if 'classifier' not in n] + [self.net.classifier.weights[self.current_task], self.net.classifier.sigma] fix_weights = list( self.net.classifier.weights[:self.current_task]) if self.current_task < self.dataset.N_TASKS - 1: fix_weights += list( self.net.classifier.weights[self.current_task + 1:]) self.opt = torch.optim.SGD([{'params': upd_weights, 'lr': self.args.lr, 'weight_decay': self.args.optim_wd}, { 'params': fix_weights, 'lr': 0, 'weight_decay': 0}], lr=self.args.lr, momentum=self.args.optim_mom, weight_decay=self.args.optim_wd)
[docs] def end_task(self, dataset) -> None: self.old_net = deepcopy(self.net.eval()) self.net.train() with torch.no_grad(): fill_buffer(self.buffer, dataset, self.current_task, net=self.net, use_herding=True) if self.args.fitting_epochs is not None and self.args.fitting_epochs > 0: self.fit_buffer(self.args.fitting_epochs) # Adapt lambda self.lamda_cos_sim = math.sqrt(self.current_task) * float(self.args.lamda_base)
[docs] def imprint_weights(self, dataset): self.net.eval() old_embedding_norm = torch.cat([self.net.classifier.weights[i] for i in range(self.current_task)]).norm( dim=1, keepdim=True) average_old_embedding_norm = torch.mean( old_embedding_norm, dim=0).cpu().type(torch.DoubleTensor) num_features = self.net.classifier.in_features novel_embedding = torch.zeros( (self.dataset.N_CLASSES_PER_TASK, num_features)) loader = dataset.train_loader cur_dataset = deepcopy(loader.dataset) for cls_idx in range(self.current_task * self.dataset.N_CLASSES_PER_TASK, (self.current_task + 1) * self.dataset.N_CLASSES_PER_TASK): cls_indices = np.asarray( loader.dataset.targets) == cls_idx cur_dataset.data = loader.dataset.data[cls_indices] cur_dataset.targets = np.zeros((cur_dataset.data.shape[0])) dt = create_seeded_dataloader(self.args, cur_dataset, batch_size=self.args.batch_size, num_workers=0) num_samples = cur_dataset.data.shape[0] cls_features = torch.empty((num_samples, num_features)) for j, d in enumerate(dt): tt = self.net(d[0].to(self.device), returnt='features').cpu() if 'ntu' in self.args.dataset: tt = F.adaptive_avg_pool3d(tt, 1) cls_features[j * self.args.batch_size:( j + 1) * self.args.batch_size] = tt.reshape(len(tt), -1) norm_features = F.normalize(cls_features, p=2, dim=1) cls_embedding = torch.mean(norm_features, dim=0) novel_embedding[cls_idx - self.current_task * self.dataset.N_CLASSES_PER_TASK] = F.normalize( cls_embedding, p=2, dim=0) * average_old_embedding_norm self.net.classifier.weights[self.current_task].data = novel_embedding.to( self.device) self.net.train()
[docs] def fit_buffer(self, opt_steps): old_opt = self.opt # Optimize only final embeddings self.opt = torch.optim.SGD(self.net.classifier.parameters( ), self.args.lr_finetune, momentum=self.args.optim_mom, weight_decay=self.args.optim_wd) lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( self.opt, milestones=self.ft_lr_strat, gamma=0.1) with bn_track_stats(self, False): for _ in range(opt_steps): examples, labels = self.buffer.get_all_data(self.transform, device=self.device) dset = torch.utils.data.TensorDataset(examples, labels) torch.cuda.synchronize() dt = create_seeded_dataloader(self.args, dset, shuffle=True, batch_size=self.args.batch_size, num_workers=0) for inputs, labels in dt: self.observe(inputs, labels, None, fitting=True) lr_scheduler.step() self.opt = old_opt