import logging
import sys
import time
import numpy as np
import torch
from torch import nn
import tqdm
from typing import TYPE_CHECKING
from copy import deepcopy
import torch.nn.functional as F
from models.utils.continual_model import ContinualModel
from utils import binary_to_boolean_type
from utils.args import add_rehearsal_args
from utils.autoaugment import CIFAR10Policy
from utils.buffer import Buffer
from torchvision import transforms
from utils.kornia_utils import to_kornia_transform
if TYPE_CHECKING:
from datasets.utils.continual_dataset import ContinualDataset
[docs]
def kl_divergence(p, q):
return (p * ((p + 1e-10) / (q + 1e-10)).log()).sum(dim=1)
[docs]
class Jensen_Shannon(nn.Module):
[docs]
def forward(self, p, q):
m = (p + q) / 2
return 0.5 * kl_divergence(p, m) + 0.5 * kl_divergence(q, m)
[docs]
class NegEntropy(object):
def __call__(self, outputs):
probs = torch.softmax(outputs, dim=1)
return torch.mean(torch.sum(probs.log() * probs, dim=1))
[docs]
def linear_rampup(current, warm_up, rampup_length=16):
current = np.clip((current - warm_up) / rampup_length, 0.0, 1.0)
return float(current)
[docs]
class SemiLoss(object):
def __init__(self, args):
self.args = args
def __call__(self, outputs_x, outputs_x2, targets_x, outputs_u, targets_u, epoch, warm_up):
probs_u = torch.softmax(outputs_u, dim=1)
Lx = -torch.mean(torch.sum(F.log_softmax(outputs_x, dim=1) * targets_x, dim=1))
Lu = torch.mean((probs_u - targets_u)**2)
return Lx, Lu, linear_rampup(epoch, warm_up)
[docs]
class Dataset(torch.utils.data.Dataset):
def __init__(self, data, targets=None, transform=None, device="cpu"):
self.data = data.to(device)
self.targets = targets.to(device) if targets is not None else None
self.transform = transform
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
if self.transform:
data = self.transform(self.data[idx])
else:
data = self.data[idx]
if self.targets is not None:
return data, self.targets[idx]
return data
[docs]
class Cnll(ContinualModel):
"""
Implementation of `CNLL: A Semi-supervised Approach For Continual Noisy Label Learning <https://github.com/nazmul-karim170/CNLL>`_ from CVPRW 2022.
"""
NAME = 'cnll'
COMPATIBILITY = ['class-il', 'task-il']
[docs]
@staticmethod
def get_parser(parser):
parser.set_defaults(optim_mom=0.9, optim_wd=5e-4, lr=0) # lr and optimizer are ignored
add_rehearsal_args(parser)
parser.add_argument('--cnll_debug_mode', type=binary_to_boolean_type, default=False,
help='Run CNLL with just a few iterations?')
parser.add_argument('--unlimited_buffer', type=binary_to_boolean_type, default=False,
help='Use unlimited buffers?')
parser.add_argument('--delayed_buffer_size', type=int, default=500,
help='Size of the delayed buffer.')
parser.add_argument('--noisy_buffer_size', type=int, default=1000,
help='Size of the noisy buffer.')
parser.add_argument('--warmup_epochs', type=int, default=5, help='Warmup epochs')
parser.add_argument('--finetune_epochs', type=int, default=10, help='Finetuning epochs')
parser.add_argument('--warmup_lr', type=float, default=0.0005, help='Warmup learning rate')
parser.add_argument('--subsample_clean', type=int, default=25,
help='Number of high confidence samples to subsample from the clean buffer (N_1 in the paper)')
parser.add_argument('--subsample_noisy', type=int, default=50,
help='Number of high confidence samples to subsample from the noisy buffer (N_2 in the paper)')
parser.add_argument('--sharp_temp', type=float, default=0.5, help='Temperature for label CO-Guessing')
parser.add_argument('--mixup_alpha', type=float, default=4, help='Alpha parameter of Beta distribution for mixup')
parser.add_argument('--lambda_u', type=float, default=30, help='Weight for unsupervised loss')
parser.add_argument('--lambda_c', type=float, default=0.025, help='Weight for constrastive loss')
parser.add_argument('--finetune_lr', type=float, help='Learning rate using during finetuning. Default is `warmup_lr`*50')
return parser
def __init__(self, backbone, loss, args, transform, dataset=None):
logging.info("LR and optimizer are ignored in CNLL (will use warmup_lr, finetune_lr and SGD)")
if args.finetune_lr is None:
args.finetune_lr = args.warmup_lr * 50
backbone.classifier_re = deepcopy(backbone.classifier)
super().__init__(backbone, loss, args, transform, dataset=dataset)
# initializing buffers
self.buffer = Buffer(self.args.buffer_size, "cpu")
self.delayed_buffer = Buffer(self.args.delayed_buffer_size, "cpu")
self.noisy_buffer = Buffer(self.args.noisy_buffer_size, "cpu")
if self.args.unlimited_buffer:
logging.warning("Using unlimited buffer!")
self.high_fidelity_buffer = Buffer(-1, device='cpu', sample_selection_strategy='unlimited') # unlimited buffers
self.high_fidelity_noisy_buffer = Buffer(-1, device='cpu', sample_selection_strategy='unlimited') # unlimited buffers
else:
self.high_fidelity_buffer = Buffer(self.args.buffer_size, "cpu") # more buffers
self.high_fidelity_noisy_buffer = Buffer(self.args.noisy_buffer_size, "cpu") # more buffers
self._past_it_t = time.time()
self._task_t = time.time()
self._avg_it_t = 0
self.past_loss = 0
self.conf_penalty = NegEntropy()
self.JS_dist = Jensen_Shannon()
self.eye = torch.eye(self.num_classes).to(self.device)
self.hard_transform = to_kornia_transform(get_hard_transform(self.args, dataset))
self.semi_sul_loss = SemiLoss(args)
[docs]
def warm_up_on_buffer(self, buffer: Buffer, has_fitted_once: bool):
opt = torch.optim.SGD(self.net.parameters(), lr=self.args.warmup_lr, momentum=self.args.optim_mom, weight_decay=self.args.optim_wd)
if has_fitted_once:
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, self.args.finetuning_epochs, 2e-2)
else:
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, 240, 2e-4)
for _ in range(self.args.warmup_epochs):
self.net.train()
for batch in buffer.get_dataloader(self.args, batch_size=self.args.batch_size, shuffle=True):
opt.zero_grad()
x, y = batch[0].to(self.device), batch[1].to(self.device)
x = self.transform(x)
pred = self.net(x)
loss = self.loss(pred, y)
if 'asym' in self.args.noise_type:
penalty = self.conf_penalty(loss)
if not torch.isnan(penalty):
loss += penalty
loss.backward()
opt.step()
sched.step()
[docs]
def begin_task(self, dataset):
if self.current_task > 0:
ct = time.time()
remaining_time = ((ct - self._task_t) * (self.n_tasks - self.current_task)) * self.args.n_epochs
logging.debug(f"Task {self.current_task-1} lasted {ct-self._task_t:.2f}s | remaining: {remaining_time:.2f}s")
self.observe_it = 0
self.tot_its = ((len(dataset.train_loader.dataset) // self.args.delayed_buffer_size) + 1) * self.args.n_epochs
# FIRST LIE: NO TASK BOUNDARY
self.current_classes = np.unique(dataset.train_loader.dataset.targets)
weight = torch.zeros(self.num_classes)
weight[self.current_classes] = 1
weight = weight.to(self.device)
self.loss = nn.CrossEntropyLoss(weight=weight)
self._task_t = self._past_it_t = time.time()
[docs]
@torch.no_grad()
def sample_selection_JSD(self, buffer: Buffer):
selected_indexes = torch.zeros(len(buffer))
for batch_idx, batch in enumerate(buffer.get_dataloader(self.args, batch_size=self.args.batch_size * 2, drop_last=True, shuffle=False)):
inputs, targets = batch[0].to(self.device), batch[1].to(self.device)
batch_size = inputs.size()[0]
# Get outputs of both network
preds = torch.softmax(self.net(inputs), dim=-1)
out = torch.zeros(preds.size()).to(self.device)
out[:, self.current_classes] = preds[:, self.current_classes]
_, ind = torch.max(out, 1)
out_final = torch.zeros(preds.size()).to(self.device)
for kk in range(out.size()[0]):
out_final[kk, ind[kk]] = 1
dist = self.JS_dist(out_final, F.one_hot(targets, num_classes=self.num_classes))
selected_indexes[int(batch_idx * batch_size):int((batch_idx + 1) * batch_size)] = dist
return selected_indexes
[docs]
@torch.no_grad()
def get_partition_buffer_indexes(self, buffer: Buffer):
buffer_size = len(buffer)
selected_indexes = self.sample_selection_JSD(buffer)
threshold = torch.mean(selected_indexes)
SR = torch.sum(selected_indexes < threshold).item() / buffer_size
selected_indexes = selected_indexes.cpu().numpy()
pred_idx = np.argsort(selected_indexes)[0: int(SR * buffer_size)]
idx = np.arange(buffer_size)
pred_idx_noisy = np.setdiff1d(idx, pred_idx)
repl_idx = np.array(pred_idx)[:self.args.subsample_clean]
repl_idx_noisy = np.array(pred_idx_noisy)[:self.args.subsample_noisy]
return pred_idx, pred_idx_noisy, repl_idx, repl_idx_noisy
def _observe(self, not_aug_x, y, true_y):
self.delayed_buffer.add_data(examples=not_aug_x.unsqueeze(0),
labels=y.unsqueeze(0),
true_labels=true_y.unsqueeze(0))
avg_expert_loss, avg_self_loss = -1, -1
has_fitted_once = False
if self.delayed_buffer.is_full():
if self.args.cnll_debug_mode and self.observe_it > 2:
return 0, 0
self.observe_it += 1
ctime = time.time()
self._avg_it_t = (self._avg_it_t + (ctime - self._past_it_t)) / (self.observe_it)
remaing_time = (self.tot_its - self.observe_it) * self._avg_it_t
logging.debug(f"[Task {self.current_task}] Buffer iteration: {self.observe_it}/{self.tot_its} (s/it: {self._avg_it_t:.2f}s | rem: {remaing_time:.2f}s)")
self._past_it_t = ctime
# Warm up on D
pret = time.time()
logging.debug(" - Warm up...")
avg_expert_loss = self.warm_up_on_buffer(self.delayed_buffer, has_fitted_once)
logging.debug(f" Done (s: {time.time()-pret:.2f}s)")
pret = time.time()
logging.debug(" - Purifying buffer...")
# Get clean samples from D
clean_idxs, noisy_idxs, high_fidelity_clean_idxs, high_fidelity_noisy_idxs = self.get_partition_buffer_indexes(self.delayed_buffer)
# Add clean samples to P
self.buffer.add_data(examples=self.delayed_buffer.examples[clean_idxs],
labels=self.delayed_buffer.labels[clean_idxs],
true_labels=self.delayed_buffer.true_labels[clean_idxs])
self.noisy_buffer.add_data(examples=self.delayed_buffer.examples[noisy_idxs],
labels=self.delayed_buffer.labels[noisy_idxs],
true_labels=self.delayed_buffer.true_labels[noisy_idxs])
self.high_fidelity_buffer.add_data(examples=self.delayed_buffer.examples[high_fidelity_clean_idxs],
labels=self.delayed_buffer.labels[high_fidelity_clean_idxs],
true_labels=self.delayed_buffer.true_labels[high_fidelity_clean_idxs])
self.high_fidelity_noisy_buffer.add_data(examples=self.delayed_buffer.examples[high_fidelity_noisy_idxs],
labels=self.delayed_buffer.labels[high_fidelity_noisy_idxs],
true_labels=self.delayed_buffer.true_labels[high_fidelity_noisy_idxs])
logging.debug(f" Done (s: {time.time()-pret:.2f}s)")
self.delayed_buffer.empty()
if self.buffer.is_full():
has_fitted_once = True
pret = time.time()
logging.debug(" - Clean buffer is full, fine-tuning model on buffers...")
self.past_loss = self.finetune_on_buffers()
self.buffer.empty()
self.noisy_buffer.empty()
logging.debug(f" Done (s: {time.time()-pret:.2f}s)")
return avg_expert_loss, avg_self_loss
[docs]
@torch.no_grad()
def coguess_label(self, xa, xb, y):
# Label Co-guessing of Unlabeled Samples
outputs_u11 = self.net(xa)
outputs_u12 = self.net(xb)
# Pseudo-Label
pu = (torch.softmax(outputs_u11, dim=1) + torch.softmax(outputs_u12, dim=1)) / 2
ptu = pu**(1 / self.args.sharp_temp) # Temparature Sharpening
targets_u = ptu / ptu.sum(dim=1, keepdim=True)
targets_u = targets_u.detach()
# Label Refinement
outputs_x = self.net(xa)
outputs_x2 = self.net(xb)
px = (torch.softmax(outputs_x, dim=1) + torch.softmax(outputs_x2, dim=1)) / 2
px = y # w_x*y + (1-w_x)*px # BLE
ptx = px**(1 / self.args.sharp_temp) # Temparature sharpening
targets_x = ptx / ptx.sum(dim=1, keepdim=True)
return targets_x, targets_u
[docs]
def ssl_loss(self, all_inputs, all_targets, batch_size, c_iter):
idx = torch.randperm(all_inputs.size(0))
l = np.random.beta(self.args.mixup_alpha, self.args.mixup_alpha)
l = max(l, 1 - l)
input_a, input_b = all_inputs, all_inputs[idx]
target_a, target_b = all_targets, all_targets[idx]
# Mixup
mixed_input = l * input_a + (1 - l) * input_b
mixed_target = l * target_a + (1 - l) * target_b
feats = self.net(mixed_input, 'features')
logits1, logits = self.net.classifier(feats), self.net.classifier_re(feats)
logits_x = logits1[:batch_size * 2]
logits_x1 = logits[:batch_size * 2]
logits_u = logits[batch_size * 2:]
# Combined Loss
Lx, Lu, lamb = self.semi_sul_loss(logits_x, logits_x1,
mixed_target[:batch_size * 2], logits_u,
mixed_target[batch_size * 2:], c_iter, self.args.warmup_epochs)
# Regularization
prior = torch.ones(self.num_classes, device=self.device) / self.num_classes
pred_mean = torch.softmax(logits, dim=1).mean(0)
penalty = torch.sum(prior * torch.log(prior / pred_mean))
lamb *= self.args.lambda_u
return Lx + 0.1 * lamb * Lu + penalty
[docs]
def finetune_on_buffers(self):
"""Fit finetuned model on purified and noisy buffer"""
self.net.train()
opt = torch.optim.SGD(self.net.parameters(), lr=self.args.finetune_lr, momentum=self.args.optim_mom, weight_decay=self.args.optim_wd)
all_clean_data = torch.cat([self.buffer.examples[:len(self.buffer)],
self.high_fidelity_buffer.examples[:len(self.high_fidelity_buffer)]], dim=0)
all_clean_labels = torch.cat([self.buffer.labels[:len(self.buffer)],
self.high_fidelity_buffer.labels[:len(self.high_fidelity_buffer)]], dim=0)
clean_dset = Dataset(all_clean_data, all_clean_labels, device=self.device)
clean_dl = torch.utils.data.DataLoader(clean_dset, batch_size=self.args.batch_size, shuffle=True, drop_last=True)
all_noisy_data = torch.cat([self.noisy_buffer.examples[:len(self.noisy_buffer)],
self.high_fidelity_noisy_buffer.examples[:len(self.high_fidelity_noisy_buffer)]], dim=0)
noisy_dset = Dataset(all_noisy_data, device=self.device)
noisy_dl = torch.utils.data.DataLoader(noisy_dset, batch_size=self.args.batch_size, shuffle=True, drop_last=True)
noisy_iter = iter(noisy_dl)
for epoch in tqdm.trange(self.args.finetune_epochs, desc="Buffer fitting", leave=False, disable=True):
avgloss = 0
if self.args.cnll_debug_mode == 1 and epoch > 2:
break
for batch_idx, dat in enumerate(clean_dl):
try:
noisy_dat = next(noisy_iter)
except StopIteration:
noisy_iter = iter(noisy_dl)
noisy_dat = next(noisy_iter)
opt.zero_grad()
x, y = dat[0], dat[1]
x, y = x.to(self.device), y.to(self.device)
onehot_y = self.eye[y]
clean_xa, clean_xb = self.transform(x), self.transform(x)
unlabeled_x = noisy_dat.to(self.device)
unlabeled_xa, unlabeled_xb = self.hard_transform(unlabeled_x), self.hard_transform(unlabeled_x)
refined_y_clean, pseudo_y = self.coguess_label(unlabeled_xa, unlabeled_xb, onehot_y)
all_inputs = torch.cat([clean_xa, clean_xb, unlabeled_xa, unlabeled_xb], dim=0)
all_labels = torch.cat([refined_y_clean, refined_y_clean, pseudo_y, pseudo_y], dim=0)
loss = self.ssl_loss(all_inputs, all_labels, len(clean_xa), epoch + batch_idx / len(clean_dl))
loss.backward()
opt.step()
avgloss += loss.item()
avgloss /= len(clean_dl)
return avgloss # retrun the average loss at the last epoch
[docs]
def observe(self, inputs, labels, not_aug_inputs, true_labels):
for y, not_aug_x, true_y in zip(labels, not_aug_inputs, true_labels):
avg_expert_loss, avg_self_loss = self._observe(not_aug_x, y, true_y)
return self.past_loss if avg_self_loss < 0 else avg_self_loss