Source code for models.joint_gcl

# Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import math
import torch

from backbone import get_backbone
from models.utils.continual_model import ContinualModel
from utils.args import ArgumentParser
from utils.status import progress_bar


[docs] class JointGCL(ContinualModel): """Joint training: a strong, simple baseline.""" NAME = 'joint_gcl' COMPATIBILITY = ['general-continual']
[docs] @staticmethod def get_parser(parser) -> ArgumentParser: parser.set_defaults(n_epochs=1) return parser
def __init__(self, backbone, loss, args, transform, dataset=None): super(JointGCL, self).__init__(backbone, loss, args, transform, dataset=dataset) self.old_data = [] self.old_labels = []
[docs] def end_task(self, dataset): # reinit network self.net = get_backbone(self.args) self.net.to(self.device) self.net.train() self.opt = self.get_optimizer() # gather data all_data = torch.cat(self.old_data) all_labels = torch.cat(self.old_labels) # train (single epochs because GCL) rp = torch.randperm(len(all_data)) for i in range(math.ceil(len(all_data) / self.args.batch_size)): inputs = all_data[rp][i * self.args.batch_size:(i + 1) * self.args.batch_size] labels = all_labels[rp][i * self.args.batch_size:(i + 1) * self.args.batch_size] inputs, labels = inputs.to(self.device), labels.to(self.device) self.opt.zero_grad() outputs = self.net(inputs) loss = self.loss(outputs, labels.long()) loss.backward() self.opt.step() progress_bar(i, math.ceil(len(all_data) / self.args.batch_size), 0, 'J', loss.item())
[docs] def observe(self, inputs, labels, not_aug_inputs, epoch=None): self.old_data.append(inputs.data) self.old_labels.append(labels.data) return 0