"""
DualPrompt: Complementary Prompting for Rehearsal-free Continual Learning
Note:
    WARNING: DualPrompt USES A CUSTOM BACKBONE: `vit_base_patch16_224`.
    The backbone is a ViT-B/16 pretrained on Imagenet 21k and finetuned on ImageNet 1k.
"""
import logging
import torch
from models.dualprompt_utils.model import Model
from models.utils.continual_model import ContinualModel
from utils import binary_to_boolean_type
from utils.args import ArgumentParser
from datasets import get_dataset
[docs]
class DualPrompt(ContinualModel):
    """DualPrompt: Complementary Prompting for Rehearsal-free Continual Learning."""
    NAME = 'dualprompt'
    COMPATIBILITY = ['class-il', 'task-il']
[docs]
    @staticmethod
    def get_parser(parser) -> ArgumentParser:
        parser.set_defaults(optimizer='adam', batch_size=128)
        parser.add_argument('--pretrained', default=1, type=binary_to_boolean_type, help='Load pretrained model or not')
        parser.add_argument('--use_permute_fix', type=binary_to_boolean_type, default=0, help='Apply fix to reshape issue from original implementation (ref: issue #56)')
        # Optimizer parameters
        parser.add_argument('--clip_grad', type=float, default=1.0, metavar='NORM', help='Clip gradient norm (default: None, no clipping)')
        # G-Prompt parameters
        parser.add_argument('--g_prompt_length', default=5, type=int, help='length of G-Prompt')
        parser.add_argument('--g_prompt_layer_idx', default=[0, 1], type=int, nargs="+", help='the layer index of the G-Prompt')
        parser.add_argument('--use_prefix_tune_for_g_prompt', default=1, type=binary_to_boolean_type, help='if using the prefix tune for G-Prompt')
        # E-Prompt parameters
        parser.add_argument('--e_prompt_layer_idx', default=[2, 3, 4], type=int, nargs="+", help='the layer index of the E-Prompt')
        parser.add_argument('--use_prefix_tune_for_e_prompt', type=binary_to_boolean_type, default=1, help='if using the prefix tune for E-Prompt')
        # Use prompt pool in L2P to implement E-Prompt
        parser.add_argument('--size', default=10, type=int,)
        parser.add_argument('--length', default=5, type=int, )
        parser.add_argument('--top_k', default=1, type=int, )
        parser.add_argument('--initializer', default='uniform', type=str,)
        parser.add_argument('--prompt_key_init', default='uniform', type=str)
        parser.add_argument('--batchwise_prompt', default=1, type=binary_to_boolean_type, help='Use batch-wise promting? (NOTE: '
                            'This should be avoided as it is not a fair comparison with other methods.)')
        parser.add_argument('--embedding_key', default='cls', type=str)
        parser.add_argument('--predefined_key', default='', type=str)
        parser.add_argument('--pull_constraint', default=1, type=binary_to_boolean_type)
        parser.add_argument('--pull_constraint_coeff', default=1.0, type=float)
        parser.add_argument('--same_key_value', default=0, type=binary_to_boolean_type)
        # ViT parameters
        parser.add_argument('--global_pool', default='token', choices=['token', 'avg'], type=str, help='type of global pooling for final sequence')
        parser.add_argument('--head_type', default='token', choices=['token', 'gap', 'prompt', 'token+prompt'], type=str, help='input type of classification head')
        parser.add_argument('--freeze', default=['blocks', 'patch_embed', 'cls_token', 'norm', 'pos_embed'], nargs='*', type=list, help='freeze part in backbone model')
        return parser 
    def __init__(self, backbone, loss, args, transform, dataset=None):
        del backbone
        logging.info("-" * 20)
        logging.info(f"DualPrompt USES A CUSTOM BACKBONE: `https://storage.googleapis.com/vit_models/imagenet21k/ViT-B_16.npz` (vit_base_patch16_224_in21k_fn_in1k_old).")
        logging.info("Pretrained on Imagenet 21k and finetuned on ImageNet 1k.")
        logging.info("-" * 20)
        args.lr = args.lr * args.batch_size / 256.0
        tmp_dataset = get_dataset(args) if dataset is None else dataset
        backbone = Model(args, tmp_dataset.N_CLASSES)
        super().__init__(backbone, loss, args, transform, dataset=dataset)
[docs]
    def begin_task(self, dataset):
        self.offset_1, self.offset_2 = self.dataset.get_offsets(self.current_task)
        if self.current_task > 0:
            prev_start = (self.current_task - 1) * self.args.top_k
            prev_end = self.current_task * self.args.top_k
            cur_start = prev_end
            cur_end = (self.current_task + 1) * self.args.top_k
            if (prev_end > self.args.size) or (cur_end > self.args.size):
                pass
            else:
                cur_idx = (slice(None), slice(None), slice(cur_start, cur_end)) if self.args.use_prefix_tune_for_e_prompt else (slice(None), slice(cur_start, cur_end))
                prev_idx = (slice(None), slice(None), slice(prev_start, prev_end)) if self.args.use_prefix_tune_for_e_prompt else (slice(None), slice(prev_start, prev_end))
                with torch.no_grad():
                    if self.net.model.e_prompt.prompt.grad is not None:
                        self.net.model.e_prompt.prompt.grad.zero_()
                    self.net.model.e_prompt.prompt[cur_idx] = self.net.model.e_prompt.prompt[prev_idx]
                    self.opt.param_groups[0]['params'] = self.net.model.parameters()
        self.opt = self.get_optimizer()
        self.net.original_model.eval() 
[docs]
    def get_parameters(self):
        return [p for p in self.net.model.parameters() if p.requires_grad] 
[docs]
    def observe(self, inputs, labels, not_aug_inputs, epoch=None):
        log_dict = {}
        cur_lr = self.opt.param_groups[0]['lr']
        log_dict['lr'] = cur_lr
        outputs = self.net(inputs, task_id=self.current_task, train=True, return_outputs=True)
        logits = outputs['logits']
        # here is the trick to mask out classes of non-current tasks
        logits[:, :self.offset_1] = -float('inf')
        loss_clf = self.loss(logits[:, :self.offset_2], labels)
        loss = loss_clf
        if self.args.pull_constraint and 'reduce_sim' in outputs:
            loss_pull_constraint = outputs['reduce_sim']
            loss = loss - self.args.pull_constraint_coeff * loss_pull_constraint
        self.opt.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.net.model.parameters(), self.args.clip_grad)
        self.opt.step()
        return loss.item() 
[docs]
    def forward(self, x):
        res = self.net(x, task_id=-1, train=False, return_outputs=True)
        logits = res['logits']
        return logits[:, :self.offset_2]