import time
import numpy as np
from sklearn.mixture import GaussianMixture
import torch
from torch import nn
from torch.utils.data import DataLoader
import tqdm
import logging
from datasets.utils.continual_dataset import ContinualDataset
from models.utils.continual_model import ContinualModel
from argparse import ArgumentParser, Namespace
from utils.args import add_rehearsal_args
from utils.augmentations import RepeatedTransform, cutmix_data
from utils.autoaugment import CIFAR10Policy
from utils.buffer import Buffer
import torch.nn.functional as F
from torchvision import transforms
from utils.conf import create_seeded_dataloader
from utils.kornia_utils import to_kornia_transform
[docs]
class CustomDataset(torch.utils.data.Dataset):
def __init__(self, data: torch.Tensor, targets: torch.Tensor, transform=None, probs=None, extra=None, device="cpu"):
self.device = device
self.data = data.to(self.device)
self.targets = targets.to(device) if targets is not None else None
self.transform = transform
self.probs = (torch.ones(len(self.data)) / len(self.data)).to(device) if probs is None else probs.to(device)
self.extra = extra.to(device) if extra is not None else None
[docs]
def set_probs(self, probs: np.ndarray | torch.Tensor):
"""
Set the probability of each data point being correct (i.e., belonging to the Gaussian with the lowest mean)
"""
if not isinstance(probs, torch.Tensor):
probs = torch.tensor(probs)
self.probs = probs.to(self.data.device)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
"""
Return the data, the target, the extra information (if any), the not augmented data, and the probability of the data point being correct
Returns:
- data: the augmented data
- target: the target
- extra: (optional) additional information
- not_aug_data: the data without augmentation
- prob: the probability of the data point being correct
"""
not_aug_data = self.data[idx]
data = not_aug_data.clone()
if self.transform:
data = self.transform(data)
if len(data.shape) > 3:
if data.shape[0] == 1:
data = data.squeeze(0)
elif data.shape[1] == 1:
data = data.squeeze(1)
ret = (data, self.targets[idx],)
if self.extra is not None:
ret += (self.extra[idx],)
ret += (not_aug_data,)
return ret + (self.probs[idx],)
[docs]
def soft_cross_entropy_loss(input, target, reduction='mean'):
"""
https://github.com/pytorch/pytorch/issues/11959
Args:
input: (batch, *)
target: (batch, *) same shape as input, each item must be a valid distribution: target[i, :].sum() == 1.
"""
logprobs = torch.nn.functional.log_softmax(input.view(input.shape[0], -1), dim=1)
batchloss = - torch.sum(target.view(target.shape[0], -1) * logprobs, dim=1)
if reduction == 'none':
return batchloss
elif reduction == 'mean':
return torch.mean(batchloss)
elif reduction == 'sum':
return torch.sum(batchloss)
else:
raise NotImplementedError('Unsupported reduction mode.')
[docs]
def get_dataloader_from_buffer(args: Namespace, buffer: Buffer, batch_size: int, shuffle=False, transform=None):
if len(buffer) == 0:
return None
buf_data = buffer.get_all_data(device="cpu")
inputs, labels = buf_data[0], buf_data[1]
# Building train dataset
train_dataset = CustomDataset(inputs, labels, transform=transform)
return create_seeded_dataloader(args, train_dataset, non_verbose=True, batch_size=batch_size, shuffle=shuffle, num_workers=0)
[docs]
class PuriDivER(ContinualModel):
"""PuriDivER: Online Continual Learning on a Contaminated Data Stream with Blurry Task Boundaries."""
NAME = 'puridiver'
COMPATIBILITY = ['class-il', 'task-il']
[docs]
@staticmethod
def get_parser(parser) -> ArgumentParser:
parser.set_defaults(n_epochs=1, optim_mom=0.9, optim_wd=1e-4, optim_nesterov=1, batch_size=16)
add_rehearsal_args(parser)
parser.add_argument('--use_bn_classifier', type=int, default=1, choices=[0, 1],
help='Use batch normalization in the classifier?')
parser.add_argument('--freeze_buffer_after_first', type=int, default=0, choices=[0, 1],
help='Freeze buffer after first task (i.e., simulate online update of the buffer, useful for multi-epoch)?')
parser.add_argument('--initial_alpha', type=float, default=0.5)
parser.add_argument('--disable_train_aug', type=int, default=1, choices=[0, 1], help='Disable training augmentation?')
parser.add_argument('--buffer_fitting_epochs', type=int, default=255, help='Number of epochs to fit on buffer')
parser.add_argument('--warmup_buffer_fitting_epochs', type=int, default=10, help='Number of warmup epochs during which fit with simple CE')
parser.add_argument('--enable_cutmix', type=int, default=1, choices=[0, 1], help='Enable cutmix augmentation?')
parser.add_argument('--cutmix_prob', type=float, default=0.5, help='Cutmix probability')
return parser
def __init__(self, backbone, loss, args, transform, dataset=None):
assert args.dataset in ['seq-cifar10', 'seq-cifar100'], 'PuriDivER is only compatible with CIFAR datasets (extend `get_hard_transform` for other datasets)'
super().__init__(backbone, loss, args, transform, dataset=dataset)
self.buffer = Buffer(self.args.buffer_size, "cpu")
self._past_it_t = time.time()
self._avg_it_t = 0
self.past_loss = 0
self.eye = torch.eye(self.num_classes).to(self.device)
hard_transform = get_hard_transform(self.dataset)
try:
self.hard_transform = to_kornia_transform(hard_transform)
except NotImplementedError as e:
_logging.error('Kornia not available, raising error instead of using PIL transforms (would be waaay too slow).')
# NOTE: uncomment the following line if you want to use PIL transforms
# self.hard_transform = hard_transform
raise e
[docs]
def get_subset_dl_from_idxs(self, idxs, batch_size, probs=None, transform=None):
if idxs is None:
return None
assert batch_size is not None
examples, labels, true_labels = self.buffer.get_all_data()
examples, labels, true_labels = examples[idxs], labels[idxs], true_labels[idxs]
if probs is not None:
probs = torch.from_numpy(probs)
dataset = CustomDataset(examples, labels, extra=true_labels, probs=probs, transform=transform)
return DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
[docs]
@torch.no_grad()
def split_data_puridiver(self, n=2):
self.net.eval()
losses = []
uncertainties = []
for batch_idx, batch in enumerate(get_dataloader_from_buffer(self.args, self.buffer, batch_size=64, shuffle=False)):
x, y, y_true = batch[0], batch[1], batch[-1]
x, y, y_true = x.to(self.device), y.to(self.device), y_true.to(self.device)
x = self.normalization_transform(x)
out = self.net(x)
probs = F.softmax(out, dim=1)
uncerts = 1 - torch.max(probs, 1)[0]
losses.append(F.cross_entropy(out, y, reduction='none'))
uncertainties.append(uncerts)
losses = torch.cat(losses, dim=0).cpu()
uncertainties = torch.cat(uncertainties, dim=0).cpu().reshape(-1, 1)
losses = (losses - losses.min()) / (losses.max() - losses.min())
losses = losses.unsqueeze(1)
# GMM for correct vs others samples
gmm_loss = GaussianMixture(n_components=n, max_iter=10, tol=1e-2, reg_covar=5e-4)
gmm_loss.fit(losses)
gmm_loss_means = gmm_loss.means_
if gmm_loss_means[0] <= gmm_loss_means[1]:
small_loss_idx = 0
large_loss_idx = 1
else:
small_loss_idx = 1
large_loss_idx = 0
loss_prob = gmm_loss.predict_proba(losses)
pred = loss_prob.argmax(axis=1)
corr_idxs = np.where(pred == small_loss_idx)[0]
if len(corr_idxs) == 0:
return None, None, None
# 2nd GMM using large loss datasets
high_loss_idxs = np.where(pred == large_loss_idx)[0]
ambiguous_idxs, incorrect_idxs = None, None
if len(high_loss_idxs) > 2:
# GMM for uncertain vs incorrect samples
gmm_uncert = GaussianMixture(n_components=n, max_iter=10, tol=1e-2, reg_covar=5e-4)
gmm_uncert.fit(uncertainties[high_loss_idxs])
prob_uncert = gmm_uncert.predict_proba(uncertainties[high_loss_idxs])
pred_uncert = prob_uncert.argmax(axis=1)
if gmm_uncert.means_[0] <= gmm_uncert.means_[1]:
small_loss_idx = 0
large_loss_idx = 1
else:
small_loss_idx = 1
large_loss_idx = 0
idx_uncert = np.where(pred_uncert == small_loss_idx)[0]
amb_size = len(idx_uncert)
ambiguous_batch_size = max(2, int(amb_size / len(corr_idxs) * self.args.batch_size))
if amb_size <= 2:
ambiguous_idxs = None
else:
ambiguous_idxs = high_loss_idxs[idx_uncert]
idx_uncert = np.where(pred_uncert == large_loss_idx)[0]
incorrect_size = len(idx_uncert)
incorrect_batch_size = max(2, int(incorrect_size / len(corr_idxs) * self.args.batch_size))
if incorrect_size <= 2:
incorrect_idxs = None
else:
incorrect_idxs = high_loss_idxs[idx_uncert]
correct_dl = self.get_subset_dl_from_idxs(corr_idxs, self.args.batch_size, transform=self.hard_transform)
if ambiguous_idxs is not None:
ambiguous_dl = self.get_subset_dl_from_idxs(ambiguous_idxs, ambiguous_batch_size, transform=RepeatedTransform([self.transform, self.hard_transform], autosqueeze=True))
else:
ambiguous_dl = None
if incorrect_idxs is not None:
incorrect_dl = self.get_subset_dl_from_idxs(incorrect_idxs, incorrect_batch_size, probs=loss_prob[incorrect_idxs], transform=RepeatedTransform([
self.transform, self.hard_transform], autosqueeze=True))
else:
incorrect_dl = None
return correct_dl, ambiguous_dl, incorrect_dl
[docs]
def train_with_mixmatch(self, loader_L, loader_U, loader_R):
criterion_U = nn.MSELoss()
criterion_L = nn.CrossEntropyLoss()
iter_U = iter(loader_U)
iter_R = iter(loader_R)
avg_loss = 0
# R: weak, hard
# L: hard
# U: weak, hard
self.net.train()
for i, batch in enumerate(loader_L):
if self.args.debug_mode and i > 10:
break
self.opt.zero_grad()
inputs_L, labels_L = batch[0], batch[1]
if len(inputs_L) == 1:
continue
try:
inputs_U = next(iter_U)[0]
except BaseException:
iter_U = iter(loader_U)
inputs_U = next(iter_U)[0]
try:
batch_R = next(iter_R)
inputs_R, labels_R, probs_R = batch_R[0], batch_R[1], batch_R[-1]
except BaseException:
iter_R = iter(loader_R)
batch_R = next(iter_R)
inputs_R, labels_R, probs_R = batch_R[0], batch_R[1], batch_R[-1]
inputs_L, labels_L = inputs_L.to(self.device), labels_L.to(self.device)
inputs_U, inputs_R = inputs_U.to(self.device), inputs_R.to(self.device)
labels_R, probs_R = labels_R.to(self.device), probs_R.to(self.device)
labels_R = F.one_hot(labels_R, self.num_classes)
corr_prob = probs_R[:, 0].unsqueeze(1).expand(-1, self.num_classes)
inputs_U = torch.cat([inputs_U[:, 0], inputs_U[:, 1]], dim=0)
inputs_R = torch.cat([inputs_R[:, 0], inputs_R[:, 1]], dim=0)
do_cutmix = self.args.enable_cutmix and np.random.random(1) < self.args.cutmix_prob
if do_cutmix:
inputs_L, labels_L_a, labels_L_b, lam = cutmix_data(inputs_L, labels_L, force=True)
all_inputs = torch.cat([inputs_R, inputs_U, inputs_L], dim=0)
all_outputs = self.net(all_inputs)
outputs_R, outputs_U, outputs_L = torch.split(all_outputs, [inputs_R.size(0), inputs_U.size(0), inputs_L.size(0)])
loss_L = lam * self.loss(outputs_L, labels_L_a) + (1 - lam) * criterion_L(outputs_L, labels_L_b)
else:
all_inputs = torch.cat([inputs_R, inputs_U, inputs_L], dim=0)
all_outputs = self.net(all_inputs)
outputs_R, outputs_U, outputs_L = torch.split(all_outputs, [inputs_R.size(0), inputs_U.size(0), inputs_L.size(0)])
outputs_L = self.net(inputs_L)
loss_L = self.loss(outputs_L, labels_L)
outputs_U_weak, outputs_U_strong = torch.split(outputs_U, outputs_U.size(0) // 2)
outputs_R_pseudo, outputs_R = torch.split(outputs_R, outputs_R.size(0) // 2) # weak, strong
probs_R_pseudo = torch.softmax(outputs_R_pseudo, dim=1)
soft_pseudo_labels = corr_prob * labels_R + (1 - corr_prob) * probs_R_pseudo.detach()
loss_R = soft_cross_entropy_loss(outputs_R, soft_pseudo_labels)
loss_U = criterion_U(outputs_U_weak, outputs_U_strong)
coeff_L = (len(labels_L) / (len(labels_L) + len(labels_R) + len(outputs_U_weak)))
coeff_R = (len(labels_R) / (len(labels_R) + len(labels_L) + len(outputs_U_weak)))
coeff_U = (len(outputs_U_weak) / (len(labels_R) + len(labels_L) + len(outputs_U_weak)))
loss = coeff_L * loss_L + coeff_U * loss_U + coeff_R * loss_R
assert not torch.isnan(loss).any()
# backward
loss.backward()
self.opt.step()
avg_loss += loss.item()
return avg_loss / len(loader_L)
[docs]
def base_fit_buffer(self, loader=None):
self.net.train()
avg_loss = 0
if loader is None:
loader = get_dataloader_from_buffer(self.args, self.buffer, batch_size=self.args.batch_size, shuffle=True, transform=self.hard_transform)
for i, batch in enumerate(loader):
x, y = batch[0].to(self.device), batch[1].to(self.device)
if len(x) == 1:
continue
if self.args.debug_mode and i > 10:
break
self.opt.zero_grad()
do_cutmix = self.args.enable_cutmix and np.random.rand(1) < self.args.cutmix_prob
if do_cutmix:
x, y_a, y_b, lam = cutmix_data(x, y, force=True)
out = self.net(x)
loss = lam * self.loss(out, y_a) + (1 - lam) * self.loss(out, y_b)
else:
out = self.net(x)
loss = self.loss(out, y)
assert not torch.isnan(loss).any()
loss.backward()
self.opt.step()
avg_loss += loss.item()
return avg_loss / len(loader)
[docs]
def fit_buffer(self):
for param_group in self.opt.param_groups:
param_group["lr"] = self.args.lr
with tqdm.trange(self.args.buffer_fitting_epochs) as pbar:
for epoch in pbar:
if self.args.debug_mode and epoch > self.args.warmup_buffer_fitting_epochs + 50:
break
if epoch < self.args.warmup_buffer_fitting_epochs:
tp = 'warmup'
loss = self.base_fit_buffer()
else:
correct_dl, ambiguous_dl, incorrect_dl = self.split_data_puridiver()
if ambiguous_dl is not None and incorrect_dl is not None:
tp = 'puridiver'
loss = self.train_with_mixmatch(correct_dl, ambiguous_dl, incorrect_dl)
else:
tp = 'base'
loss = self.base_fit_buffer()
buf_not_aug_inputs, buf_labels, buf_true_labels = self.buffer.get_all_data()
_, _, buf_acc, true_buf_acc = self._non_observe_data(self.normalization_transform(buf_not_aug_inputs), buf_labels, buf_true_labels)
perc_clean = (self.buffer.labels == self.buffer.true_labels).float().mean().item()
pbar.set_postfix(loss=loss, buf_acc=buf_acc, true_buf_acc=true_buf_acc, perc_clean=perc_clean, lr=self.opt.param_groups[0]["lr"], refresh=False)
pbar.set_description(f'Epoch {epoch + 1}/{self.args.buffer_fitting_epochs} [{tp}]', refresh=False)
self.scheduler.step()
[docs]
def end_task(self, dataset):
# fit classifier on P
if self.args.buffer_fitting_epochs > 0:
self.fit_buffer()
[docs]
def get_classifier_weights(self):
if isinstance(self.net.classifier, nn.Sequential):
return self.net.classifier[0].weight.detach()
return self.net.classifier.weight.detach()
[docs]
def get_sim_score(self, feats, targets):
# relevant representation
cl_weights = self.get_classifier_weights()
relevant_idx = cl_weights[targets[0], :] > cl_weights.mean(dim=0)
cls_features = feats[:, relevant_idx]
sim_score = torch.cosine_similarity(cls_features, cls_features, dim=1)
return (sim_score - sim_score.mean()) / sim_score.std()
[docs]
def get_current_alpha_sim_score(self, loss):
return self.args.initial_alpha * min(1, 1 / loss)
[docs]
def get_scheduler(self):
return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
self.opt, T_0=1, T_mult=2, eta_min=self.args.lr * 0.01
)
[docs]
def begin_task(self, dataset):
self.total_its = len(dataset.train_loader) * self.args.n_epochs
if self.current_task == 0 and self.args.use_bn_classifier:
self.net.classifier = nn.Sequential(nn.Linear(self.net.classifier.in_features, self.net.classifier.out_features, bias=False),
nn.BatchNorm1d(self.net.classifier.out_features, affine=True, eps=1e-6).to(self.device)).to(self.device)
for m in self.net.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
m.eps = 1e-6
self.opt = self.get_optimizer()
self.scheduler = self.get_scheduler()
for param_group in self.opt.param_groups:
param_group["lr"] = self.args.lr
if self.args.disable_train_aug:
dataset.train_loader.dataset.transform = self.dataset.TEST_TRANSFORM
@torch.no_grad()
def _non_observe_data(self, inputs: torch.Tensor, labels: torch.Tensor, true_labels: torch.Tensor = None):
was_training = self.net.training
self.net.eval()
dset = CustomDataset(inputs, labels, extra=true_labels, device=self.device)
dl = DataLoader(dset, batch_size=min(len(dset), 256), shuffle=False, num_workers=0)
feats = []
losses = []
true_accs, accs = [], []
for batch in dl:
inputs, labels, true_labels = batch[0].to(self.device), batch[1].to(self.device), batch[2].to(self.device)
out, feat = self.net(inputs, returnt='both')
acc = (out.argmax(dim=1) == labels).float().mean().item()
tacc = (out.argmax(dim=1) == true_labels).float().mean().item()
feats.append(feat)
losses.append(F.cross_entropy(out, labels, reduction='none'))
accs.append(acc)
true_accs.append(tacc)
feats = torch.cat(feats, dim=0)
losses = torch.cat(losses, dim=0)
acc = np.mean(accs)
true_acc = np.mean(true_accs)
self.net.train(was_training)
return feats, losses, acc, true_acc
[docs]
def puridiver_update_buffer(self, stream_not_aug_inputs: torch.Tensor, stream_labels: torch.Tensor, stream_true_labels: torch.Tensor):
if len(self.buffer) < self.args.buffer_size:
self.buffer.add_data(examples=stream_not_aug_inputs, labels=stream_labels, true_labels=stream_true_labels)
return -1, -1
buf_not_aug_inputs, buf_labels, buf_true_labels = self.buffer.get_all_data()
buf_not_aug_inputs, buf_labels, buf_true_labels = buf_not_aug_inputs.to(self.device), buf_labels.to(self.device), buf_true_labels.to(self.device)
not_aug_inputs = torch.cat([buf_not_aug_inputs, stream_not_aug_inputs], dim=0)
labels = torch.cat([buf_labels, stream_labels], dim=0)
true_labels = torch.cat([buf_true_labels, stream_true_labels], dim=0)
cur_idxs = torch.arange(len(not_aug_inputs)).to(self.device)
feats, losses, buf_acc, true_buf_acc = self._non_observe_data(self.normalization_transform(not_aug_inputs), labels, true_labels=true_labels)
alpha_sim_score = self.get_current_alpha_sim_score(losses.mean())
lbs = labels[cur_idxs]
while len(lbs) > self.args.buffer_size:
fts = feats[cur_idxs]
lss = losses[cur_idxs]
clss, cls_cnt = lbs.unique(return_counts=True)
# argmax w/ random tie-breaking
cls_to_drop = clss[cls_cnt == cls_cnt.max()]
cls_to_drop = cls_to_drop[torch.randperm(len(cls_to_drop))][0]
mask = lbs == cls_to_drop
sim_score = self.get_sim_score(fts[mask], lbs[mask])
div_score = (1 - alpha_sim_score) * lss[mask] + alpha_sim_score * sim_score
drop_cls_idx = div_score.argmax()
drop_idx = cur_idxs[mask][drop_cls_idx]
cur_idxs = cur_idxs[cur_idxs != drop_idx]
lbs = labels[cur_idxs]
self.buffer.empty()
self.buffer.add_data(examples=not_aug_inputs[cur_idxs], labels=labels[cur_idxs], true_labels=true_labels[cur_idxs])
return buf_acc, true_buf_acc
[docs]
def observe(self, inputs, labels, not_aug_inputs, true_labels, epoch):
self.net.train()
B = len(inputs)
self.opt.zero_grad()
if self.current_task > 0: # starting from second task
buf_inputs, buf_labels, _ = self.buffer.get_data(
self.args.minibatch_size, transform=self.hard_transform, device=self.device)
inputs = torch.cat((inputs, buf_inputs))
labels = torch.cat((labels, buf_labels))
do_cutmix = self.args.enable_cutmix and np.random.rand(1) < self.args.cutmix_prob
if do_cutmix:
inputs, labels_a, labels_b, lam = cutmix_data(inputs, labels, force=True)
outputs = self.net(inputs)
loss = lam * self.loss(outputs, labels_a) + (1 - lam) * self.loss(outputs, labels_b)
else:
outputs = self.net(inputs)
loss = self.loss(outputs, labels)
assert not torch.isnan(loss).any()
loss.backward()
self.opt.step()
if self.args.freeze_buffer_after_first == 0 or epoch == 0:
self.puridiver_update_buffer(not_aug_inputs[:B], labels[:B], true_labels[:B])
return loss.item()