"""
DualPrompt: Complementary Prompting for Rehearsal-free Continual Learning
Note:
WARNING: DualPrompt USES A CUSTOM BACKBONE: `vit_base_patch16_224`.
The backbone is a ViT-B/16 pretrained on Imagenet 21k and finetuned on ImageNet 1k.
"""
import logging
import torch
from models.dualprompt_utils.model import Model
from models.utils.continual_model import ContinualModel
from utils.args import ArgumentParser
from datasets import get_dataset
[docs]
class DualPrompt(ContinualModel):
"""DualPrompt: Complementary Prompting for Rehearsal-free Continual Learning."""
NAME = 'dualprompt'
COMPATIBILITY = ['class-il', 'task-il']
[docs]
@staticmethod
def get_parser(parser) -> ArgumentParser:
parser.add_argument('--train_mask', default=True, type=bool, help='if using the class mask at training')
parser.add_argument('--pretrained', default=True, help='Load pretrained model or not')
parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', help='Dropout rate (default: 0.)')
parser.add_argument('--drop-path', type=float, default=0.0, metavar='PCT', help='Drop path rate (default: 0.)')
# Optimizer parameters
parser.add_argument('--clip_grad', type=float, default=1.0, metavar='NORM', help='Clip gradient norm (default: None, no clipping)')
# G-Prompt parameters
parser.add_argument('--use_g_prompt', default=True, type=bool, help='if using G-Prompt')
parser.add_argument('--g_prompt_length', default=5, type=int, help='length of G-Prompt')
parser.add_argument('--g_prompt_layer_idx', default=[0, 1], type=int, nargs="+", help='the layer index of the G-Prompt')
parser.add_argument('--use_prefix_tune_for_g_prompt', default=True, type=bool, help='if using the prefix tune for G-Prompt')
# E-Prompt parameters
parser.add_argument('--use_e_prompt', default=True, type=bool, help='if using the E-Prompt')
parser.add_argument('--e_prompt_layer_idx', default=[2, 3, 4], type=int, nargs="+", help='the layer index of the E-Prompt')
parser.add_argument('--use_prefix_tune_for_e_prompt', default=True, type=bool, help='if using the prefix tune for E-Prompt')
# Use prompt pool in L2P to implement E-Prompt
parser.add_argument('--prompt_pool', default=True, type=bool,)
parser.add_argument('--size', default=10, type=int,)
parser.add_argument('--length', default=5, type=int, )
parser.add_argument('--top_k', default=1, type=int, )
parser.add_argument('--initializer', default='uniform', type=str,)
parser.add_argument('--prompt_key', default=True, type=bool,)
parser.add_argument('--prompt_key_init', default='uniform', type=str)
parser.add_argument('--use_prompt_mask', default=True, type=bool)
parser.add_argument('--mask_first_epoch', default=False, type=bool)
parser.add_argument('--shared_prompt_pool', default=True, type=bool)
parser.add_argument('--shared_prompt_key', default=False, type=bool)
parser.add_argument('--batchwise_prompt', default=True, type=bool)
parser.add_argument('--embedding_key', default='cls', type=str)
parser.add_argument('--predefined_key', default='', type=str)
parser.add_argument('--pull_constraint', default=True)
parser.add_argument('--pull_constraint_coeff', default=1.0, type=float)
parser.add_argument('--same_key_value', default=False, type=bool)
# ViT parameters
parser.add_argument('--global_pool', default='token', choices=['token', 'avg'], type=str, help='type of global pooling for final sequence')
parser.add_argument('--head_type', default='token', choices=['token', 'gap', 'prompt', 'token+prompt'], type=str, help='input type of classification head')
parser.add_argument('--freeze', default=['blocks', 'patch_embed', 'cls_token', 'norm', 'pos_embed'], nargs='*', type=list, help='freeze part in backbone model')
return parser
def __init__(self, backbone, loss, args, transform, dataset=None):
del backbone
print("-" * 20)
logging.info(f"DualPrompt USES A CUSTOM BACKBONE: `vit_base_patch16_224`.")
print("Pretrained on Imagenet 21k and finetuned on ImageNet 1k.")
print("-" * 20)
args.lr = args.lr * args.batch_size / 256.0
tmp_dataset = get_dataset(args) if dataset is None else dataset
backbone = Model(args, tmp_dataset.N_CLASSES)
super().__init__(backbone, loss, args, transform, dataset=dataset)
[docs]
def begin_task(self, dataset):
self.offset_1, self.offset_2 = self.dataset.get_offsets(self.current_task)
if self.current_task > 0:
prev_start = (self.current_task - 1) * self.args.top_k
prev_end = self.current_task * self.args.top_k
cur_start = prev_end
cur_end = (self.current_task + 1) * self.args.top_k
if (prev_end > self.args.size) or (cur_end > self.args.size):
pass
else:
cur_idx = (slice(None), slice(None), slice(cur_start, cur_end)) if self.args.use_prefix_tune_for_e_prompt else (slice(None), slice(cur_start, cur_end))
prev_idx = (slice(None), slice(None), slice(prev_start, prev_end)) if self.args.use_prefix_tune_for_e_prompt else (slice(None), slice(prev_start, prev_end))
with torch.no_grad():
self.net.model.e_prompt.prompt.grad.zero_()
self.net.model.e_prompt.prompt[cur_idx] = self.net.model.e_prompt.prompt[prev_idx]
self.opt.param_groups[0]['params'] = self.net.model.parameters()
self.opt = self.get_optimizer()
self.net.original_model.eval()
[docs]
def get_parameters(self):
return [p for p in self.net.model.parameters() if p.requires_grad]
[docs]
def observe(self, inputs, labels, not_aug_inputs, epoch=None):
log_dict = {}
cur_lr = self.opt.param_groups[0]['lr']
log_dict['lr'] = cur_lr
outputs = self.net(inputs, task_id=self.current_task, train=True, return_outputs=True)
logits = outputs['logits']
# here is the trick to mask out classes of non-current tasks
if self.args.train_mask:
logits[:, :self.offset_1] = -float('inf')
loss_clf = self.loss(logits[:, :self.offset_2], labels)
loss = loss_clf
if self.args.pull_constraint and 'reduce_sim' in outputs:
loss_pull_constraint = outputs['reduce_sim']
loss = loss - self.args.pull_constraint_coeff * loss_pull_constraint
self.opt.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.net.model.parameters(), self.args.clip_grad)
self.opt.step()
return loss.item()
[docs]
def forward(self, x):
res = self.net(x, task_id=-1, train=False, return_outputs=True)
logits = res['logits']
return logits[:, :self.offset_2]