import torch
from copy import deepcopy
from torch.utils.data import TensorDataset
from tqdm import tqdm
from argparse import ArgumentParser
try:
import wandb
except ImportError:
wandb = None
from utils import binary_to_boolean_type
from utils.augmentations import RepeatedTransform
from utils.conf import create_seeded_dataloader
from utils.schedulers import CosineSchedule
from models.utils.continual_model import ContinualModel
from models.star_prompt_utils.second_stage_model import Model
from models.star_prompt_utils.generative_replay import Gaussian, MixtureOfGaussiansModel
[docs]
class SecondStageStarprompt(ContinualModel):
"""Second-stage of StarPrompt. Requires the keys saved from the first stage."""
NAME = 'second_stage_starprompt'
COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual']
[docs]
@staticmethod
def get_parser(parser) -> ArgumentParser:
parser.set_defaults(batch_size=128, optimizer='adam', lr=0.001)
frozen_group = parser.add_argument_group('Frozen hyperparameters')
frozen_group.add_argument("--virtual_bs_n", type=int, default=1,
help="virtual batch size iterations")
frozen_group.add_argument("--enable_data_aug_query", type=binary_to_boolean_type, default=1,
help="Use default transform with data aug to generate the CLIP's response?")
frozen_group.add_argument("--use_clip_preprocess_eval", type=binary_to_boolean_type, default=0,
help="Use CLIP's transform during eval instead of the default test transform?")
frozen_group.add_argument("--ortho_split_val", type=int, default=0)
frozen_group.add_argument('--gr_mog_n_iters', '--gr_mog_n_iters_second_stage', dest='gr_mog_n_iters_second_stage',
type=int, default=500, help="Number of EM iterations during fit for GR with MOG.")
frozen_group.add_argument('--gr_mog_n_components', type=int, default=5,
help="Number of components for GR with MOG.")
frozen_group.add_argument('--batch_size_gr', type=int, default=128,
help="Batch size for Generative Replay.")
frozen_group.add_argument('--num_samples_gr', type=int, default=256,
help="Number of samples for Generative Replay.")
frozen_group.add_argument('--prefix_tuning_prompt_len', type=int, default=5,
help="Prompt length for prefix tuning. Used only if `--prompt_mode==concat`.")
ablation_group = parser.add_argument_group('Ablations hyperparameters')
ablation_group.add_argument('--gr_model', type=str, default='mog', choices=['mog', 'gaussian'],
help="Type of distribution model for Generative Replay. "
"- `mog`: Mixture of Gaussian. "
"- `gaussian`: Single Gaussian distribution.")
ablation_group.add_argument("--enable_gr", type=binary_to_boolean_type, default=1,
help="Enable Generative Replay.")
ablation_group.add_argument('--statc_keys_use_templates', type=binary_to_boolean_type, default=1,
help="Use templates for the second stage if no keys are loaded.")
ablation_group.add_argument('--prompt_mode', type=str, default='residual', choices=['residual', 'concat'],
help="Prompt type for the second stage. "
"- `residual`: STAR-Prompt style prompting. "
"- `concat`: Prefix-Tuning style prompting.")
ablation_group.add_argument("--enable_confidence_modulation", type=binary_to_boolean_type, default=1,
help="Enable confidence modulation with CLIP similarities (Eq. 5 of the main paper)?")
# Tunable hyperparameters
tunable_group = parser.add_argument_group('Tunable hyperparameters')
tunable_group.add_argument("--lambda_ortho_second_stage", type=float, default=10,
help="orthogonality loss coefficient")
tunable_group.add_argument("--num_monte_carlo_gr", "--num_monte_carlo_gr_second_stage", dest="num_monte_carlo_gr_second_stage",
type=int, default=1, help="how many times to sample from the dataset for alignment")
tunable_group.add_argument("--num_epochs_gr", "--num_epochs_gr_second_stage", dest="num_epochs_gr_second_stage",
type=int, default=10, help="Num. of epochs for GR.")
tunable_group.add_argument("--learning_rate_gr", "--learning_rate_gr_second_stage", dest="learning_rate_gr_second_stage",
type=float, default=0.001, help="Learning rate for GR.")
# Very important parameter
parser.add_argument('--keys_ckpt_path', type=str,
help="Path for first-stage keys. "
"The keys can be saved by runninng `first_stage_starprompt` with `--save_first_stage_keys=1`."
"This can be:"
"- A path to a checkpoint file (.pt) containing ONLY THE FIRST STAGE KEYS."
"- A path to the checkpoint made by `first_stage_starprompt`"
"- The job-id (`conf_jobnum`) of the `first_stage_starprompt` run that made the keys."
"- A JSON file containing the job-id (`conf_jobnum`) of the `first_stage_starprompt` run that made the keys."
"The JSON is expected to contain an entry for each dataset and seed: `{dataset: {seed: job-id}}`.")
return parser
net: Model
def __init__(self, backbone, loss, args, transform, dataset=None):
super().__init__(backbone, loss, args, transform, dataset=dataset)
self.net = Model(args,
backbone=self.net,
dataset=self.dataset,
num_classes=self.num_classes,
device=self.device)
# REMOVE ALL TRACK RUNNING STATS FROM CLIP
for m in self.net.modules():
if isinstance(m, (torch.nn.BatchNorm2d, torch.nn.BatchNorm1d)):
m.track_running_stats = False
embed_dim = self.net.vit.embed_dim
self.distributions = torch.nn.ModuleList([self._get_dist(embed_dim)
for _ in range(self.num_classes)]).to(self.device)
self.classifier_state_dict = None
def _get_dist(self, embed_dim):
assert self.args.gr_model in ['mog', 'gaussian'], f"Invalid GR model: {self.args.gr_model}"
if self.args.gr_model == 'mog':
return MixtureOfGaussiansModel(embed_dim, n_components=self.args.gr_mog_n_components,
n_iters=self.args.gr_mog_n_iters_second_stage)
else:
return Gaussian(embed_dim)
[docs]
def norm(self, t):
return torch.norm(t, p=2, dim=-1, keepdim=True) + 1e-7
[docs]
@torch.no_grad()
def create_features_dataset(self):
labels, features = [], []
for _ti in range(self.current_task + 1):
prev_t_size, cur_t_size = self.compute_offsets(_ti)
for class_idx in range(prev_t_size, cur_t_size):
current_samples = self.distributions[class_idx](self.args.num_samples_gr)
features.append(current_samples)
labels.append(torch.ones(self.args.num_samples_gr) * class_idx)
features = torch.cat(features, dim=0)
labels = torch.cat(labels, dim=0).long()
return create_seeded_dataloader(self.args, TensorDataset(features, labels),
batch_size=self.args.batch_size_gr,
shuffle=True, num_workers=0)
[docs]
def train_alignment_epoch(self, classifier: torch.nn.Module, optim: torch.optim.Optimizer, epoch: int):
dl = self.create_features_dataset()
with tqdm(enumerate(dl), total=len(dl), desc=f'GR second stage epoch {epoch + 1}/{self.args.num_epochs_gr_second_stage}', leave=False) as pbar:
for i, (x, labels) in pbar:
optim.zero_grad()
x, labels = x.to(self.device, dtype=torch.float32), labels.to(self.device)
logits = classifier(x)
logits = logits[:, :self.n_seen_classes]
norm = self.norm(logits)
logits = logits / (0.1 * norm)
loss = self.loss(logits, labels)
loss.backward()
optim.step()
if not self.args.nowand:
assert wandb is not None, "wandb is not installed."
wandb.log({'ca_loss_second_stage': loss.item(), 'ca_lr_second_stage': optim.param_groups[0]['lr']})
pbar.set_postfix({'loss': loss.item()}, refresh=False)
[docs]
def align(self):
classifier = deepcopy(self.net.vit.head)
optim = torch.optim.SGD(lr=self.args.learning_rate_gr_second_stage,
params=classifier.parameters(),
momentum=0.0,
weight_decay=0.0)
num_epochs = self.args.num_epochs_gr_second_stage + (5 * self.current_task)
for e in range(num_epochs):
self.train_alignment_epoch(classifier, optim, e)
self.net.vit.head.weight.data.copy_(classifier.weight.data)
self.net.vit.head.bias.data.copy_(classifier.bias.data)
[docs]
@torch.no_grad()
def update_statistics(self, dataset):
features_dict = {i: [] for i in range(self.n_past_classes, self.n_seen_classes)}
self.net.eval()
with tqdm(total=self.args.num_monte_carlo_gr_second_stage * len(dataset.train_loader), desc='GR update statistics') as pbar:
for _ in range(self.args.num_monte_carlo_gr_second_stage):
for i, data in enumerate(dataset.train_loader):
if self.args.debug_mode and i > 3 and min([len(v) for k, v in features_dict.items()]) > self.args.gr_mog_n_components:
break
x, labels = data[0], data[1]
x, labels = x.to(self.device), labels.to(self.device, dtype=torch.long)
x, query_x = x[:, 0], x[:, 1]
if self.args.enable_data_aug_query:
query_x = None
features = self.net(x, query_x=query_x, return_features=True, cur_classes=self.n_seen_classes, frozen_past_classes=self.n_past_classes)
features = features[:, 0]
for class_idx in labels.unique():
features_dict[int(class_idx)].append(features[labels == class_idx])
pbar.update(1)
for class_idx in range(self.n_past_classes, self.n_seen_classes):
features_class_idx = torch.cat(features_dict[class_idx], dim=0)
self.distributions[class_idx].fit(features_class_idx.to(self.device))
[docs]
def backup(self):
print(f"BACKUP: Task - {self.current_task} - classes from "
f"{self.n_past_classes} - to {self.n_seen_classes}")
self.classifier_state_dict = deepcopy(self.net.vit.head.state_dict())
[docs]
def recall(self):
print(f"RECALL: Task - {self.current_task} - classes from "
f"{self.n_past_classes} - to {self.n_seen_classes}")
if self.current_task == 0 or not self.args.enable_gr:
return
assert self.classifier_state_dict
self.net.vit.head.weight.data.copy_(self.classifier_state_dict['weight'].data)
self.net.vit.head.bias.data.copy_(self.classifier_state_dict['bias'].data)
[docs]
def end_task(self, dataset):
if hasattr(self, 'opt'):
del self.opt # free up some vram
if self.args.enable_gr:
self.update_statistics(dataset)
self.backup()
if self.current_task > 0:
self.align()
[docs]
def get_parameters(self):
return [p for p in self.net.parameters() if p.requires_grad]
[docs]
def get_scheduler(self):
return CosineSchedule(self.opt, K=self.args.n_epochs)
[docs]
def begin_task(self, dataset):
if self.args.permute_classes:
if hasattr(self.net.prompter, 'old_args') and self.net.prompter.old_args is not None:
assert self.args.seed == self.net.prompter.old_args.seed
assert (self.args.class_order == self.net.prompter.old_args.class_order).all()
dataset.train_loader.dataset.transform = RepeatedTransform([dataset.train_loader.dataset.transform, self.net.prompter.clip_preprocess])
dataset.test_loaders[-1].dataset.transform = RepeatedTransform([dataset.test_loaders[-1].dataset.transform, self.net.prompter.clip_preprocess])
# NOTE: Remove these comments if you want to check if the keys are loaded correctly and results are the same as the first stage
# tot_data, tot_corr = 0, 0
# for i, ts in enumerate(dataset.test_loaders):
# task_tot, task_corr = 0, 0
# for data in ts:
# inputs, labels = data[0], data[1]
# inputs, labels = inputs[:, 1].to(self.device), labels.to(self.device) # only clip-preprocessed input
# queries = self.net.prompter.get_query(inputs)
# queries = torch.nn.functional.normalize(queries, dim=-1)
# logits = torch.einsum('bd,cd->bc', queries, self.net.prompter.keys.type(self.net.prompter.clip_model.dtype))
# task_corr += (logits.argmax(dim=-1) == labels).sum().item()
# task_tot += labels.shape[0]
# print(f"CLIP on TASK {i+1}: {task_corr / task_tot}")
# tot_corr += task_corr
# tot_data += task_tot
# print(f"AVG CLIP ON TASKS: {tot_corr / tot_data}") # the avg of the avg != the avg of the total
# For later GR
self.recall()
if hasattr(self, 'opt'):
del self.opt
self.opt = self.get_optimizer()
self.scheduler = self.get_scheduler()
[docs]
def forward(self, x):
x, query_x = x[:, 0], x[:, 1] # from repeated transform
if not self.args.use_clip_preprocess_eval:
query_x = None
logits = self.net(x, query_x=query_x, cur_classes=self.n_seen_classes)
logits = logits[:, :self.n_seen_classes]
return logits
[docs]
def observe(self, inputs, labels, not_aug_inputs, epoch=None):
stream_inputs, stream_labels = inputs, labels
stream_inputs, query_stream_inputs = stream_inputs[:, 0], stream_inputs[:, 1]
if self.args.enable_data_aug_query:
query_stream_inputs = None
stream_logits = self.net(stream_inputs, query_x=query_stream_inputs, cur_classes=self.n_seen_classes, frozen_past_classes=self.n_past_classes)
# Compute accuracy on current training batch for logging
with torch.no_grad():
stream_preds = stream_logits[:, :self.n_seen_classes].argmax(dim=1)
stream_acc = (stream_preds == stream_labels).sum().item() / stream_labels.shape[0]
# mask old classes
stream_logits[:, :self.n_past_classes] = -float('inf')
loss = self.loss(stream_logits[:, :self.n_seen_classes], stream_labels)
loss_ortho = self.net.prompter.compute_ortho_loss(frozen_past_classes=self.n_past_classes, cur_classes=self.n_seen_classes)
loss += self.args.lambda_ortho_second_stage * loss_ortho
if self.epoch_iteration == 0:
self.opt.zero_grad()
(loss / self.args.virtual_bs_n).backward()
# loss.backward()
if (self.epoch_iteration > 0 or self.args.virtual_bs_n == 1) and \
self.epoch_iteration % self.args.virtual_bs_n == 0:
self.opt.step()
self.opt.zero_grad()
return {'loss': loss.item(),
'stream_accuracy': stream_acc}