import os
import numpy as np
import torch
import torch.nn as nn
from datasets import get_dataset
from models.l2p_utils.vit_prompt import vit_base_patch16_224_l2p
[docs]
class L2PModel(nn.Module):
    def __init__(self, args):
        super().__init__()
        dataset = get_dataset(args)
        n_classes = dataset.N_CLASSES
        self.original_model = vit_base_patch16_224_l2p(
            pretrained=True,
            num_classes=n_classes,
            drop_rate=0.0,
            drop_path_rate=0.0,
        )
        self.original_model.eval()
        self.model = vit_base_patch16_224_l2p(
            pretrained=True,
            num_classes=n_classes,
            prompt_length=args.length,
            embedding_key=args.embedding_key,
            prompt_init=args.prompt_key_init,
            prompt_pool=args.prompt_pool,
            prompt_key=args.prompt_key,
            pool_size=args.pool_size_l2p,
            top_k=args.top_k,
            batchwise_prompt=args.batchwise_prompt,
            prompt_key_init=args.prompt_key_init,
            head_type=args.head_type,
            use_prompt_mask=args.use_prompt_mask,
        )
        if args.use_original_ckpt:
            # download ckpt from https://storage.googleapis.com/vit_models/imagenet21k/ViT-B_16.npz
            if not os.path.exists('./data/imagenet21k_ViT-B_16.npz'):
                os.system('wget https://storage.googleapis.com/vit_models/imagenet21k/ViT-B_16.npz -P ./data/')
            lf = np.load('data/imagenet21k_ViT-B_16.npz')
            ckpt = {k: lf[k] for k in lf.files}
            def translate_name(name):
                name = name.replace('Transformer/', '')
                name = name.replace('encoderblock_', 'blocks.')
                name = name.replace('/', '.')
                name = name.replace('LayerNorm_', 'norm')
                name = name.replace('norm0', 'norm1')
                name = name.replace('MlpBlock_3', 'mlp')
                name = name.replace('Dense_0', 'fc1')
                name = name.replace('Dense_1', 'fc2')
                name = name.replace('MultiHeadDotProductAttention_1', 'attn')
                name = name.replace('kernel', 'weight')
                name = name.replace('out', 'proj')
                name = name.replace('scale', 'weight')
                name = name.replace('cls', 'cls_token')
                name = name.replace('embedding', 'patch_embed.proj')
                name = name.replace('posembed_input.pos_patch_embed.proj', 'pos_embed')
                name = name.replace('encoder_norm', 'norm')
                return name
            ckpt = {translate_name(k): v for k, v in ckpt.items()}
            for block_id in range(12):
                # convert qkv
                q = ckpt[f'blocks.{block_id}.attn.query.weight'].reshape(768, -1)
                k = ckpt[f'blocks.{block_id}.attn.key.weight'].reshape(768, -1)
                v = ckpt[f'blocks.{block_id}.attn.value.weight'].reshape(768, -1)
                qkv = np.concatenate([q, k, v], axis=1)
                ckpt[f'blocks.{block_id}.attn.qkv.weight'] = qkv
                ckpt.pop(f'blocks.{block_id}.attn.query.weight')
                ckpt.pop(f'blocks.{block_id}.attn.key.weight')
                ckpt.pop(f'blocks.{block_id}.attn.value.weight')
                q = ckpt[f'blocks.{block_id}.attn.query.bias'].reshape(-1)
                k = ckpt[f'blocks.{block_id}.attn.key.bias'].reshape(-1)
                v = ckpt[f'blocks.{block_id}.attn.value.bias'].reshape(-1)
                qkv = np.concatenate([q, k, v], axis=0)
                ckpt[f'blocks.{block_id}.attn.qkv.bias'] = qkv
                ckpt.pop(f'blocks.{block_id}.attn.query.bias')
                ckpt.pop(f'blocks.{block_id}.attn.key.bias')
                ckpt.pop(f'blocks.{block_id}.attn.value.bias')
                # permute
                ckpt[f'blocks.{block_id}.mlp.fc1.weight'] = ckpt[f'blocks.{block_id}.mlp.fc1.weight'].T
                ckpt[f'blocks.{block_id}.mlp.fc2.weight'] = ckpt[f'blocks.{block_id}.mlp.fc2.weight'].T
                ckpt[f'blocks.{block_id}.attn.qkv.weight'] = ckpt[f'blocks.{block_id}.attn.qkv.weight'].T
                ckpt[f'blocks.{block_id}.attn.proj.weight'] = ckpt[f'blocks.{block_id}.attn.proj.weight'].reshape(-1, 768).T
            ckpt['patch_embed.proj.weight'] = ckpt['patch_embed.proj.weight'].transpose(-1, -2, -4, -3)
            # remove head
            del ckpt['head.weight']
            del ckpt['head.bias']
            del ckpt['pre_logits.weight']
            del ckpt['pre_logits.bias']
            # convert to torch
            ckpt = {k: torch.from_numpy(v) for k, v in ckpt.items()}
            unexpected, missing = self.original_model.load_state_dict(ckpt, strict=False)
            assert len([x for x in missing if 'head' not in x]) == 0, f"Missing keys: {missing}"
            assert len([x for x in unexpected if 'head' not in x]) == 0, f"Unexpected keys: {unexpected}"
            # extend pos_embed for the prompts
            ckpt['pos_embed'] = torch.cat((ckpt['pos_embed'], self.model.pos_embed[:, ckpt['pos_embed'].shape[1]:]), dim=1)
            unexpected, missing = self.model.load_state_dict(ckpt, strict=False)
            assert len([x for x in missing if 'prompt' not in x and 'head' not in x]) == 0, f"Missing keys: {missing}"
            assert len([x for x in unexpected if 'prompt' not in x and 'head' not in x]) == 0, f"Unexpected keys: {unexpected}"
        if args.freeze:
            # all parameters are frozen for original vit model
            for p in self.original_model.parameters():
                p.requires_grad = False
            # freeze args.freeze[blocks, patch_embed, cls_token] parameters
            for n, p in self.model.named_parameters():
                if n.startswith(tuple(args.freeze)):
                    p.requires_grad = False
[docs]
    def forward(self, x, return_reduce_sim_loss=False):
        with torch.no_grad():
            if self.original_model is not None:
                original_model_output = self.original_model(x)
                cls_features = original_model_output['pre_logits']
            else:
                cls_features = None
        outputs = self.model(x, task_id=-1, cls_features=cls_features, train=self.training)
        logits = outputs['logits']
        reduce_sim = outputs['reduce_sim'] if 'reduce_sim' in outputs else None
        if return_reduce_sim_loss:
            return {'logits': logits, 'reduce_sim': reduce_sim}
        else:
            return logits