Source code for models.er_ace_cscct

import torch

from utils.args import add_rehearsal_args, ArgumentParser
from models.cscct_utils.cscct_model import CscCtModel


[docs] class ErACECscCt(CscCtModel): """ER-ACE with future not fixed (as made by authors). Treated with CSCCT!""" NAME = 'er_ace_cscct' COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual']
[docs] @staticmethod def get_parser(parser) -> ArgumentParser: add_rehearsal_args(parser) CscCtModel.add_cscct_args(parser) return parser
def __init__(self, backbone, loss, args, transform, dataset=None): super().__init__(backbone, loss, args, transform, dataset) self.seen_so_far = torch.tensor([], dtype=torch.long, device=self.device)
[docs] def observe(self, inputs: torch.Tensor, labels: torch.Tensor, not_aug_inputs: torch.Tensor, epoch=None): self.opt.zero_grad() present = labels.unique() self.seen_so_far = torch.cat([self.seen_so_far, present]).unique() logits = self.net(inputs) mask = torch.zeros_like(logits) mask[:, present] = 1 if self.seen_so_far.max() < (self.N_CLASSES - 1): mask[:, self.seen_so_far.max():] = 1 if self.current_task > 0: logits = logits.masked_fill(mask == 0, torch.finfo(logits.dtype).min) loss_class = self.loss(logits, labels) loss = loss_class if self.current_task > 0 and self.args.buffer_size > 0: # sample from buffer buf_inputs, buf_labels = self.buffer.get_data(self.args.minibatch_size, transform=self.transform) loss_erace = self.loss(self.net(buf_inputs), buf_labels) loss += loss_erace if self.args.csc_weight > 0 and self.args.ct_weight > 0: # concatenate stream with buf full_inputs = torch.cat([inputs, buf_inputs], dim=0) full_targets = torch.cat([labels, buf_labels], dim=0) loss_cscct = self.get_cscct_loss(full_inputs, full_targets) loss += loss_cscct if self.args.buffer_size > 0: self.buffer.add_data(examples=not_aug_inputs, labels=labels) loss.backward() self.opt.step() return loss.item()