Source code for models.coda_prompt

"""
CODA-Prompt: COntinual Decomposed Attention-based Prompting

Note:
    CODA-Prompt USES A CUSTOM BACKBONE: `vit_base_patch16_224`.
    The backbone is a ViT-B/16 pretrained on Imagenet 21k and finetuned on ImageNet 1k.
"""

import logging
from utils.args import *
from models.utils.continual_model import ContinualModel
import torch
from datasets import get_dataset
from models.coda_prompt_utils.model import Model
from utils.schedulers import CosineSchedule


[docs] class CodaPrompt(ContinualModel): """Continual Learning via CODA-Prompt: COntinual Decomposed Attention-based Prompting.""" NAME = 'coda_prompt' COMPATIBILITY = ['class-il', 'task-il']
[docs] @staticmethod def get_parser(parser) -> ArgumentParser: parser.set_defaults(lr=0.001, optimizer='adam', optim_mom=0.9) parser.add_argument('--mu', type=float, default=0.0, help='weight of prompt loss') parser.add_argument('--pool_size', type=int, default=100, help='pool size') parser.add_argument('--prompt_len', type=int, default=8, help='prompt length') parser.add_argument('--virtual_bs_iterations', '--virtual_bs_n', dest='virtual_bs_iterations', type=int, default=1, help="virtual batch size iterations") return parser
def __init__(self, backbone, loss, args, transform, dataset=None): del backbone print("-" * 20) logging.info(f"CODA-Prompt USES A CUSTOM BACKBONE: `vit_base_patch16_224`.") print("Pretrained on Imagenet 21k and finetuned on ImageNet 1k.") print("-" * 20) if args.lr_scheduler is not None: logging.info("CODA-Prompt uses a custom scheduler: cosine. Ignoring --lr_scheduler.") self.dataset = get_dataset(args) self.n_classes = self.dataset.N_CLASSES self.n_tasks = self.dataset.N_TASKS backbone = Model(num_classes=self.n_classes, pt=True, prompt_param=[self.n_tasks, [args.pool_size, args.prompt_len, 0]]) super().__init__(backbone, loss, args, transform, dataset=dataset) self.net.task_id = 0 self.opt = self.get_optimizer()
[docs] def get_optimizer(self): params_to_opt = list(self.net.prompt.parameters()) + list(self.net.last.parameters()) optimizer_arg = {'params': params_to_opt, 'lr': self.args.lr, 'weight_decay': self.args.optim_wd} if self.args.optimizer == 'sgd': opt = torch.optim.SGD(**optimizer_arg) elif self.args.optimizer == 'adam': opt = torch.optim.Adam(**optimizer_arg) else: raise ValueError('Optimizer not supported for this method') return opt
[docs] def begin_task(self, dataset): self.offset_1, self.offset_2 = self.dataset.get_offsets(self.current_task) if self.current_task != 0: self.net.task_id = self.current_task self.net.prompt.process_task_count() self.opt = self.get_optimizer() self.scheduler = CosineSchedule(self.opt, K=self.args.n_epochs)
[docs] def observe(self, inputs, labels, not_aug_inputs, epoch=0): labels = labels.long() self.opt.zero_grad() logits, loss_prompt = self.net(inputs, train=True) loss_prompt = loss_prompt.sum() logits = logits[:, :self.offset_2] logits[:, :self.offset_1] = -float('inf') loss_ce = self.loss(logits, labels) loss = loss_ce + self.args.mu * loss_prompt if self.task_iteration == 0: self.opt.zero_grad() torch.cuda.empty_cache() (loss / float(self.args.virtual_bs_iterations)).backward() if self.task_iteration > 0 and self.task_iteration % self.args.virtual_bs_iterations == 0: self.opt.step() self.opt.zero_grad() return loss.item()
[docs] def forward(self, x): return self.net(x)[:, :self.offset_2]