# Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from collections import defaultdict
from argparse import ArgumentParser
from copy import deepcopy
from tqdm import tqdm
import torch
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader
from utils import binary_to_boolean_type, none_or_float
from utils.schedulers import CosineSchedule
from datasets import get_dataset_class
from models.utils.continual_model import ContinualModel
from models.lora_prototype_utils.lora_prompt import Model
from models.lora_prototype_utils.generative_replay import FeaturesDataset
from models.lora_prototype_utils.utils import create_optimizer
from models.lora_prototype_utils.utils import get_dist
from models.lora_prototype_utils.utils import AlignmentLoss
from models.lora_prototype_utils.utils import linear_probing_epoch
[docs]
def int_or_all(x):
if x == 'all':
return x
return str(x)
[docs]
class SecondOrder(ContinualModel):
NAME = 'second_order'
COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual']
net: Model
[docs]
@staticmethod
def get_parser(parser: ArgumentParser) -> ArgumentParser:
parser.set_defaults(pretrain_type='in21k', optimizer='adamw')
# OPTIM PARAMS
parser.add_argument("--virtual_bs_n", type=int, default=1,
help="Virtual batch size iterations")
parser.add_argument('--clip_grad', type=none_or_float, default=100,
help='Clip gradient norm (None means no clipping)')
# FINE-TUNING PARAMS
parser.add_argument('--tuning_style', type=str, default='lora',
choices=['lora', 'full', 'ia3'],
help='Strategy to use for tuning the model.\n'
'- "lora": LoRA\n'
'- "full": full fine-tuning\n'
'- "ia3": IA3')
parser.add_argument('--lora_r', type=int, default=16,
help='LoRA rank. Used if `tuning_style` is "lora".')
# PRE-TUNING
parser.add_argument('--num_epochs_pretuning', type=int, default=3,
help='Number of epochs for pre-tuning')
parser.add_argument("--learning_rate_pretuning", type=float, default=0.01,
help="Learning rate for pre-tuning.")
parser.add_argument('--fisher_mc_classes', type=int_or_all, default='all',
help='Number of classes to use for EWC Fisher computation.\n'
'- "all": slow but accurate, uses all classes\n'
'- <int>: use subset of <int> classes, faster but less accurate')
parser.add_argument("--num_samples_align_pretuning", type=int, default=256,
help="Num. of samples from each gaussian.")
parser.add_argument("--batch_size_align_pretuning", type=int, default=128,
help="Batch size for CA.")
parser.add_argument("--num_epochs_align_pretuning", type=int, default=10,
help="Num. of epochs for CA.")
parser.add_argument("--lr_align_pretuning", type=float,
default=0.01, help="Learning rate for CA.")
# REGULARIZATION PARAMS
parser.add_argument('--use_iel', type=binary_to_boolean_type, choices=[0, 1], default=0,
help="Tune with ITA or IEL")
# IEL
parser.add_argument('--beta_iel', type=float, default=0.0, help="Beta parameter of IEL (Eq. 18/19)")
# ITA
parser.add_argument('--alpha_ita', type=float, default=0.0, help="Alpha parameter of ITA (Eq. 11)")
parser.add_argument('--req_weight_cls', type=float,
help="Regularization weight (alpha for ITA/beta for IEL) for classifier. "
"If None, will use the alpha/beta of ITA/IEL.")
parser.add_argument('--simple_reg_weight_cls', type=float, default=0.0,
help="Regularization weight for simple MSE-based loss for the classifier.")
return parser
def __init__(self, backbone, loss, args, transform, dataset=None):
if args.fisher_mc_classes == 'all':
dset_cls = get_dataset_class(args)
args.fisher_mc_classes = dset_cls.N_CLASSES
else:
args.fisher_mc_classes = int(args.fisher_mc_classes)
assert args.beta_iel >= 0., "Beta parameter of IEL must be >= 0"
assert args.alpha_ita >= 0., "Alpha parameter of ITA must be >= 0"
args.req_weight_cls = args.req_weight_cls if args.req_weight_cls is not None else \
(args.beta_iel if args.use_iel else args.alpha_ita)
backbone = Model(args, dataset, backbone)
super().__init__(backbone, loss, args, transform, dataset=dataset)
self.output_dim = backbone.output_dim
distributions = [get_dist(self.output_dim) for _ in range(self.num_classes)]
self.distributions = torch.nn.ModuleList(distributions).to(self.device)
pretrain_distributions = [get_dist(self.output_dim) for _ in range(self.num_classes)]
self.pretrain_distributions = torch.nn.ModuleList(pretrain_distributions).to(self.device)
self.old_epoch, self.iteration = 0, 0
self.custom_scheduler = self.get_scheduler()
self.alignment_loss = AlignmentLoss(self.dataset, self.device)
self.pretraining_classifier = deepcopy(self.net.vit.head)
self.buffergrad = None
self.buffergrad_cls = None
self.beta_iel = self.args.beta_iel
self.alpha_ita = self.args.alpha_ita
self.req_weight_cls = self.args.req_weight_cls
self.reg_loss_is_active = self.beta_iel > 0. or self.alpha_ita > 0.
self.reg_loss_cls_is_active = self.req_weight_cls > 0.
[docs]
@torch.no_grad()
def create_synthetic_features_dataset(self, distributions_to_sample_from=None, upto: int = None):
labels, features = [], []
if upto is None:
upto = self.current_task + 1
else:
assert isinstance(upto, int)
num_samples_per_class = self.args.num_samples_align_pretuning
if distributions_to_sample_from is None:
distributions_to_sample_from = self.distributions
for _ti in range(upto):
prev_t_size, cur_t_size = self.dataset.get_offsets(_ti)
for class_idx in range(prev_t_size, cur_t_size):
current_samples = distributions_to_sample_from[class_idx](num_samples_per_class, 1.0)
features.append(current_samples)
labels.append(torch.ones((num_samples_per_class,)) * class_idx)
features = torch.cat(features, dim=0)
labels = torch.cat(labels, dim=0).long()
return DataLoader(FeaturesDataset(features, labels),
batch_size=self.args.batch_size_align_pretuning, shuffle=True,
num_workers=0, drop_last=True)
[docs]
@torch.no_grad()
def create_features_dataset(self, data_loader, use_lora: bool):
labels, features = [], []
orig_mode = self.net.training
self.net.eval()
for i, data in enumerate(data_loader):
if self.args.debug_mode and i > 101:
break
x, y, _ = data
x, y = x.to(self.device), y.to(self.device).long()
z = self.net(x, train=False, return_features=True,
use_lora=use_lora)
z = z[:, 0]
features.append(z.detach().cpu())
labels.append(y.detach().cpu())
features = torch.cat(features, dim=0)
labels = torch.cat(labels, dim=0).long()
self.net.train(orig_mode)
return DataLoader(FeaturesDataset(features, labels),
batch_size=self.args.batch_size_align_pretuning, shuffle=True,
num_workers=0, drop_last=True)
[docs]
@torch.no_grad()
def compute_statistics(self, dataset, distributions, use_lora):
features_dict = defaultdict(list)
orig_mode = self.net.training
self.net.eval()
for i, data in enumerate(dataset.train_loader):
if self.args.debug_mode and i > 101:
break
x, labels, _ = data
x, labels = x.to(self.device), labels.to(self.device).long()
features = self.net(x, train=False, return_features=True, use_lora=use_lora)
features = features[:, 0]
for class_idx in labels.unique():
features_dict[int(class_idx)].append(features[labels == class_idx])
self.net.train(orig_mode)
for class_idx in features_dict.keys():
features_class_idx = torch.cat(features_dict[class_idx], dim=0).to(self.device)
distributions[class_idx].fit(features_class_idx)
[docs]
def get_sgd_optim(self, cls, lr):
params = cls.build_optimizer_args(lr)
return torch.optim.SGD(lr=lr, params=params)
[docs]
def sched(self, optim, num_epochs: int):
return CosineAnnealingLR(optimizer=optim, T_max=num_epochs)
[docs]
def align_pretuning(self, cls, distributions_to_sample_from, desc=''):
optim = self.get_sgd_optim(cls, lr=self.args.lr_align_pretuning)
num_epochs = self.args.num_epochs_align_pretuning + 5 * self.current_task
lr_scheduler = self.sched(optim, num_epochs)
for _ in tqdm(range(num_epochs), total=num_epochs, desc=desc):
data_loader = self.create_synthetic_features_dataset(
distributions_to_sample_from=distributions_to_sample_from)
linear_probing_epoch(data_loader, self.alignment_loss, cls,
optim, lr_scheduler, self.device)
return cls
[docs]
def masked_loss(self, cls, x, labels):
"""
Separate losses for current and previous tasks.
"""
logits = cls(x)
logits[:, :self.n_past_classes] = -float('inf')
loss = self.loss(logits, labels)
loss_val = loss.detach().item()
return loss, {'ce_pretuning': loss_val}
[docs]
def linear_probing(self, dataset, classifier, lr, num_epochs,
desc='', use_lora: bool = False):
optim = self.get_sgd_optim(classifier, lr=lr)
for _ in tqdm(range(num_epochs), total=num_epochs, desc=desc):
data_loader = self.create_features_dataset(dataset.train_loader,
use_lora=use_lora)
linear_probing_epoch(data_loader, self.masked_loss, classifier,
optim, None, self.device, debug_mode=self.args.debug_mode == 1)
return classifier
[docs]
def pretuning(self, dataset):
self.compute_statistics(dataset, self.pretrain_distributions,
use_lora=False)
lr = self.args.learning_rate_pretuning
num_epochs = self.args.num_epochs_pretuning
classifier = deepcopy(self.pretraining_classifier)
classifier.enable_training()
classifier = self.linear_probing(dataset, classifier, lr, num_epochs,
desc='Pre-Tuning - Task-IL (begin)',
use_lora=False)
self.align_pretuning(classifier, self.pretrain_distributions,
desc='Pre-Tuning - Class-IL (begin)')
self.pretraining_classifier.assign(classifier)
[docs]
def get_optimizer(self):
optimizer_arg = self.net.build_optimizer_args(self.args.lr)
return create_optimizer(self.args.optimizer, optimizer_arg, momentum=0.9)
[docs]
def get_scheduler(self):
return CosineSchedule(self.opt, K=self.args.n_epochs)
[docs]
def update_statistics(self, dataset):
self.net.vit.head.backup()
self.net.vit.head.assign(self.pretraining_classifier)
generative_dataloader = None
if self.args.use_iel:
if self.current_task > 0:
generative_dataloader = self.create_synthetic_features_dataset(self.pretrain_distributions, self.current_task)
else:
generative_dataloader = self.create_synthetic_features_dataset(self.pretrain_distributions, self.current_task + 1)
self.net.update_fisher(dataset, generative_dataloader, self.args.debug_mode == 1)
self.net.vit.head.recall()
[docs]
def begin_task(self, dataset):
num_classes = self.n_classes_current_task
if self.current_task > 0:
self.pretraining_classifier.update(nb_classes=num_classes)
self.net.vit.head.update(nb_classes=num_classes)
self.alignment_loss.set_current_task(self.current_task)
self.net.set_current_task(self.current_task)
self.pretuning(dataset)
self.update_statistics(dataset)
if hasattr(self, 'opt'):
self.opt.zero_grad()
del self.opt
self.opt = self.get_optimizer()
self.custom_scheduler = self.get_scheduler()
self.old_epoch, self.iteration = 0, 0
if self.buffergrad is not None:
del self.buffergrad
if self.buffergrad_cls is not None:
del self.buffergrad_cls
self.buffergrad = [torch.zeros_like(p)
for p in self.opt.param_groups[0]['params']]
self.buffergrad_cls = [torch.zeros_like(p)
for p in self.opt.param_groups[1]['params']]
# Train either with ITA or IEL
self.net.ensemble(self.args.use_iel)
torch.cuda.empty_cache()
[docs]
def end_task(self, dataset):
# Evaluate the merged model
self.net.ensemble(True)
[docs]
def forward(self, x, task_weights=None, returnt='out'):
assert returnt in ['out', 'features']
logits = self.net(x, train=False, task_weights=task_weights, return_features=returnt == 'features')
if returnt == 'features':
return logits
return logits[:, :self.n_seen_classes]
[docs]
def accuracy(self, pred, labels):
stream_preds = pred[:, :self.n_seen_classes].argmax(dim=1)
acc = (stream_preds == labels).sum().item() / len(labels)
return acc
[docs]
def compute_loss(self, stream_logits, stream_labels):
"""
Compute the loss for the current task.
"""
stream_logits[:, :self.n_past_classes] = -float('inf')
loss = self.loss(stream_logits[:, :self.n_seen_classes], stream_labels)
return loss
def _grad_backup(self, param_group, buffer, set_to_zero: bool):
for idx_p, p in enumerate(param_group):
buffer[idx_p].copy_(p.grad)
if set_to_zero:
torch.nn.init.zeros_(p.grad)
def _grad_recall(self, param_group, buffer, op):
for idx_p in range(len(param_group)):
if op == 'add':
param_group[idx_p].grad.add_(buffer[idx_p])
else:
param_group[idx_p].grad.copy_(buffer[idx_p])
def _apply_grads(self, lr, param_group, clip_grad):
for myparam in param_group:
if clip_grad is not None:
torch.nn.utils.clip_grad_norm_(myparam, clip_grad)
myparam.add_(myparam.grad, alpha=-lr)
myparam.grad.zero_()
[docs]
@torch.no_grad()
def grad_backup(self, param_group: str):
assert param_group in ['layers', 'cls']
id_group = {'layers': 0, 'cls': 1}[param_group]
set_to_zero = {'layers': False, 'cls': True}[param_group]
id_buffer = {'layers': self.buffergrad,
'cls': self.buffergrad_cls}[param_group]
self._grad_backup(self.opt.param_groups[id_group]['params'],
id_buffer, set_to_zero=set_to_zero)
[docs]
@torch.no_grad()
def grad_recall(self, param_group: str, op='set'):
assert op in ['add', 'set']
id_group = {'layers': 0, 'cls': 1}[param_group]
id_buffer = {'layers': self.buffergrad, 'cls': self.buffergrad_cls}[param_group]
self._grad_recall(self.opt.param_groups[id_group]['params'], id_buffer, op)
[docs]
@torch.no_grad()
def apply_grads(self, param_group: str):
assert param_group in ['layers', 'cls']
id_group = {'layers': 0, 'cls': 1}[param_group]
lr = 1.0
if self.custom_scheduler is not None:
base_lr = self.custom_scheduler.base_lrs[id_group]
lr = self.custom_scheduler.get_lr()[id_group] / base_lr
self._apply_grads(lr, self.opt.param_groups[id_group]['params'],
clip_grad=self.args.clip_grad)
[docs]
def observe(self, inputs, labels, not_aug_inputs, epoch=None):
labels = labels.long()
if self.custom_scheduler and self.old_epoch != epoch:
if epoch > 0:
self.custom_scheduler.step()
self.old_epoch = epoch
self.iteration = 0
self.net.iteration = self.iteration
log_dict = {}
stream_inputs, stream_labels = inputs, labels
stream_logits = self.net(stream_inputs, train=True)
with torch.no_grad():
log_dict['stream_class_il'] = self.accuracy(stream_logits, stream_labels)
stream_logits[:, :self.n_past_classes] = -float('inf')
with torch.no_grad():
log_dict['stream_task_il'] = self.accuracy(stream_logits, stream_labels)
loss = self.compute_loss(stream_logits, stream_labels)
if self.iteration == 0:
self.opt.zero_grad()
if self.args.virtual_bs_n > 1:
loss = loss / self.args.virtual_bs_n
loss.backward()
if (self.iteration > 0 or self.args.virtual_bs_n == 1) and \
self.iteration % self.args.virtual_bs_n == 0:
if self.reg_loss_is_active:
self.grad_backup('layers')
if self.reg_loss_cls_is_active:
self.grad_backup('cls')
with torch.set_grad_enabled(self.reg_loss_is_active):
reg_loss, dotprod_loss = self.net.compute_reg_loss(do_backward=self.reg_loss_is_active,
do_loss_computation=True)
with torch.set_grad_enabled(self.reg_loss_cls_is_active):
reg_cls = self.net.compute_classifier_reg_loss(
cls_ref=self.pretraining_classifier,
do_backward=self.reg_loss_cls_is_active)
with torch.no_grad():
log_dict['reg_loss'] = reg_loss.detach()
log_dict['dotprod_loss'] = dotprod_loss.detach()
log_dict['reg_cls'] = reg_cls.detach()
if self.reg_loss_is_active:
self.apply_grads('layers')
self.grad_recall('layers')
if self.reg_loss_cls_is_active:
self.apply_grads('cls')
self.grad_recall('cls')
self.opt.step()
self.opt.zero_grad()
self.iteration += 1
log_dict['loss'] = loss.item()
return log_dict