Source code for models.xder_ce

# Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from argparse import ArgumentParser
import torch
from torch.nn import functional as F

from utils import binary_to_boolean_type
from utils.args import add_rehearsal_args
from models.utils.continual_model import ContinualModel
from utils.batch_norm import bn_track_stats


[docs] class XDerCe(ContinualModel): """Continual learning via eXtended Dark Experience Replay with cross-entropy on future heads.""" NAME = 'xder_ce' COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual']
[docs] @staticmethod def get_parser(parser) -> ArgumentParser: add_rehearsal_args(parser) parser.add_argument('--alpha', type=float, required=True, help='Penalty weight.') parser.add_argument('--beta', type=float, required=True, help='Penalty weight.') parser.add_argument('--gamma', type=float, default=0.85, help='Weight for logit update') # log_update_weight parser.add_argument('--constr_eta', type=float, default=0.01, help='Regularization weight for past/future constraints') # constr_weight parser.add_argument('--constr_margin', type=float, default=0.3, help='Margin for past/future constraints') parser.add_argument('--align_bn', type=binary_to_boolean_type, default=1, help='Use BatchNorm alignment') return parser
def __init__(self, backbone, loss, args, transform, dataset=None): super().__init__(backbone, loss, args, transform, dataset=dataset) from utils.buffer import Buffer self.buffer = Buffer(self.args.buffer_size) self.update_counter = torch.zeros(self.args.buffer_size)
[docs] def end_task(self, dataset): tng = self.training self.train() # fdr reduce coreset if self.current_task > 0: examples_per_class = self.args.buffer_size // self.n_seen_classes buf_x, buf_lab, buf_log, buf_tl = self.buffer.get_all_data() self.buffer.empty() for tl in buf_lab.unique(): idx = tl == buf_lab ex, lab, log, tasklab = buf_x[idx], buf_lab[idx], buf_log[idx], buf_tl[idx] first = min(ex.shape[0], examples_per_class) self.buffer.add_data( examples=ex[:first], labels=lab[:first], logits=log[:first], task_labels=tasklab[:first] ) # fdr add new task examples_last_task = self.buffer.buffer_size - self.buffer.num_seen_examples examples_per_class = examples_last_task // self.cpt ce = torch.tensor([examples_per_class] * self.cpt).int() ce[torch.randperm(self.cpt)[:examples_last_task - (examples_per_class * self.cpt)]] += 1 with torch.no_grad(): with bn_track_stats(self, False): for data in dataset.train_loader: inputs, labels, not_aug_inputs = data[0], data[1], data[2] inputs = inputs.to(self.device) not_aug_inputs = not_aug_inputs.to(self.device) outputs = self.net(inputs) if all(ce == 0): break # update past if self.current_task > 0: outputs = self.update_logits(outputs, outputs, labels, 0, self.current_task) flags = torch.zeros(len(inputs)).bool() for j in range(len(flags)): if ce[labels[j] % self.cpt] > 0: flags[j] = True ce[labels[j] % self.cpt] -= 1 self.buffer.add_data(examples=not_aug_inputs[flags], labels=labels[flags], logits=outputs.data[flags], task_labels=(torch.ones(len(not_aug_inputs)) * self.current_task)[flags]) # update future past buf_idx, buf_inputs, buf_labels, buf_logits, _ = self.buffer.get_data(self.buffer.buffer_size, transform=self.transform, return_index=True, device=self.device) buf_outputs = [] while len(buf_inputs): buf_outputs.append(self.net(buf_inputs[:self.args.batch_size])) buf_inputs = buf_inputs[self.args.batch_size:] buf_outputs = torch.cat(buf_outputs) chosen = ((buf_labels // self.cpt) < self.current_task).to(self.buffer.device) if chosen.any(): to_transplant = self.update_logits(buf_logits[chosen], buf_outputs[chosen], buf_labels[chosen], self.current_task, self.n_tasks - self.current_task) self.buffer.logits[buf_idx[chosen], :] = to_transplant.to(self.buffer.device) self.buffer.task_labels[buf_idx[chosen]] = self.current_task self.update_counter = torch.zeros(self.args.buffer_size) self.train(tng)
[docs] def update_logits(self, old, new, gt, task_start, n_tasks=1): offset_1, _ = self.dataset.get_offsets(task_start) offset_2, _ = self.dataset.get_offsets(task_start + n_tasks) transplant = new[:, offset_1:offset_2] gt_values = old[torch.arange(len(gt)), gt] max_values = transplant.max(1).values coeff = self.args.gamma * gt_values / max_values coeff = coeff.unsqueeze(1).repeat(1, offset_2 - offset_1) mask = (max_values > gt_values).unsqueeze(1).repeat(1, offset_2 - offset_1) transplant[mask] *= coeff[mask] old[:, offset_1:offset_2] = transplant return old
[docs] def observe(self, inputs, labels, not_aug_inputs, epoch=None): self.opt.zero_grad() with bn_track_stats(self, not self.args.align_bn or self.current_task == 0): outputs = self.net(inputs) # Present head loss_stream = self.loss(outputs[:, self.n_past_classes:self.n_seen_classes], labels - self.n_past_classes) loss_der, loss_derpp = torch.tensor(0.), torch.tensor(0.) if not self.buffer.is_empty(): # Distillation Replay Loss (all heads) buf_idx1, buf_inputs1, buf_labels1, buf_logits1, buf_tl1 = self.buffer.get_data( self.args.minibatch_size, transform=self.transform, return_index=True, device=self.device) if self.args.align_bn: buf_inputs1 = torch.cat([buf_inputs1, inputs[:self.args.minibatch_size // self.current_task]]) buf_outputs1 = self.net(buf_inputs1) if self.args.align_bn: buf_inputs1 = buf_inputs1[:self.args.minibatch_size] buf_outputs1 = buf_outputs1[:self.args.minibatch_size] mse = F.mse_loss(buf_outputs1, buf_logits1, reduction='none') loss_der = self.args.alpha * mse.mean() # Label Replay Loss (past heads) buf_idx2, buf_inputs2, buf_labels2, buf_logits2, buf_tl2 = self.buffer.get_data( self.args.minibatch_size, transform=self.transform, return_index=True, device=self.device) with bn_track_stats(self, not self.args.align_bn): buf_outputs2 = self.net(buf_inputs2) buf_ce = self.loss(buf_outputs2[:, :self.n_past_classes], buf_labels2) loss_derpp = self.args.beta * buf_ce # Merge Batches & Remove Duplicates buf_idx = torch.cat([buf_idx1, buf_idx2]) buf_inputs = torch.cat([buf_inputs1, buf_inputs2]) buf_labels = torch.cat([buf_labels1, buf_labels2]) buf_logits = torch.cat([buf_logits1, buf_logits2]) buf_outputs = torch.cat([buf_outputs1, buf_outputs2]) buf_tl = torch.cat([buf_tl1, buf_tl2]) # remove dupulicates eyey = torch.eye(self.buffer.buffer_size).to(buf_idx.device)[buf_idx] umask = (eyey * eyey.cumsum(0)).sum(1) < 2 buf_idx = buf_idx[umask].to(self.buffer.device) buf_inputs = buf_inputs[umask] buf_labels = buf_labels[umask] buf_logits = buf_logits[umask] buf_outputs = buf_outputs[umask] buf_tl = buf_tl[umask] # Update Future Past Logits with torch.no_grad(): chosen = ((buf_labels // self.cpt) < self.current_task).to(self.buffer.device) c = chosen.clone() self.update_counter[buf_idx[chosen]] += 1 chosen[c] = torch.rand_like(chosen[c].float()) * self.update_counter[buf_idx[c]] < 1 if chosen.any(): assert self.current_task > 0 to_transplant = self.update_logits(buf_logits[chosen], buf_outputs[chosen], buf_labels[chosen], self.current_task, self.n_tasks - self.current_task) self.buffer.logits[buf_idx[chosen], :] = to_transplant.to(self.buffer.device) self.buffer.task_labels[buf_idx[chosen]] = self.current_task # Consistency Loss (future heads) loss_constr_futu = torch.tensor(0.) if self.current_task < self.n_tasks - 1: # Future Logits Constraint bad_head = outputs[:, self.n_seen_classes:] good_head = outputs[:, self.n_past_classes:self.n_seen_classes] if not self.buffer.is_empty(): buf_tlgt = buf_labels // self.cpt bad_head = torch.cat([bad_head, buf_outputs[:, self.n_seen_classes:]]) good_head = torch.cat([good_head, torch.stack(buf_outputs.split(self.cpt, 1), 1)[torch.arange(len(buf_tlgt)), buf_tlgt]]) loss_constr = bad_head.max(1)[0] + self.args.constr_margin - good_head.max(1)[0] mask = loss_constr > 0 if (mask).any(): loss_constr_futu = self.args.constr_eta * loss_constr[mask].mean() loss = loss_stream + loss_der + loss_derpp + loss_constr_futu loss.backward() self.opt.step() return loss.item()