Source code for models.hal

# Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import sys
import numpy as np
import torch
from torch.optim import SGD


from backbone import get_backbone
from datasets import get_dataset
from models.utils.continual_model import ContinualModel
from utils.args import add_rehearsal_args, ArgumentParser
from utils.ring_buffer import RingBuffer as Buffer


[docs] class HAL(ContinualModel): """Hindsight Anchor Learning.""" NAME = 'hal' COMPATIBILITY = ['class-il', 'domain-il', 'task-il']
[docs] @staticmethod def get_parser(parser) -> ArgumentParser: add_rehearsal_args(parser) parser.add_argument('--hal_lambda', type=float, default=0.1) parser.add_argument('--beta', type=float, default=0.5) parser.add_argument('--gamma', type=float, default=0.1) return parser
def __init__(self, backbone, loss, args, transform, dataset=None): super().__init__(backbone, loss, args, transform, dataset=dataset) self.task_number = 0 self.buffer = Buffer(self.args.buffer_size, n_tasks=get_dataset(args).N_TASKS) self.hal_lambda = args.hal_lambda self.beta = args.beta self.gamma = args.gamma self.anchor_optimization_steps = 100 self.finetuning_epochs = 1 self.dataset = get_dataset(args) self.spare_model = get_backbone(self.args) self.spare_model.to(self.device) self.spare_opt = SGD(self.spare_model.parameters(), lr=self.args.lr)
[docs] def end_task(self, dataset): self.task_number += 1 # ring buffer mgmt (if we are not loading if self.task_number > self.buffer.task_number: self.buffer.num_seen_examples = 0 self.buffer.task_number = self.task_number # get anchors (provided that we are not loading the model if len(self.anchors) < self.task_number * dataset.N_CLASSES_PER_TASK: self.get_anchors(dataset) del self.phi
[docs] def get_anchors(self, dataset): theta_t = self.net.get_params().detach().clone() self.spare_model.set_params(theta_t) # fine tune on memory buffer for _ in range(self.finetuning_epochs): inputs, labels = self.buffer.get_data(self.args.batch_size, transform=self.transform, device=self.device) self.spare_opt.zero_grad() out = self.spare_model(inputs) loss = self.loss(out, labels) loss.backward() self.spare_opt.step() theta_m = self.spare_model.get_params().detach().clone() classes_for_this_task = np.unique(dataset.train_loader.dataset.targets) for a_class in classes_for_this_task: e_t = torch.rand(self.input_shape, requires_grad=True, device=self.device) e_t_opt = SGD([e_t], lr=self.args.lr) print(file=sys.stderr) for i in range(self.anchor_optimization_steps): e_t_opt.zero_grad() cum_loss = 0 self.spare_opt.zero_grad() self.spare_model.set_params(theta_m.detach().clone()) loss = -torch.sum(self.loss(self.spare_model(e_t.unsqueeze(0)), torch.tensor([a_class]).to(self.device))) loss.backward() cum_loss += loss.item() self.spare_opt.zero_grad() self.spare_model.set_params(theta_t.detach().clone()) loss = torch.sum(self.loss(self.spare_model(e_t.unsqueeze(0)), torch.tensor([a_class]).to(self.device))) loss.backward() cum_loss += loss.item() self.spare_opt.zero_grad() loss = torch.sum(self.gamma * (self.spare_model(e_t.unsqueeze(0), returnt='features') - self.phi) ** 2) assert not self.phi.requires_grad loss.backward() cum_loss += loss.item() e_t_opt.step() e_t = e_t.detach() e_t.requires_grad = False self.anchors = torch.cat((self.anchors, e_t.unsqueeze(0))) del e_t print('Total anchors:', len(self.anchors), file=sys.stderr) self.spare_model.zero_grad()
[docs] def observe(self, inputs, labels, not_aug_inputs, epoch=None): real_batch_size = inputs.shape[0] if not hasattr(self, 'input_shape'): self.input_shape = inputs.shape[1:] if not hasattr(self, 'anchors'): self.anchors = torch.zeros(tuple([0] + list(self.input_shape))).to(self.device) if not hasattr(self, 'phi'): print('Building phi', file=sys.stderr) with torch.no_grad(): self.phi = torch.zeros_like(self.net(inputs[0].unsqueeze(0), returnt='features'), requires_grad=False) assert not self.phi.requires_grad if not self.buffer.is_empty(): buf_inputs, buf_labels = self.buffer.get_data( self.args.minibatch_size, transform=self.transform, device=self.device) inputs = torch.cat((inputs, buf_inputs)) labels = torch.cat((labels, buf_labels)) old_weights = self.net.get_params().detach().clone() self.opt.zero_grad() outputs = self.net(inputs) k = self.task_number loss = self.loss(outputs, labels) loss.backward() self.opt.step() first_loss = 0 assert len(self.anchors) == self.dataset.N_CLASSES_PER_TASK * k if len(self.anchors) > 0: first_loss = loss.item() with torch.no_grad(): pred_anchors = self.net(self.anchors) self.net.set_params(old_weights) pred_anchors -= self.net(self.anchors) loss = self.hal_lambda * (pred_anchors ** 2).mean() loss.backward() self.opt.step() with torch.no_grad(): self.phi = self.beta * self.phi + (1 - self.beta) * self.net(inputs[:real_batch_size], returnt='features').mean(0) self.buffer.add_data(examples=not_aug_inputs, labels=labels[:real_batch_size]) return first_loss + loss.item()