import logging
import time
from typing import Union
import numpy as np
from sklearn.mixture import GaussianMixture
import torch
from torch import nn
from torch.utils.data import DataLoader
from tqdm.auto import trange
from datasets.utils.continual_dataset import ContinualDataset
from models.utils.continual_model import ContinualModel
from argparse import ArgumentParser, Namespace
from utils.args import add_rehearsal_args
from utils.augmentations import RepeatedTransform, cutmix_data
from utils.autoaugment import CIFAR10Policy
from utils.buffer import Buffer
import torch.nn.functional as F
from torchvision import transforms
from utils.conf import create_seeded_dataloader
from utils.kornia_utils import to_kornia_transform
[docs]
class CustomDataset(torch.utils.data.Dataset):
def __init__(self, data: torch.Tensor, targets: torch.Tensor, transform=None, probs=None, extra=None, device="cpu"):
self.device = device
self.data = data.to(self.device)
self.targets = targets.to(device) if targets is not None else None
self.transform = transform
self.probs = (torch.ones(len(self.data)) / len(self.data)).to(device) if probs is None else probs.to(device)
self.extra = extra.to(device) if extra is not None else None
[docs]
def set_probs(self, probs: Union[np.ndarray, torch.Tensor]):
"""
Set the probability of each data point being correct (i.e., belonging to the Gaussian with the lowest mean)
"""
if not isinstance(probs, torch.Tensor):
probs = torch.tensor(probs)
self.probs = probs.to(self.data.device)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
"""
Return the data, the target, the extra information (if any), the not augmented data, and the probability of the data point being correct
Returns:
- data: the augmented data
- target: the target
- extra: (optional) additional information
- not_aug_data: the data without augmentation
- prob: the probability of the data point being correct
"""
not_aug_data = self.data[idx]
data = not_aug_data.clone()
if self.transform:
data = self.transform(data)
if len(data.shape) > 3:
if data.shape[0] == 1:
data = data.squeeze(0)
elif data.shape[1] == 1:
data = data.squeeze(1)
ret = (data, self.targets[idx],)
if self.extra is not None:
ret += (self.extra[idx],)
ret += (not_aug_data,)
return ret + (self.probs[idx],)
[docs]
def soft_cross_entropy_loss(input, target, reduction='mean'):
"""
https://github.com/pytorch/pytorch/issues/11959
Args:
input: (batch, *)
target: (batch, *) same shape as input, each item must be a valid distribution: target[i, :].sum() == 1.
"""
logprobs = torch.nn.functional.log_softmax(input.view(input.shape[0], -1), dim=1)
batchloss = - torch.sum(target.view(target.shape[0], -1) * logprobs, dim=1)
if reduction == 'none':
return batchloss
elif reduction == 'mean':
return torch.mean(batchloss)
elif reduction == 'sum':
return torch.sum(batchloss)
else:
raise NotImplementedError('Unsupported reduction mode.')
[docs]
def get_dataloader_from_buffer(args: Namespace, buffer: Buffer, batch_size: int, shuffle=False, transform=None):
if len(buffer) == 0:
return None
buf_data = buffer.get_all_data(device="cpu")
inputs, labels = buf_data[0], buf_data[1]
# Building train dataset
train_dataset = CustomDataset(inputs, labels, transform=transform)
return create_seeded_dataloader(args, train_dataset, non_verbose=True, batch_size=batch_size, shuffle=shuffle, num_workers=0)
[docs]
class PuriDivER(ContinualModel):
"""PuriDivER: Online Continual Learning on a Contaminated Data Stream with Blurry Task Boundaries."""
NAME = 'puridiver'
COMPATIBILITY = ['class-il', 'task-il']
[docs]
@staticmethod
def get_parser(parser) -> ArgumentParser:
parser.set_defaults(n_epochs=1, optim_mom=0.9, optim_wd=1e-4, optim_nesterov=1, batch_size=16)
add_rehearsal_args(parser)
parser.add_argument('--use_bn_classifier', type=int, default=1, choices=[0, 1],
help='Use batch normalization in the classifier?')
parser.add_argument('--freeze_buffer_after_first', type=int, default=0, choices=[0, 1],
help='Freeze buffer after first task (i.e., simulate online update of the buffer, useful for multi-epoch)?')
parser.add_argument('--initial_alpha', type=float, default=0.5)
parser.add_argument('--disable_train_aug', type=int, default=1, choices=[0, 1], help='Disable training augmentation?')
parser.add_argument('--buffer_fitting_epochs', type=int, default=255, help='Number of epochs to fit on buffer')
parser.add_argument('--warmup_buffer_fitting_epochs', type=int, default=10, help='Number of warmup epochs during which fit with simple CE')
parser.add_argument('--enable_cutmix', type=int, default=1, choices=[0, 1], help='Enable cutmix augmentation?')
parser.add_argument('--cutmix_prob', type=float, default=0.5, help='Cutmix probability')
return parser
def __init__(self, backbone, loss, args, transform, dataset=None):
assert args.dataset in ['seq-cifar10', 'seq-cifar100'], 'PuriDivER is only compatible with CIFAR datasets (extend `get_hard_transform` for other datasets)'
super().__init__(backbone, loss, args, transform, dataset=dataset)
self.buffer = Buffer(self.args.buffer_size, "cpu")
self._past_it_t = time.time()
self._avg_it_t = 0
self.past_loss = 0
self.eye = torch.eye(self.num_classes).to(self.device)
hard_transform = get_hard_transform(self.dataset)
try:
self.hard_transform = to_kornia_transform(hard_transform)
except NotImplementedError as e:
logging.error('Kornia not available, raising error instead of using PIL transforms (would be waaay too slow).')
# NOTE: uncomment the following line if you want to use PIL transforms
# self.hard_transform = hard_transform
raise e
[docs]
def get_subset_dl_from_idxs(self, idxs, batch_size, probs=None, transform=None):
if idxs is None:
return None
assert batch_size is not None
examples, labels, true_labels = self.buffer.get_all_data()
examples, labels, true_labels = examples[idxs], labels[idxs], true_labels[idxs]
if probs is not None:
probs = torch.from_numpy(probs)
dataset = CustomDataset(examples, labels, extra=true_labels, probs=probs, transform=transform)
return DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
[docs]
@torch.no_grad()
def split_data_puridiver(self, n=2):
self.net.eval()
losses = []
uncertainties = []
for batch_idx, batch in enumerate(get_dataloader_from_buffer(self.args, self.buffer, batch_size=64, shuffle=False)):
x, y, y_true = batch[0], batch[1], batch[-1]
x, y, y_true = x.to(self.device), y.to(self.device), y_true.to(self.device)
x = self.normalization_transform(x)
out = self.net(x)
probs = F.softmax(out, dim=1)
uncerts = 1 - torch.max(probs, 1)[0]
losses.append(F.cross_entropy(out, y, reduction='none'))
uncertainties.append(uncerts)
losses = torch.cat(losses, dim=0).cpu()
uncertainties = torch.cat(uncertainties, dim=0).cpu().reshape(-1, 1)
losses = (losses - losses.min()) / (losses.max() - losses.min())
losses = losses.unsqueeze(1)
# GMM for correct vs others samples
gmm_loss = GaussianMixture(n_components=n, max_iter=10, tol=1e-2, reg_covar=5e-4)
gmm_loss.fit(losses)
gmm_loss_means = gmm_loss.means_
if gmm_loss_means[0] <= gmm_loss_means[1]:
small_loss_idx = 0
large_loss_idx = 1
else:
small_loss_idx = 1
large_loss_idx = 0
loss_prob = gmm_loss.predict_proba(losses)
pred = loss_prob.argmax(axis=1)
corr_idxs = np.where(pred == small_loss_idx)[0]
if len(corr_idxs) == 0:
return None, None, None
# 2nd GMM using large loss datasets
high_loss_idxs = np.where(pred == large_loss_idx)[0]
ambiguous_idxs, incorrect_idxs = None, None
if len(high_loss_idxs) > 2:
# GMM for uncertain vs incorrect samples
gmm_uncert = GaussianMixture(n_components=n, max_iter=10, tol=1e-2, reg_covar=5e-4)
gmm_uncert.fit(uncertainties[high_loss_idxs])
prob_uncert = gmm_uncert.predict_proba(uncertainties[high_loss_idxs])
pred_uncert = prob_uncert.argmax(axis=1)
if gmm_uncert.means_[0] <= gmm_uncert.means_[1]:
small_loss_idx = 0
large_loss_idx = 1
else:
small_loss_idx = 1
large_loss_idx = 0
idx_uncert = np.where(pred_uncert == small_loss_idx)[0]
amb_size = len(idx_uncert)
ambiguous_batch_size = max(2, int(amb_size / len(corr_idxs) * self.args.batch_size))
if amb_size <= 2:
ambiguous_idxs = None
else:
ambiguous_idxs = high_loss_idxs[idx_uncert]
idx_uncert = np.where(pred_uncert == large_loss_idx)[0]
incorrect_size = len(idx_uncert)
incorrect_batch_size = max(2, int(incorrect_size / len(corr_idxs) * self.args.batch_size))
if incorrect_size <= 2:
incorrect_idxs = None
else:
incorrect_idxs = high_loss_idxs[idx_uncert]
correct_dl = self.get_subset_dl_from_idxs(corr_idxs, self.args.batch_size, transform=self.hard_transform)
if ambiguous_idxs is not None:
ambiguous_dl = self.get_subset_dl_from_idxs(ambiguous_idxs, ambiguous_batch_size, transform=RepeatedTransform([self.transform, self.hard_transform], autosqueeze=True))
else:
ambiguous_dl = None
if incorrect_idxs is not None:
incorrect_dl = self.get_subset_dl_from_idxs(incorrect_idxs, incorrect_batch_size, probs=loss_prob[incorrect_idxs], transform=RepeatedTransform([
self.transform, self.hard_transform], autosqueeze=True))
else:
incorrect_dl = None
return correct_dl, ambiguous_dl, incorrect_dl
[docs]
def train_with_mixmatch(self, loader_L, loader_U, loader_R):
criterion_U = nn.MSELoss()
criterion_L = nn.CrossEntropyLoss()
iter_U = iter(loader_U)
iter_R = iter(loader_R)
avg_loss = 0
# R: weak, hard
# L: hard
# U: weak, hard
self.net.train()
for i, batch in enumerate(loader_L):
if self.args.debug_mode and i > 10:
break
self.opt.zero_grad()
inputs_L, labels_L = batch[0], batch[1]
if len(inputs_L) == 1:
continue
try:
inputs_U = next(iter_U)[0]
except BaseException:
iter_U = iter(loader_U)
inputs_U = next(iter_U)[0]
try:
batch_R = next(iter_R)
inputs_R, labels_R, probs_R = batch_R[0], batch_R[1], batch_R[-1]
except BaseException:
iter_R = iter(loader_R)
batch_R = next(iter_R)
inputs_R, labels_R, probs_R = batch_R[0], batch_R[1], batch_R[-1]
inputs_L, labels_L = inputs_L.to(self.device), labels_L.to(self.device)
inputs_U, inputs_R = inputs_U.to(self.device), inputs_R.to(self.device)
labels_R, probs_R = labels_R.to(self.device), probs_R.to(self.device)
labels_R = F.one_hot(labels_R, self.num_classes)
corr_prob = probs_R[:, 0].unsqueeze(1).expand(-1, self.num_classes)
inputs_U = torch.cat([inputs_U[:, 0], inputs_U[:, 1]], dim=0)
inputs_R = torch.cat([inputs_R[:, 0], inputs_R[:, 1]], dim=0)
do_cutmix = self.args.enable_cutmix and np.random.random(1) < self.args.cutmix_prob
if do_cutmix:
inputs_L, labels_L_a, labels_L_b, lam = cutmix_data(inputs_L, labels_L, force=True)
all_inputs = torch.cat([inputs_R, inputs_U, inputs_L], dim=0)
all_outputs = self.net(all_inputs)
outputs_R, outputs_U, outputs_L = torch.split(all_outputs, [inputs_R.size(0), inputs_U.size(0), inputs_L.size(0)])
loss_L = lam * self.loss(outputs_L, labels_L_a) + (1 - lam) * criterion_L(outputs_L, labels_L_b)
else:
all_inputs = torch.cat([inputs_R, inputs_U, inputs_L], dim=0)
all_outputs = self.net(all_inputs)
outputs_R, outputs_U, outputs_L = torch.split(all_outputs, [inputs_R.size(0), inputs_U.size(0), inputs_L.size(0)])
outputs_L = self.net(inputs_L)
loss_L = self.loss(outputs_L, labels_L)
outputs_U_weak, outputs_U_strong = torch.split(outputs_U, outputs_U.size(0) // 2)
outputs_R_pseudo, outputs_R = torch.split(outputs_R, outputs_R.size(0) // 2) # weak, strong
probs_R_pseudo = torch.softmax(outputs_R_pseudo, dim=1)
soft_pseudo_labels = corr_prob * labels_R + (1 - corr_prob) * probs_R_pseudo.detach()
loss_R = soft_cross_entropy_loss(outputs_R, soft_pseudo_labels)
loss_U = criterion_U(outputs_U_weak, outputs_U_strong)
coeff_L = (len(labels_L) / (len(labels_L) + len(labels_R) + len(outputs_U_weak)))
coeff_R = (len(labels_R) / (len(labels_R) + len(labels_L) + len(outputs_U_weak)))
coeff_U = (len(outputs_U_weak) / (len(labels_R) + len(labels_L) + len(outputs_U_weak)))
loss = coeff_L * loss_L + coeff_U * loss_U + coeff_R * loss_R
assert not torch.isnan(loss).any()
# backward
loss.backward()
self.opt.step()
avg_loss += loss.item()
return avg_loss / len(loader_L)
[docs]
def base_fit_buffer(self, loader=None):
self.net.train()
avg_loss = 0
if loader is None:
loader = get_dataloader_from_buffer(self.args, self.buffer, batch_size=self.args.batch_size, shuffle=True, transform=self.hard_transform)
for i, batch in enumerate(loader):
x, y = batch[0].to(self.device), batch[1].to(self.device)
if len(x) == 1:
continue
if self.args.debug_mode and i > 10:
break
self.opt.zero_grad()
do_cutmix = self.args.enable_cutmix and np.random.rand(1) < self.args.cutmix_prob
if do_cutmix:
x, y_a, y_b, lam = cutmix_data(x, y, force=True)
out = self.net(x)
loss = lam * self.loss(out, y_a) + (1 - lam) * self.loss(out, y_b)
else:
out = self.net(x)
loss = self.loss(out, y)
assert not torch.isnan(loss).any()
loss.backward()
self.opt.step()
avg_loss += loss.item()
return avg_loss / len(loader)
[docs]
def fit_buffer(self):
for param_group in self.opt.param_groups:
param_group["lr"] = self.args.lr
with trange(self.args.buffer_fitting_epochs) as pbar:
for epoch in pbar:
if self.args.debug_mode and epoch > self.args.warmup_buffer_fitting_epochs + 50:
break
if epoch < self.args.warmup_buffer_fitting_epochs:
tp = 'warmup'
loss = self.base_fit_buffer()
else:
correct_dl, ambiguous_dl, incorrect_dl = self.split_data_puridiver()
if ambiguous_dl is not None and incorrect_dl is not None:
tp = 'puridiver'
loss = self.train_with_mixmatch(correct_dl, ambiguous_dl, incorrect_dl)
else:
tp = 'base'
loss = self.base_fit_buffer()
buf_not_aug_inputs, buf_labels, buf_true_labels = self.buffer.get_all_data()
_, _, buf_acc, true_buf_acc = self._non_observe_data(self.normalization_transform(buf_not_aug_inputs), buf_labels, buf_true_labels)
perc_clean = (self.buffer.labels == self.buffer.true_labels).float().mean().item()
pbar.set_postfix(loss=loss, buf_acc=buf_acc, true_buf_acc=true_buf_acc, perc_clean=perc_clean, lr=self.opt.param_groups[0]["lr"], refresh=False)
pbar.set_description(f'Epoch {epoch + 1}/{self.args.buffer_fitting_epochs} [{tp}]', refresh=False)
self.custom_scheduler.step()
[docs]
def end_task(self, dataset):
# fit classifier on P
if self.args.buffer_fitting_epochs > 0:
self.fit_buffer()
[docs]
def get_classifier_weights(self):
if isinstance(self.net.classifier, nn.Sequential):
return self.net.classifier[0].weight.detach()
return self.net.classifier.weight.detach()
[docs]
def get_sim_score(self, feats, targets):
# relevant representation
cl_weights = self.get_classifier_weights()
relevant_idx = cl_weights[targets[0], :] > cl_weights.mean(dim=0)
cls_features = feats[:, relevant_idx]
sim_score = torch.cosine_similarity(cls_features, cls_features, dim=1)
return (sim_score - sim_score.mean()) / sim_score.std()
[docs]
def get_current_alpha_sim_score(self, loss):
return self.args.initial_alpha * min(1, 1 / loss)
[docs]
def get_scheduler(self):
return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
self.opt, T_0=1, T_mult=2, eta_min=self.args.lr * 0.01
)
[docs]
def begin_task(self, dataset):
self.total_its = len(dataset.train_loader) * self.args.n_epochs
if self.current_task == 0 and self.args.use_bn_classifier:
self.net.classifier = nn.Sequential(nn.Linear(self.net.classifier.in_features, self.net.classifier.out_features, bias=False),
nn.BatchNorm1d(self.net.classifier.out_features, affine=True, eps=1e-6).to(self.device)).to(self.device)
for m in self.net.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
m.eps = 1e-6
self.opt = self.get_optimizer()
self.custom_scheduler = self.get_scheduler()
for param_group in self.opt.param_groups:
param_group["lr"] = self.args.lr
if self.args.disable_train_aug:
dataset.train_loader.dataset.transform = self.dataset.TEST_TRANSFORM
@torch.no_grad()
def _non_observe_data(self, inputs: torch.Tensor, labels: torch.Tensor, true_labels: torch.Tensor = None):
was_training = self.net.training
self.net.eval()
dset = CustomDataset(inputs, labels, extra=true_labels, device=self.device)
dl = DataLoader(dset, batch_size=min(len(dset), 256), shuffle=False, num_workers=0)
feats = []
losses = []
true_accs, accs = [], []
for batch in dl:
inputs, labels, true_labels = batch[0].to(self.device), batch[1].to(self.device), batch[2].to(self.device)
out, feat = self.net(inputs, returnt='both')
acc = (out.argmax(dim=1) == labels).float().mean().item()
tacc = (out.argmax(dim=1) == true_labels).float().mean().item()
feats.append(feat)
losses.append(F.cross_entropy(out, labels, reduction='none'))
accs.append(acc)
true_accs.append(tacc)
feats = torch.cat(feats, dim=0)
losses = torch.cat(losses, dim=0)
acc = np.mean(accs)
true_acc = np.mean(true_accs)
self.net.train(was_training)
return feats, losses, acc, true_acc
[docs]
def puridiver_update_buffer(self, stream_not_aug_inputs: torch.Tensor, stream_labels: torch.Tensor, stream_true_labels: torch.Tensor):
if len(self.buffer) < self.args.buffer_size:
self.buffer.add_data(examples=stream_not_aug_inputs, labels=stream_labels, true_labels=stream_true_labels)
return -1, -1
buf_not_aug_inputs, buf_labels, buf_true_labels = self.buffer.get_all_data()
buf_not_aug_inputs, buf_labels, buf_true_labels = buf_not_aug_inputs.to(self.device), buf_labels.to(self.device), buf_true_labels.to(self.device)
not_aug_inputs = torch.cat([buf_not_aug_inputs, stream_not_aug_inputs], dim=0)
labels = torch.cat([buf_labels, stream_labels], dim=0)
true_labels = torch.cat([buf_true_labels, stream_true_labels], dim=0)
cur_idxs = torch.arange(len(not_aug_inputs)).to(self.device)
feats, losses, buf_acc, true_buf_acc = self._non_observe_data(self.normalization_transform(not_aug_inputs), labels, true_labels=true_labels)
alpha_sim_score = self.get_current_alpha_sim_score(losses.mean())
lbs = labels[cur_idxs]
while len(lbs) > self.args.buffer_size:
fts = feats[cur_idxs]
lss = losses[cur_idxs]
clss, cls_cnt = lbs.unique(return_counts=True)
# argmax w/ random tie-breaking
cls_to_drop = clss[cls_cnt == cls_cnt.max()]
cls_to_drop = cls_to_drop[torch.randperm(len(cls_to_drop))][0]
mask = lbs == cls_to_drop
sim_score = self.get_sim_score(fts[mask], lbs[mask])
div_score = (1 - alpha_sim_score) * lss[mask] + alpha_sim_score * sim_score
drop_cls_idx = div_score.argmax()
drop_idx = cur_idxs[mask][drop_cls_idx]
cur_idxs = cur_idxs[cur_idxs != drop_idx]
lbs = labels[cur_idxs]
self.buffer.empty()
self.buffer.add_data(examples=not_aug_inputs[cur_idxs], labels=labels[cur_idxs], true_labels=true_labels[cur_idxs])
return buf_acc, true_buf_acc
[docs]
def observe(self, inputs, labels, not_aug_inputs, true_labels, epoch):
self.net.train()
B = len(inputs)
self.opt.zero_grad()
if self.current_task > 0: # starting from second task
buf_inputs, buf_labels, _ = self.buffer.get_data(
self.args.minibatch_size, transform=self.hard_transform, device=self.device)
inputs = torch.cat((inputs, buf_inputs))
labels = torch.cat((labels, buf_labels))
do_cutmix = self.args.enable_cutmix and np.random.rand(1) < self.args.cutmix_prob
if do_cutmix:
inputs, labels_a, labels_b, lam = cutmix_data(inputs, labels, force=True)
outputs = self.net(inputs)
loss = lam * self.loss(outputs, labels_a) + (1 - lam) * self.loss(outputs, labels_b)
else:
outputs = self.net(inputs)
loss = self.loss(outputs, labels)
assert not torch.isnan(loss).any()
loss.backward()
self.opt.step()
if self.args.freeze_buffer_after_first == 0 or epoch == 0:
self.puridiver_update_buffer(not_aug_inputs[:B], labels[:B], true_labels[:B])
return loss.item()