# Copyright 2021-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from utils.ring_buffer import RingBuffer
from datasets import get_dataset
from utils.args import *
from models.utils.continual_model import ContinualModel
from utils.buffer import Buffer
from utils.mixup import mixup
from utils.triplet import batch_hard_triplet_loss, negative_only_triplet_loss
import torch
import torch.nn.functional as F
[docs]
class Ccic(ContinualModel):
"""Continual Semi-Supervised Learning via Continual Contrastive Interpolation Consistency."""
NAME = 'ccic'
COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'cssl']
[docs]
@staticmethod
def get_parser(parser) -> ArgumentParser:
parser.set_defaults(optimizer='adam')
add_rehearsal_args(parser)
parser.set_defaults(optimizer='adam')
parser.add_argument('--alpha', type=float, default=0.5,
help='Unsupervised loss weight.')
parser.add_argument('--knn_k', '--k', type=int, default=2, dest='knn_k',
help='k of kNN.')
parser.add_argument('--memory_penalty', type=float,
default=1.0, help='Unsupervised penalty weight.')
parser.add_argument('--k_aug', type=int, default=3,
help='Number of augumentation to compute label predictions.')
parser.add_argument('--mixmatch_alpha', '--lamda', type=float, default=0.5, dest='mixmatch_alpha',
help='Regularization weight.')
parser.add_argument('--sharp_temp', default=0.5,
type=float, help='Temperature for sharpening.')
parser.add_argument('--mixup_alpha', default=0.75, type=float)
return parser
def __init__(self, backbone, loss, args, transform, dataset=None):
super(Ccic, self).__init__(backbone, loss, args, transform, dataset=dataset)
self.buffer = Buffer(self.args.buffer_size, self.device)
self.embeddings = None
self.eye = torch.eye(self.num_classes).to(self.device)
self.sup_virtual_batch = RingBuffer(self.args.batch_size)
self.unsup_virtual_batch = RingBuffer(self.args.batch_size)
[docs]
def get_debug_iters(self):
"""
Returns the number of iterations to wait before logging.
- CCIC needs a couple more iterations to initialize the KNN.
"""
return 1000 if len(self.buffer) < self.args.buffer_size else 5
[docs]
def forward(self, x):
if self.embeddings is None:
with torch.no_grad():
self.compute_embeddings()
n_seen_classes = self.cpt * self.current_task if isinstance(self.cpt, int) else sum(self.cpt[:self.current_task])
n_remaining_classes = self.N_CLASSES - n_seen_classes
buf_labels = self.buffer.labels[:self.buffer.num_seen_examples]
feats = self.net(x, returnt='features')
feats = F.normalize(feats, p=2, dim=1)
distances = (self.embeddings.unsqueeze(0) - feats.unsqueeze(1)).pow(2).sum(2)
dist = torch.stack([distances[:, buf_labels == c].topk(1, largest=False)[0].mean(dim=1)
if (buf_labels == c).sum() > 0 else torch.zeros(x.shape[0]).to(self.device)
for c in range(n_seen_classes)] +
[torch.zeros(x.shape[0]).to(self.device)] * n_remaining_classes).T
topkappas = self.eye[buf_labels[distances.topk(self.args.knn_k, largest=False)[1]]].sum(1)
return topkappas - dist * 10e-6
[docs]
def end_task(self, dataset):
self.embeddings = None
[docs]
def discard_unsupervised_labels(self, inputs, labels, not_aug_inputs):
mask = labels != -1
return inputs[mask], labels[mask], not_aug_inputs[mask]
[docs]
def discard_supervised_labels(self, inputs, labels, not_aug_inputs):
mask = labels == -1
return inputs[mask], labels[mask], not_aug_inputs[mask]
[docs]
def observe(self, inputs, labels, not_aug_inputs, epoch=None):
self.opt.zero_grad()
real_batch_size = inputs.shape[0]
sup_inputs, sup_labels, sup_not_aug_inputs = self.discard_unsupervised_labels(inputs, labels, not_aug_inputs)
sup_inputs_for_buffer, sup_labels_for_buffer = sup_not_aug_inputs.clone(), sup_labels.clone()
unsup_inputs, unsup_labels, unsup_not_aug_inputs = self.discard_supervised_labels(inputs, labels, not_aug_inputs)
if len(sup_inputs) == 0 and self.buffer.is_empty(): # if there is no data to train on, just return 1.
return 1.
self.sup_virtual_batch.add_data(sup_not_aug_inputs, sup_labels)
sup_inputs, sup_labels = self.sup_virtual_batch.get_data(self.args.batch_size, transform=self.transform, device=self.device)
if self.current_task > 0 and unsup_not_aug_inputs.shape[0] > 0:
self.unsup_virtual_batch.add_data(unsup_not_aug_inputs)
unsup_inputs = self.unsup_virtual_batch.get_data(self.args.batch_size, transform=self.transform, device=self.device)[0]
# BUFFER RETRIEVAL
if not self.buffer.is_empty():
buf_inputs, buf_labels = self.buffer.get_data(self.args.minibatch_size, transform=self.transform, device=self.device)
sup_inputs = torch.cat((sup_inputs, buf_inputs))
sup_labels = torch.cat((sup_labels, buf_labels))
if self.current_task > 0:
masked_buf_inputs = self.buffer.get_data(self.args.minibatch_size,
mask_task_out=self.current_task,
transform=self.transform,
cpt=self.n_classes_current_task,
device=self.device)[0]
unsup_labels = torch.cat((torch.zeros(unsup_inputs.shape[0]).to(self.device),
torch.ones(masked_buf_inputs.shape[0]).to(self.device))).long()
unsup_inputs = torch.cat((unsup_inputs, masked_buf_inputs))
# ------------------ K AUG ---------------------
mask = labels != -1
real_mask = mask[:real_batch_size]
if (~real_mask).sum() > 0:
unsup_aug_inputs = self.transform(not_aug_inputs[~real_mask].repeat_interleave(self.args.k_aug, 0))
else:
unsup_aug_inputs = torch.zeros((0,)).to(self.device)
# ------------------ PSEUDO LABEL ---------------------
self.net.eval()
if len(unsup_aug_inputs):
with torch.no_grad():
unsup_aug_outputs = self.net(unsup_aug_inputs).reshape(self.args.k_aug, -1, self.eye.shape[0]).mean(0)
unsup_sharp_outputs = unsup_aug_outputs ** (1 / self.args.sharp_temp)
unsup_norm_outputs = unsup_sharp_outputs / unsup_sharp_outputs.sum(1).unsqueeze(1)
unsup_norm_outputs = unsup_norm_outputs.repeat(self.args.k_aug, 1)
else:
unsup_norm_outputs = torch.zeros((0, len(self.eye))).to(self.device)
self.net.train()
# ------------------ MIXUP ---------------------
self.opt.zero_grad()
W_inputs = torch.cat((sup_inputs, unsup_aug_inputs))
W_probs = torch.cat((self.eye[sup_labels], unsup_norm_outputs))
perm = torch.randperm(W_inputs.shape[0])
W_inputs, W_probs = W_inputs[perm], W_probs[perm]
sup_shape = sup_inputs.shape[0]
sup_mix_inputs, _ = mixup([(sup_inputs, W_inputs[:sup_shape]), (self.eye[sup_labels], W_probs[:sup_shape])], self.args.mixup_alpha)
sup_mix_outputs = self.net(sup_mix_inputs)
if len(unsup_aug_inputs):
unsup_mix_inputs, _ = mixup(
[(unsup_aug_inputs, W_inputs[sup_shape:]),
(unsup_norm_outputs, W_probs[sup_shape:])],
self.args.mixup_alpha)
unsup_mix_outputs = self.net(unsup_mix_inputs)
effective_mbs = min(self.args.minibatch_size,
self.buffer.num_seen_examples)
if effective_mbs == 0:
effective_mbs = -self.N_CLASSES
# ------------------ CIC LOSS ---------------------
loss_X = 0
if real_mask.sum() > 0:
loss_X += self.loss(sup_mix_outputs[:-effective_mbs],
sup_labels[:-effective_mbs])
if not self.buffer.is_empty():
assert effective_mbs > 0
loss_X += self.args.memory_penalty * self.loss(sup_mix_outputs[-effective_mbs:],
sup_labels[-effective_mbs:])
if len(unsup_aug_inputs):
loss_U = F.mse_loss(unsup_norm_outputs, unsup_mix_outputs) / self.eye.shape[0]
else:
loss_U = 0
# CIC LOSS
if self.current_task > 0 and epoch < self.args.n_epochs / 10 * 9:
W_inputs = sup_inputs
W_probs = self.eye[sup_labels]
perm = torch.randperm(W_inputs.shape[0])
W_inputs, W_probs = W_inputs[perm], W_probs[perm]
sup_mix_inputs, _ = mixup([(sup_inputs, W_inputs), (self.eye[sup_labels], W_probs)], 1)
else:
sup_mix_inputs = sup_inputs
# STANDARD TRIPLET
sup_mix_embeddings = self.net.features(sup_mix_inputs)
loss = batch_hard_triplet_loss(sup_labels, sup_mix_embeddings, self.args.batch_size // 10,
margin=1, margin_type='hard')
if loss is None:
loss = loss_X + self.args.mixmatch_alpha * loss_U
else:
loss += loss_X + self.args.mixmatch_alpha * loss_U
self.buffer.add_data(examples=sup_inputs_for_buffer,
labels=sup_labels_for_buffer)
# SELF-SUPERVISED PAST TASKS NEGATIVE ONLY
if self.current_task > 0 and epoch < self.args.n_epochs / 10 * 9:
unsup_embeddings = self.net.features(unsup_inputs)
loss_unsup = negative_only_triplet_loss(unsup_labels, unsup_embeddings, self.args.batch_size // 10,
margin=1, margin_type='hard')
if loss_unsup is not None:
loss += self.args.alpha * loss_unsup
loss.backward()
self.opt.step()
return loss.item()
[docs]
@torch.no_grad()
def compute_embeddings(self):
"""
Computes a vector representing mean features for each class.
"""
was_training = self.net.training
self.net.eval()
data = self.buffer.get_all_data(transform=self.normalization_transform)[0]
outputs = []
while data.shape[0] > 0:
inputs = data[:self.args.batch_size]
data = data[self.args.batch_size:]
out = self.net(inputs, returnt='features')
out = F.normalize(out, p=2, dim=1)
outputs.append(out)
self.embeddings = torch.cat(outputs)
self.net.train(was_training)