Source code for models.l2p

"""
L2P: Learning to Prompt for Continual Learning

Note:
    L2P USES A CUSTOM BACKBONE: `vit_base_patch16_224`.
    The backbone is a ViT-B/16 pretrained on Imagenet 21k and finetuned on ImageNet 1k.
"""

import torch

from models.utils.continual_model import ContinualModel
from utils.args import ArgumentParser
from timm import create_model  # noqa
from models.l2p_utils.l2p_model import L2PModel


[docs] class L2P(ContinualModel): """Learning to Prompt (L2P).""" NAME = 'l2p' COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual']
[docs] @staticmethod def get_parser(parser) -> ArgumentParser: parser.set_defaults(optimizer='adam') # Prompt parameters parser.add_argument('--prompt_pool', default=True, type=bool,) parser.add_argument('--pool_size_l2p', default=10, type=int, help='number of prompts (M in paper)') parser.add_argument('--length', default=5, type=int, help='length of prompt (L_p in paper)') parser.add_argument('--top_k', default=5, type=int, help='top k prompts to use (N in paper)') parser.add_argument('--prompt_key', default=True, type=bool, help='Use learnable prompt key') parser.add_argument('--prompt_key_init', default='uniform', type=str, help='initialization type for key\'s prompts') parser.add_argument('--use_prompt_mask', default=False, type=bool) parser.add_argument('--batchwise_prompt', default=True, type=bool) parser.add_argument('--embedding_key', default='cls', type=str) parser.add_argument('--predefined_key', default='', type=str) parser.add_argument('--pull_constraint', default=True) parser.add_argument('--pull_constraint_coeff', default=0.1, type=float) parser.add_argument('--global_pool', default='token', choices=['token', 'avg'], type=str, help='type of global pooling for final sequence') parser.add_argument('--head_type', default='prompt', choices=['token', 'gap', 'prompt', 'token+prompt'], type=str, help='input type of classification head') parser.add_argument('--freeze', default=['blocks', 'patch_embed', 'cls_token', 'norm', 'pos_embed'], nargs='*', type=list, help='freeze part in backbone model') # Learning rate schedule parameters parser.add_argument('--sched', default='constant', type=str, metavar='SCHEDULER', help='LR scheduler (default: "constant"') parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', help='learning rate noise on/off epoch percentages') parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', help='learning rate noise limit percent (default: 0.67)') parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', help='learning rate noise std-dev (default: 1.0)') parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR', help='warmup learning rate (default: 1e-6)') parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') parser.add_argument('--decay-epochs', type=float, default=30, metavar='N', help='epoch interval to decay LR') parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N', help='epochs to warmup LR, if scheduler supports') parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', help='epochs to cooldown LR at min_lr, after cyclic schedule ends') parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', help='patience epochs for Plateau LR scheduler (default: 10') parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', help='LR decay rate (default: 0.1)') parser.add_argument('--unscale_lr', type=bool, default=True, help='scaling lr by batch size (default: True)') parser.add_argument('--clip_grad', type=float, default=1, help='Clip gradient norm') return parser
def __init__(self, backbone, loss, args, transform, dataset=None): """ L2P re-defines the backbone model to include the prompt parameters. This is done *before* calling the super constructor, so that the backbone is already initialized when the super constructor is called. """ del backbone print("-" * 20) print(f"WARNING: L2P USES A CUSTOM BACKBONE: `vit_base_patch16_224`.") print("Pretrained on Imagenet 21k and finetuned on ImageNet 1k.") print("-" * 20) args.lr = args.lr * args.batch_size / 256.0 backbone = L2PModel(args) super().__init__(backbone, loss, args, transform, dataset=dataset)
[docs] def begin_task(self, dataset): self.net.original_model.eval() if hasattr(self, 'opt'): self.opt.zero_grad(set_to_none=True) del self.opt self.opt = self.get_optimizer()
[docs] def observe(self, inputs, labels, not_aug_inputs, epoch=None): outputs = self.net(inputs, return_outputs=True) logits = outputs['logits'] # here is the trick to mask out classes of non-current tasks logits[:, :self.n_past_classes] = -float('inf') loss = self.loss(logits[:, :self.n_seen_classes], labels) if self.args.pull_constraint and 'reduce_sim' in outputs: loss = loss - self.args.pull_constraint_coeff * outputs['reduce_sim'] self.opt.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(self.get_parameters(), self.args.clip_grad) self.opt.step() return loss.item()
[docs] def get_parameters(self): return [p for n, p in self.net.model.named_parameters() if 'prompt' in n or 'head' in n]
[docs] def forward(self, x): return self.net(x)[:, :self.n_seen_classes]