Source code for models.starprompt

import logging
import torch
from argparse import ArgumentParser

import torch

from models.star_prompt_utils.end_to_end_model import STARPromptModel
from models.utils.continual_model import ContinualModel
from utils import binary_to_boolean_type
from utils.schedulers import CosineSchedule

try:
    import clip
except ImportError:
    raise ImportError("Please install the CLIP package by running: pip install git+https://github.com/openai/CLIP.git (requires also `huggingface-hub`)")


[docs] class STARPrompt(ContinualModel): """Second-stage of StarPrompt. Requires the keys saved from the first stage.""" NAME = 'starprompt' COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual'] net: STARPromptModel
[docs] @staticmethod def get_parser(parser) -> ArgumentParser: parser.set_defaults(batch_size=128, optimizer='adam', lr=0.001) frozen_group = parser.add_argument_group('Frozen hyperparameters') frozen_group.add_argument("--virtual_bs_n", type=int, default=1, help="virtual batch size iterations") frozen_group.add_argument("--ortho_split_val", type=int, default=0) frozen_group.add_argument('--gr_mog_n_iters_second_stage', type=int, default=500, help="Number of EM iterations during fit for GR with MOG on the second stage.") frozen_group.add_argument('--gr_mog_n_iters_first_stage', type=int, default=200, help="Number of EM iterations during fit for GR with MOG on the first stage.") frozen_group.add_argument('--gr_mog_n_components', type=int, default=5, help="Number of components for GR with MOG (both first and second stage).") frozen_group.add_argument('--batch_size_gr', type=int, default=128, help="Batch size for Generative Replay (both first and second stage).") frozen_group.add_argument('--num_samples_gr', type=int, default=256, help="Number of samples for Generative Replay (both first and second stage).") frozen_group.add_argument('--prefix_tuning_prompt_len', type=int, default=5, help="Prompt length for prefix tuning. Used only if `--prompt_mode==concat`.") ablation_group = parser.add_argument_group('Ablations hyperparameters') ablation_group.add_argument('--gr_model', type=str, default='mog', choices=['mog', 'gaussian'], help="Type of distribution model for Generative Replay (both first and second stage). " "- `mog`: Mixture of Gaussian. " "- `gaussian`: Single Gaussian distribution.") ablation_group.add_argument("--enable_gr", type=binary_to_boolean_type, default=1, help="Enable Generative Replay (both first and second stage).") ablation_group.add_argument('--prompt_mode', type=str, default='residual', choices=['residual', 'concat'], help="Prompt type for the second stage. " "- `residual`: STAR-Prompt style prompting. " "- `concat`: Prefix-Tuning style prompting.") ablation_group.add_argument("--enable_confidence_modulation", type=binary_to_boolean_type, default=1, help="Enable confidence modulation with CLIP similarities (Eq. 5 of the main paper)?") tunable_group = parser.add_argument_group('Tunable hyperparameters') # second stage tunable_group.add_argument("--lambda_ortho_second_stage", type=float, default=10, help="orthogonality loss coefficient") tunable_group.add_argument("--num_monte_carlo_gr_second_stage", type=int, default=1, help="how many times to sample from the dataset for alignment") tunable_group.add_argument("--num_epochs_gr_second_stage", type=int, default=10, help="Num. of epochs for GR.") tunable_group.add_argument("--learning_rate_gr_second_stage", type=float, default=0.001, help="Learning rate for GR.") # first stage tunable_group.add_argument("--num_monte_carlo_gr_first_stage", type=int, default=1, help="how many times to sample from the dataset for alignment") tunable_group.add_argument("--learning_rate_gr_first_stage", type=float, default=0.05, help="Learning rate for Generative Replay.") tunable_group.add_argument("--lambda_ortho_first_stage", type=float, default=30, help="Orthogonality loss coefficient for coop") tunable_group.add_argument("--num_epochs_gr_first_stage", type=int, default=10, help="Num. of epochs for Generative Replay.") parser.add_argument("--clip_backbone", type=str, default='ViT-L/14', help="CLIP backbone architecture", choices=clip.available_models()) first_stage_optim_group = parser.add_argument_group('First stage optimization hyperparameters') first_stage_optim_group.add_argument("--first_stage_optim", type=str, default='sgd', choices=['sgd', 'adam'], help="First stage optimizer") first_stage_optim_group.add_argument("--first_stage_lr", type=float, default=0.002, help="First stage learning rate") first_stage_optim_group.add_argument("--first_stage_momentum", type=float, default=0, help="First stage momentum") first_stage_optim_group.add_argument("--first_stage_weight_decay", type=float, default=0, help="First stage weight decay") first_stage_optim_group.add_argument("--first_stage_epochs", type=int, help="First stage epochs. If not set, it will be the same as `n_epochs`.") return parser
def __init__(self, backbone, loss, args, transform, dataset=None): if not hasattr(args, 'first_stage_epochs') or args.first_stage_epochs is None: logging.info("`first_stage_epochs` not set. Setting it to `n_epochs`.") args.first_stage_epochs = args.n_epochs super().__init__(backbone, loss, args, transform, dataset=dataset) self.net = STARPromptModel(args, backbone=self.net, dataset=self.dataset, num_classes=self.num_classes, device=self.device)
[docs] def end_task(self, dataset): if hasattr(self, 'opt'): del self.opt # free up some vram if self.args.enable_gr: self.net.update_statistics(dataset, self.n_past_classes, self.n_seen_classes) self.net.backup(self.current_task, self.n_past_classes, self.n_seen_classes) if self.current_task > 0: if self.args.seed is not None: torch.manual_seed(self.args.seed) self.net.align(self.current_task, self.n_seen_classes, self.loss)
[docs] def get_parameters(self): if not isinstance(self.net, STARPromptModel): # during initialization return super().get_parameters() return [p for p in self.net.second_stage.parameters() if p.requires_grad]
[docs] def get_scheduler(self): return CosineSchedule(self.opt, K=self.args.n_epochs)
[docs] def begin_task(self, dataset): # clean junk on GPU if hasattr(self, 'opt'): del self.opt torch.cuda.empty_cache() # adapt CLIP on current task self.net.train_first_stage_on_task(dataset, self.current_task, self.n_past_classes, self.n_seen_classes, self.loss) self.net.update_keys(self.n_past_classes, self.n_seen_classes) self.net.second_stage.train() # initialize second stage # For later GR self.net.recall_classifier_second_stage(self.current_task, self.n_past_classes, self.n_seen_classes) self.opt = self.get_optimizer() self.scheduler = self.get_scheduler()
[docs] def forward(self, x): logits = self.net(x, cur_classes=self.n_seen_classes) logits = logits[:, :self.n_seen_classes] return logits
[docs] def observe(self, inputs, labels, not_aug_inputs, epoch=None): # second stage only stream_inputs, stream_labels = inputs, labels stream_logits = self.net(stream_inputs, cur_classes=self.n_seen_classes, frozen_past_classes=self.n_past_classes) # Compute accuracy on current training batch for logging with torch.no_grad(): stream_preds = stream_logits[:, :self.n_seen_classes].argmax(dim=1) stream_acc = (stream_preds == stream_labels).sum().item() / stream_labels.shape[0] # mask old classes stream_logits[:, :self.n_past_classes] = -float('inf') loss = self.loss(stream_logits[:, :self.n_seen_classes], stream_labels) loss_ortho = self.net.second_stage.prompter.compute_ortho_loss(frozen_past_classes=self.n_past_classes, cur_classes=self.n_seen_classes) loss += self.args.lambda_ortho_second_stage * loss_ortho if self.epoch_iteration == 0: self.opt.zero_grad() (loss / self.args.virtual_bs_n).backward() # loss.backward() if (self.epoch_iteration > 0 or self.args.virtual_bs_n == 1) and \ self.epoch_iteration % self.args.virtual_bs_n == 0: self.opt.step() self.opt.zero_grad() return {'loss': loss.item(), 'stream_accuracy': stream_acc}