Source code for models.gem

# Copyright 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the gem_license file in the root of this source tree.

import numpy as np
import torch
import os
from utils.conf import warn_once

try:
    import quadprog as solver
except ImportError:
    warn_once("`quadprog` not found, trying with `qpsolvers`. Note that the code is only tested with `quadprog`.")
    try:
        import qpsolvers as solver
    except ImportError:
        solver = None
        warn_once('Warning: qpsolvers not found (GEM and A-GEM will not work)')
        raise ImportError

from models.utils.continual_model import ContinualModel
from utils.args import add_rehearsal_args, ArgumentParser
from utils.buffer import Buffer, fill_buffer


[docs] def store_grad(params, grads, grad_dims): """ This stores parameter gradients of past tasks. pp: parameters grads: gradients grad_dims: list with number of parameters per layers """ # store the gradients grads.fill_(0.0) count = 0 for param in params(): if param.grad is not None: begin = 0 if count == 0 else sum(grad_dims[:count]) end = np.sum(grad_dims[:count + 1]) grads[begin: end].copy_(param.grad.data.view(-1)) count += 1
[docs] def overwrite_grad(params, newgrad, grad_dims): """ This is used to overwrite the gradients with a new gradient vector, whenever violations occur. pp: parameters newgrad: corrected gradient grad_dims: list storing number of parameters at each layer """ count = 0 for param in params(): if param.grad is not None: begin = 0 if count == 0 else sum(grad_dims[:count]) end = sum(grad_dims[:count + 1]) this_grad = newgrad[begin: end].contiguous().view( param.grad.data.size()) param.grad.data.copy_(this_grad) count += 1
[docs] def project2cone2(gradient, memories, margin=0.5, eps=1e-3): """ Solves the GEM dual QP described in the paper given a proposed gradient "gradient", and a memory of task gradients "memories". Overwrites "gradient" with the final projected update. input: gradient, p-vector input: memories, (t * p)-vector output: x, p-vector """ memories_np = memories.cpu().t().double().numpy() gradient_np = gradient.cpu().contiguous().view(-1).double().numpy() n_rows = memories_np.shape[0] self_prod = np.dot(memories_np, memories_np.transpose()) self_prod = 0.5 * (self_prod + self_prod.transpose()) + np.eye(n_rows) * eps grad_prod = np.dot(memories_np, gradient_np) * -1 G = np.eye(n_rows) h = np.zeros(n_rows) + margin v = solver.solve_qp(self_prod, grad_prod, G, h)[0] x = np.dot(v, memories_np) + gradient_np gradient.copy_(torch.from_numpy(x).view(-1, 1))
[docs] class Gem(ContinualModel): """Continual learning via Gradient Episodic Memory.""" NAME = 'gem' COMPATIBILITY = ['class-il', 'domain-il', 'task-il']
[docs] @staticmethod def get_parser(parser) -> ArgumentParser: add_rehearsal_args(parser) parser.add_argument('--gamma', type=float, default=0.5, help='Margin parameter for GEM.') return parser
def __init__(self, backbone, loss, args, transform, dataset=None): assert solver is not None, 'GEM requires quadprog (linux only, python <= 3.10) or qpsolvers (cross-platform)' super(Gem, self).__init__(backbone, loss, args, transform, dataset=dataset) self.buffer = Buffer(self.args.buffer_size) # Allocate temporary synaptic memory self.grad_dims = [] for pp in self.parameters(): self.grad_dims.append(pp.data.numel()) self.grads_cs = [] self.grads_da = torch.zeros(np.sum(self.grad_dims)).to(self.device)
[docs] def end_task(self, dataset): self.grads_cs.append(torch.zeros( np.sum(self.grad_dims)).to(self.device)) fill_buffer(self.buffer, dataset, self.current_task, required_attributes=['examples', 'labels', 'task_labels'])
[docs] def observe(self, inputs, labels, not_aug_inputs, epoch=None): if not self.buffer.is_empty(): buf_inputs, buf_labels, buf_task_labels = self.buffer.get_data( self.args.buffer_size, transform=self.transform, device=self.device) for tt in buf_task_labels.unique(): # compute gradient on the memory buffer self.opt.zero_grad() cur_task_inputs = buf_inputs[buf_task_labels == tt] cur_task_labels = buf_labels[buf_task_labels == tt] cur_task_outputs = self.forward(cur_task_inputs) penalty = self.loss(cur_task_outputs, cur_task_labels) penalty.backward() store_grad(self.parameters, self.grads_cs[tt], self.grad_dims) # now compute the grad on the current data self.opt.zero_grad() outputs = self.forward(inputs) loss = self.loss(outputs, labels) loss.backward() # check if gradient violates buffer constraints if not self.buffer.is_empty(): # copy gradient store_grad(self.parameters, self.grads_da, self.grad_dims) dot_prod = torch.mm(self.grads_da.unsqueeze(0), torch.stack(self.grads_cs).T) if (dot_prod < 0).sum() != 0: project2cone2(self.grads_da.unsqueeze(1), torch.stack(self.grads_cs).T, margin=self.args.gamma) # copy gradients back overwrite_grad(self.parameters, self.grads_da, self.grad_dims) self.opt.step() return loss.item()