# Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from argparse import ArgumentParser
import torch
from torch.nn import functional as F
from utils import binary_to_boolean_type
from utils.args import add_rehearsal_args
from models.utils.continual_model import ContinualModel
from utils.batch_norm import bn_track_stats
[docs]
class XDerCe(ContinualModel):
"""Continual learning via eXtended Dark Experience Replay with cross-entropy on future heads."""
NAME = 'xder_ce'
COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual']
[docs]
@staticmethod
def get_parser(parser) -> ArgumentParser:
add_rehearsal_args(parser)
parser.add_argument('--alpha', type=float, required=True, help='Penalty weight.')
parser.add_argument('--beta', type=float, required=True, help='Penalty weight.')
parser.add_argument('--gamma', type=float, default=0.85, help='Weight for logit update') # log_update_weight
parser.add_argument('--constr_eta', type=float, default=0.01, help='Regularization weight for past/future constraints') # constr_weight
parser.add_argument('--constr_margin', type=float, default=0.3, help='Margin for past/future constraints')
parser.add_argument('--align_bn', type=binary_to_boolean_type, default=1, help='Use BatchNorm alignment')
return parser
def __init__(self, backbone, loss, args, transform, dataset=None):
super().__init__(backbone, loss, args, transform, dataset=dataset)
from utils.buffer import Buffer
self.buffer = Buffer(self.args.buffer_size)
self.update_counter = torch.zeros(self.args.buffer_size)
[docs]
def end_task(self, dataset):
tng = self.training
self.train()
# fdr reduce coreset
if self.current_task > 0:
examples_per_class = self.args.buffer_size // self.n_seen_classes
buf_x, buf_lab, buf_log, buf_tl = self.buffer.get_all_data()
self.buffer.empty()
for tl in buf_lab.unique():
idx = tl == buf_lab
ex, lab, log, tasklab = buf_x[idx], buf_lab[idx], buf_log[idx], buf_tl[idx]
first = min(ex.shape[0], examples_per_class)
self.buffer.add_data(
examples=ex[:first],
labels=lab[:first],
logits=log[:first],
task_labels=tasklab[:first]
)
# fdr add new task
examples_last_task = self.buffer.buffer_size - self.buffer.num_seen_examples
examples_per_class = examples_last_task // self.cpt
ce = torch.tensor([examples_per_class] * self.cpt).int()
ce[torch.randperm(self.cpt)[:examples_last_task - (examples_per_class * self.cpt)]] += 1
with torch.no_grad():
with bn_track_stats(self, False):
for data in dataset.train_loader:
inputs, labels, not_aug_inputs = data[0], data[1], data[2]
inputs = inputs.to(self.device)
not_aug_inputs = not_aug_inputs.to(self.device)
outputs = self.net(inputs)
if all(ce == 0):
break
# update past
if self.current_task > 0:
outputs = self.update_logits(outputs, outputs, labels, 0, self.current_task)
flags = torch.zeros(len(inputs)).bool()
for j in range(len(flags)):
if ce[labels[j] % self.cpt] > 0:
flags[j] = True
ce[labels[j] % self.cpt] -= 1
self.buffer.add_data(examples=not_aug_inputs[flags],
labels=labels[flags],
logits=outputs.data[flags],
task_labels=(torch.ones(len(not_aug_inputs)) * self.current_task)[flags])
# update future past
buf_idx, buf_inputs, buf_labels, buf_logits, _ = self.buffer.get_data(self.buffer.buffer_size,
transform=self.transform, return_index=True, device=self.device)
buf_outputs = []
while len(buf_inputs):
buf_outputs.append(self.net(buf_inputs[:self.args.batch_size]))
buf_inputs = buf_inputs[self.args.batch_size:]
buf_outputs = torch.cat(buf_outputs)
chosen = ((buf_labels // self.cpt) < self.current_task).to(self.buffer.device)
if chosen.any():
to_transplant = self.update_logits(buf_logits[chosen], buf_outputs[chosen], buf_labels[chosen], self.current_task, self.n_tasks - self.current_task)
self.buffer.logits[buf_idx[chosen], :] = to_transplant.to(self.buffer.device)
self.buffer.task_labels[buf_idx[chosen]] = self.current_task
self.update_counter = torch.zeros(self.args.buffer_size)
self.train(tng)
[docs]
def update_logits(self, old, new, gt, task_start, n_tasks=1):
offset_1, _ = self.dataset.get_offsets(task_start)
offset_2, _ = self.dataset.get_offsets(task_start + n_tasks)
transplant = new[:, offset_1:offset_2]
gt_values = old[torch.arange(len(gt)), gt]
max_values = transplant.max(1).values
coeff = self.args.gamma * gt_values / max_values
coeff = coeff.unsqueeze(1).repeat(1, offset_2 - offset_1)
mask = (max_values > gt_values).unsqueeze(1).repeat(1, offset_2 - offset_1)
transplant[mask] *= coeff[mask]
old[:, offset_1:offset_2] = transplant
return old
[docs]
def observe(self, inputs, labels, not_aug_inputs, epoch=None):
self.opt.zero_grad()
with bn_track_stats(self, not self.args.align_bn or self.current_task == 0):
outputs = self.net(inputs)
# Present head
loss_stream = self.loss(outputs[:, self.n_past_classes:self.n_seen_classes], labels - self.n_past_classes)
loss_der, loss_derpp = torch.tensor(0.), torch.tensor(0.)
if not self.buffer.is_empty():
# Distillation Replay Loss (all heads)
buf_idx1, buf_inputs1, buf_labels1, buf_logits1, buf_tl1 = self.buffer.get_data(
self.args.minibatch_size, transform=self.transform, return_index=True, device=self.device)
if self.args.align_bn:
buf_inputs1 = torch.cat([buf_inputs1, inputs[:self.args.minibatch_size // self.current_task]])
buf_outputs1 = self.net(buf_inputs1)
if self.args.align_bn:
buf_inputs1 = buf_inputs1[:self.args.minibatch_size]
buf_outputs1 = buf_outputs1[:self.args.minibatch_size]
mse = F.mse_loss(buf_outputs1, buf_logits1, reduction='none')
loss_der = self.args.alpha * mse.mean()
# Label Replay Loss (past heads)
buf_idx2, buf_inputs2, buf_labels2, buf_logits2, buf_tl2 = self.buffer.get_data(
self.args.minibatch_size, transform=self.transform, return_index=True, device=self.device)
with bn_track_stats(self, not self.args.align_bn):
buf_outputs2 = self.net(buf_inputs2)
buf_ce = self.loss(buf_outputs2[:, :self.n_past_classes], buf_labels2)
loss_derpp = self.args.beta * buf_ce
# Merge Batches & Remove Duplicates
buf_idx = torch.cat([buf_idx1, buf_idx2])
buf_inputs = torch.cat([buf_inputs1, buf_inputs2])
buf_labels = torch.cat([buf_labels1, buf_labels2])
buf_logits = torch.cat([buf_logits1, buf_logits2])
buf_outputs = torch.cat([buf_outputs1, buf_outputs2])
buf_tl = torch.cat([buf_tl1, buf_tl2])
# remove dupulicates
eyey = torch.eye(self.buffer.buffer_size).to(buf_idx.device)[buf_idx]
umask = (eyey * eyey.cumsum(0)).sum(1) < 2
buf_idx = buf_idx[umask].to(self.buffer.device)
buf_inputs = buf_inputs[umask]
buf_labels = buf_labels[umask]
buf_logits = buf_logits[umask]
buf_outputs = buf_outputs[umask]
buf_tl = buf_tl[umask]
# Update Future Past Logits
with torch.no_grad():
chosen = ((buf_labels // self.cpt) < self.current_task).to(self.buffer.device)
c = chosen.clone()
self.update_counter[buf_idx[chosen]] += 1
chosen[c] = torch.rand_like(chosen[c].float()) * self.update_counter[buf_idx[c]] < 1
if chosen.any():
assert self.current_task > 0
to_transplant = self.update_logits(buf_logits[chosen], buf_outputs[chosen], buf_labels[chosen], self.current_task, self.n_tasks - self.current_task)
self.buffer.logits[buf_idx[chosen], :] = to_transplant.to(self.buffer.device)
self.buffer.task_labels[buf_idx[chosen]] = self.current_task
# Consistency Loss (future heads)
loss_constr_futu = torch.tensor(0.)
if self.current_task < self.n_tasks - 1:
# Future Logits Constraint
bad_head = outputs[:, self.n_seen_classes:]
good_head = outputs[:, self.n_past_classes:self.n_seen_classes]
if not self.buffer.is_empty():
buf_tlgt = buf_labels // self.cpt
bad_head = torch.cat([bad_head, buf_outputs[:, self.n_seen_classes:]])
good_head = torch.cat([good_head, torch.stack(buf_outputs.split(self.cpt, 1), 1)[torch.arange(len(buf_tlgt)), buf_tlgt]])
loss_constr = bad_head.max(1)[0] + self.args.constr_margin - good_head.max(1)[0]
mask = loss_constr > 0
if (mask).any():
loss_constr_futu = self.args.constr_eta * loss_constr[mask].mean()
loss = loss_stream + loss_der + loss_derpp + loss_constr_futu
loss.backward()
self.opt.step()
return loss.item()