import copy
import logging
import time
import numpy as np
import torch
from torch import nn
import tqdm
from backbone import get_backbone
from models.utils.continual_model import ContinualModel
from utils import binary_to_boolean_type
from utils.args import add_rehearsal_args
from utils.bmm import BetaMixture1D
from utils.buffer import Buffer
import torch.nn.functional as F
import networkx as nx
from copy import deepcopy
def _get_projector_prenet(net, device=None, bn=True):
device = net.device if hasattr(net, 'device') else device if device is not None else "cpu"
assert "resnet" in type(net).__name__.lower(), "Only resnet is supported for now"
sizes = [net.nf * 8, net.nf * 8, 256]
layers = []
for i in range(len(sizes) - 2):
layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=True).to(device))
if bn:
layers.append(nn.BatchNorm1d(sizes[i + 1]).to(device))
layers.append(nn.ReLU(inplace=True).to(device))
layers.append(nn.Linear(sizes[-2], sizes[-1], bias=True).to(device))
return nn.Sequential(*layers).to(device)
[docs]
def init_simclr_net(model, device=None):
model.projector = _get_projector_prenet(model, device=device, bn=False)
model.predictor = deepcopy(model.projector)
return model
[docs]
class SimCLR:
def __init__(self, transform, temp=0.5, eps=1e-6, filter_bs_len=None, correlation_mask=None):
self.temp = temp
self.eps = eps
self.filter_bs_len = filter_bs_len
self.correlation_mask = correlation_mask
self.transform = transform
def __call__(self, model, x):
xa = self.transform(x)
xb = self.transform(x)
outa = model.projector(model(xa))
outb = model.projector(model(xb))
outa = F.normalize(outa, dim=1)
outb = F.normalize(outb, dim=1)
out = torch.cat([outb, outa], dim=0)
cov = F.cosine_similarity(out.unsqueeze(1), out.unsqueeze(0), dim=-1)
# filter out the scores from the positive samples
l_pos = torch.diag(cov, self.filter_bs_len)
r_pos = torch.diag(cov, -self.filter_bs_len)
positives = torch.cat([l_pos, r_pos]).view(2 * self.filter_bs_len, 1)
negatives = cov[self.correlation_mask].view(2 * self.filter_bs_len, -1)
logits = torch.cat((positives, negatives), dim=1)
logits /= self.temp
labels = torch.zeros(2 * self.filter_bs_len).to(cov.device).long()
loss = F.cross_entropy(logits, labels, reduction='sum') / (2 * self.filter_bs_len)
return loss
[docs]
def disable_linear(backbone):
# disable linear base net
in_features = backbone.classifier.in_features
out_features = backbone.classifier.out_features
backbone.classifier = nn.Identity()
backbone.classifier.in_features = in_features
backbone.classifier.out_features = out_features
return backbone
[docs]
class Spr(ContinualModel):
"""
Implementation of `Continual Learning on Noisy Data Streams via Self-Purified Replay <https://github.com/ecrireme/SPR>`_ from ICCV 2021.
"""
NAME = 'spr'
COMPATIBILITY = ['class-il', 'task-il']
OVERRIDE_SUPPORT_DISTRIBUTED = True
[docs]
@staticmethod
def get_parser(parser):
parser.set_defaults(optimizer='adam', lr=0.0002)
add_rehearsal_args(parser)
parser.add_argument('--spr_debug_mode', type=binary_to_boolean_type, default=False, help='Run SPR with just a few iterations?')
parser.add_argument('--delayed_buffer_size', type=int, default=500,
help='Size of the delayed buffer.')
parser.add_argument('--fitting_lr', type=float, default=0.002,
help='LR used during finetuining (classifier buffer fitting on P)')
parser.add_argument('--fitting_epochs', type=int, default=50,
help='Number of epochs used during finetuining (classifier buffer fitting on P)')
parser.add_argument('--inner_train_epochs', type=int, default=3000,
help='Inner train epochs for SSL (base net)')
parser.add_argument('--expert_train_epochs', type=int, default=4000,
help='Innert train epochs for SSL (expert)')
parser.add_argument('--simclr_temp', type=float, default=0.5,
help='Temperature for simclr SSL loss')
parser.add_argument('--fitting_sched_lr_stepsize', type=int, default=300,
help='Step size for the LR scheduler during finetuining (classifier buffer fitting on P)')
parser.add_argument('--fitting_sched_lr_gamma', type=float, default=0.1,
help='Gamma for the LR scheduler during finetuining (classifier buffer fitting on P)')
parser.add_argument('--fitting_batch_size', type=int, default=16,
help='Batch size for finetuining (classifier buffer fitting on P)')
parser.add_argument('--fitting_clip_value', type=float, default=0.5,
help='Gradient clipping for finetuning')
parser.add_argument('--E_max', type=int, default=5,
help='Number of stochastic ensemble for expert')
return parser
def __init__(self, backbone, loss, args, transform, dataset=None):
cl_in_features = backbone.classifier.in_features
# disable linear base net
backbone = disable_linear(backbone)
init_simclr_net(backbone)
super().__init__(backbone, loss, args, transform, dataset=dataset)
self.cl_in_features = cl_in_features
self.buffer = Buffer(self.args.buffer_size, "cpu")
self.delayed_buffer = Buffer(self.args.delayed_buffer_size, "cpu")
self.past_loss = 0
self.get_optimizer()
self.lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.opt, T_max=self.args.inner_train_epochs, eta_min=0, last_epoch=-1)
self.finetuned_model = get_backbone(args)
missing, ignored = self.finetuned_model.load_state_dict(copy.deepcopy(self.net.state_dict()), strict=False)
assert len([m for m in missing if 'classifier' not in m and 'fc' not in m]) == 0, missing
self.finetuned_model.classifier = nn.Linear(self.cl_in_features, self.num_classes).to(self.device)
self.finetuned_model.to(self.device)
self.expert_model = get_backbone(args)
self.expert_model = disable_linear(self.expert_model)
init_simclr_net(self.expert_model)
missing, ignored = self.expert_model.load_state_dict(copy.deepcopy(self.net.state_dict()), strict=False)
assert len([m for m in missing if 'classifier' not in m and 'fc' not in m]) == 0, missing
self.expert_model.to(self.device)
self.expert_opt = self.get_optimizer(self.expert_model.parameters())
self.expert_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.expert_opt, T_max=self.args.expert_train_epochs, eta_min=0, last_epoch=-1)
self.expert_model.to("cpu")
self.net.to("cpu")
self.finetuned_model.to("cpu")
[docs]
@torch.no_grad()
def cluster_and_sample(self):
"""filter samples in delay buffer"""
self.expert_model.eval()
self.expert_model.to(self.device)
self.delayed_buffer.to(self.device)
xs = self.delayed_buffer.examples
ys = self.delayed_buffer.labels
tls = self.delayed_buffer.true_labels
corrs = tls == ys
features = self.expert_model.projector(self.expert_model(xs))
features = F.normalize(features, dim=1)
clean_idx, clean_p = [], []
noisy_samples_selected, corr_samples_selected = 0, 0
for u_y in torch.unique(ys):
y_mask = ys == u_y
corr = corrs[y_mask]
feature = features[y_mask]
# ignore negative similairties
_similarity_matrix = torch.relu(F.cosine_similarity(feature.unsqueeze(1), feature.unsqueeze(0), dim=-1))
# stochastic ensemble
_clean_ps = torch.zeros((self.args.E_max, len(feature)), dtype=torch.double)
for _i in range(self.args.E_max):
# sample binary adjacency matrix from bernoulli distribution (of similarity matrix)
similarity_matrix = (_similarity_matrix > torch.rand_like(_similarity_matrix)).type(torch.float32)
similarity_matrix[similarity_matrix == 0] = 1e-5 # add small num for ensuring positive matrix
# get centrality
g = nx.from_numpy_array(similarity_matrix.cpu().numpy())
info = nx.eigenvector_centrality(g, max_iter=6000, weight='weight') # index: value
centrality = np.asarray(list(info.values()))
# fit BMM
bmm_model = BetaMixture1D(max_iters=10)
# fit beta mixture model
c = np.copy(centrality)
c, c_min, c_max = bmm_model.outlier_remove(c)
c = bmm_model.normalize(c, c_min, c_max)
bmm_model.fit(c)
bmm_model.create_lookup(1) # 0: noisy, 1: clean
# get posterior
c = np.copy(centrality)
c = bmm_model.normalize(c, c_min, c_max)
p = bmm_model.look_lookup(c)
_clean_ps[_i] = torch.from_numpy(p)
_clean_ps = torch.mean(_clean_ps, dim=0)
m = _clean_ps > torch.rand_like(_clean_ps)
clean_idx.extend(torch.nonzero(y_mask)[:, -1][m].tolist())
clean_p.extend(_clean_ps[m].tolist())
corr_samples_selected += corr[m].sum().item()
noisy_samples_selected += (~corr)[~m].sum().item()
return clean_idx, torch.Tensor(clean_p), corr_samples_selected, noisy_samples_selected
[docs]
def train_self_expert(self):
"""Train expert model with samples from delay buffer only"""
self.finetuned_model.to("cpu")
self.net.to("cpu")
# reset expert model
nt = get_backbone(self.args)
nt = disable_linear(nt)
nt = init_simclr_net(nt)
missing, ignored = self.expert_model.load_state_dict(nt.state_dict())
assert len([m for m in missing if 'classifier' not in m and 'fc' not in m]) == 0, missing
self.expert_model.to(self.device)
torch.cuda.empty_cache()
def _get_correlated_mask(bs):
diag = np.eye(2 * bs)
l1 = np.eye((2 * bs), 2 * bs, k=-bs)
l2 = np.eye((2 * bs), 2 * bs, k=bs)
mask = torch.from_numpy((diag + l1 + l2))
mask = (1 - mask).type(torch.bool)
return mask
# total batch size = buffer size (delay only)
bs = min(self.args.delayed_buffer_size, len(self.delayed_buffer))
self.expert_model.train()
correlation_mask = _get_correlated_mask(bs).to(self.device)
loss_fn = SimCLR(self.transform, temp=self.args.simclr_temp, filter_bs_len=bs, correlation_mask=correlation_mask)
sampler = torch.utils.data.RandomSampler(self.delayed_buffer, replacement=True)
self.delayed_buffer.to(self.device)
delayer_dl = self.delayed_buffer.get_dataloader(self.args, batch_size=bs, drop_last=True, sampler=sampler)
totloss, cit = 0, 0
for epoch_i in tqdm.trange(self.args.expert_train_epochs, desc="Expert network training", leave=False):
if self.args.spr_debug_mode == 1 and epoch_i > 10:
break
for data in delayer_dl:
inputs = data[0].to(self.device)
self.expert_opt.zero_grad()
loss = loss_fn(self.expert_model, inputs)
loss.backward()
self.expert_opt.step()
totloss += loss.item()
cit += 1
# warmup for the first 10 epochs
if epoch_i >= 10:
self.expert_lr_scheduler.step()
return totloss / cit
[docs]
def train_self_base(self):
"""Self Replay. train base model with samples from delay and purified buffer"""
self.expert_model.to("cpu")
self.finetuned_model.to("cpu")
self.net.to(self.device)
def _get_correlated_mask(bs):
diag = np.eye(2 * bs)
l1 = np.eye((2 * bs), 2 * bs, k=-bs)
l2 = np.eye((2 * bs), 2 * bs, k=bs)
mask = torch.from_numpy((diag + l1 + l2))
mask = (1 - mask).type(torch.bool)
return mask
# total batch size = buffer size (splitted btw delay and purified)
bs = self.args.buffer_size
# If purified buffer is full, train using it also
db_bs = (bs // 2) if self.buffer.is_full() else bs
db_bs = min(db_bs, len(self.delayed_buffer))
pb_bs = min(bs - db_bs, len(self.buffer))
self.net.train()
correlation_mask = _get_correlated_mask(db_bs + pb_bs).to(self.device)
loss_fn = SimCLR(self.transform, temp=self.args.simclr_temp, filter_bs_len=db_bs + pb_bs, correlation_mask=correlation_mask)
totloss, cit = 0, 0
delayed_sampler = torch.utils.data.RandomSampler(self.delayed_buffer, replacement=True)
delayed_dl = self.delayed_buffer.get_dataloader(self.args, batch_size=db_bs, drop_last=False, sampler=delayed_sampler)
for epoch_i in tqdm.trange(self.args.inner_train_epochs, desc="Base network training", leave=False):
if self.args.spr_debug_mode == 1 and epoch_i > 10:
break
for data in delayed_dl:
x = data[0].to(self.device)
if pb_bs > 0:
xp = self.buffer.get_data(pb_bs)[0].to(self.device)
x = torch.cat([x, xp], dim=0)
self.opt.zero_grad()
loss = loss_fn(self.net, x)
loss.backward()
self.opt.step()
totloss += loss.item()
cit += 1
# warmup for the first 10 epochs
if epoch_i >= 10:
self.lr_scheduler.step()
return totloss / cit
[docs]
def begin_task(self, dataset):
self.observe_it = 0
def _observe(self, not_aug_x, y, true_y):
self.delayed_buffer.add_data(examples=not_aug_x.unsqueeze(0),
labels=y.unsqueeze(0),
true_labels=true_y.unsqueeze(0))
avg_expert_loss, avg_self_loss = -1, -1
if self.delayed_buffer.is_full():
self.observe_it += 1
# Train expert net with SSL on D only
avg_expert_loss = self.train_self_expert()
# Train base net with SSL on D and P
avg_self_loss = self.train_self_base()
pret = time.time()
# Get clean samples from D
clean_idx, _, corr_sel, noisy_sel = self.cluster_and_sample()
# Add clean samples to P
self.buffer.add_data(examples=self.delayed_buffer.examples[clean_idx],
labels=self.delayed_buffer.labels[clean_idx],
true_labels=self.delayed_buffer.true_labels[clean_idx])
logging.debug("Purifying buffer took", time.time() - pret)
self.delayed_buffer.empty()
return avg_expert_loss, avg_self_loss
[docs]
def fit_buffer(self):
"""Fit finetuned model on purified buffer, before eval"""
logging.debug("Fitting finetuned model on purified buffer")
self.expert_model.to("cpu")
self.net.to("cpu")
self.finetuned_model = get_backbone(self.args)
missing, ignored = self.finetuned_model.load_state_dict(copy.deepcopy(self.net.state_dict()), strict=False)
assert len([m for m in missing if 'classifier' not in m and 'fc' not in m]) == 0
assert len([m for m in ignored if 'projector' not in m and 'predictor' not in m]) == 0
self.finetuned_model.classifier = nn.Linear(self.cl_in_features, self.n_seen_classes).to(self.device)
self.finetuned_model.to(self.device)
opt = self.get_optimizer(self.finetuned_model.parameters(), lr=self.args.fitting_lr)
sched = torch.optim.lr_scheduler.StepLR(opt, step_size=self.args.fitting_sched_lr_stepsize, gamma=self.args.fitting_sched_lr_gamma)
sampler = torch.utils.data.RandomSampler(self.buffer)
buffer_dl = self.buffer.get_dataloader(self.args, batch_size=self.args.fitting_batch_size, drop_last=True, sampler=sampler) # , transform=self.transform NO TRANSFORM???
self.finetuned_model.train()
ce_loss = nn.NLLLoss()
for epoch in tqdm.trange(self.args.fitting_epochs, desc="Buffer fitting", leave=False, disable=True):
if self.args.spr_debug_mode == 1 and epoch > 10:
break
for dat in buffer_dl:
x, y = dat[0], dat[1]
x, y = x.to(self.device), y.to(self.device)
opt.zero_grad()
out = self.finetuned_model(x)
loss = ce_loss(F.log_softmax(out, dim=1), y)
loss.backward()
if self.args.fitting_clip_value is not None:
nn.utils.clip_grad_value_(self.parameters(), self.args.fitting_clip_value)
opt.step()
sched.step()
self.finetuned_model.eval()
[docs]
def forward(self, inputs):
if self.finetuned_model.device != inputs.device:
self.finetuned_model.to(inputs.device)
return self.finetuned_model(inputs)
[docs]
def end_task(self, dataset):
# fit classifier on P
self.fit_buffer()
self.buffer.to(self.device)
[docs]
def observe(self, inputs, labels, not_aug_inputs, true_labels):
for y, not_aug_x, true_y in zip(labels, not_aug_inputs, true_labels): # un-batch the data
avg_expert_loss, avg_self_loss = self._observe(not_aug_x, y, true_y)
return self.past_loss if avg_self_loss < 0 else avg_self_loss