Source code for models.l2p_utils.l2p_model

import os
import numpy as np
import torch
import torch.nn as nn
from datasets import get_dataset
from models.l2p_utils.vit_prompt import vit_base_patch16_224_l2p


[docs] class L2PModel(nn.Module): def __init__(self, args): super().__init__() dataset = get_dataset(args) n_classes = dataset.N_CLASSES self.original_model = vit_base_patch16_224_l2p( pretrained=True, num_classes=n_classes, drop_rate=0.0, drop_path_rate=0.0, ) self.original_model.eval() self.model = vit_base_patch16_224_l2p( pretrained=True, num_classes=n_classes, prompt_length=args.length, embedding_key=args.embedding_key, prompt_init=args.prompt_key_init, prompt_pool=args.prompt_pool, prompt_key=args.prompt_key, pool_size=args.pool_size_l2p, top_k=args.top_k, batchwise_prompt=args.batchwise_prompt, prompt_key_init=args.prompt_key_init, head_type=args.head_type, use_prompt_mask=args.use_prompt_mask, ) if args.use_original_ckpt: # download ckpt from https://storage.googleapis.com/vit_models/imagenet21k/ViT-B_16.npz if not os.path.exists('./data/imagenet21k_ViT-B_16.npz'): os.system('wget https://storage.googleapis.com/vit_models/imagenet21k/ViT-B_16.npz -P ./data/') lf = np.load('data/imagenet21k_ViT-B_16.npz') ckpt = {k: lf[k] for k in lf.files} def translate_name(name): name = name.replace('Transformer/', '') name = name.replace('encoderblock_', 'blocks.') name = name.replace('/', '.') name = name.replace('LayerNorm_', 'norm') name = name.replace('norm0', 'norm1') name = name.replace('MlpBlock_3', 'mlp') name = name.replace('Dense_0', 'fc1') name = name.replace('Dense_1', 'fc2') name = name.replace('MultiHeadDotProductAttention_1', 'attn') name = name.replace('kernel', 'weight') name = name.replace('out', 'proj') name = name.replace('scale', 'weight') name = name.replace('cls', 'cls_token') name = name.replace('embedding', 'patch_embed.proj') name = name.replace('posembed_input.pos_patch_embed.proj', 'pos_embed') name = name.replace('encoder_norm', 'norm') return name ckpt = {translate_name(k): v for k, v in ckpt.items()} for block_id in range(12): # convert qkv q = ckpt[f'blocks.{block_id}.attn.query.weight'].reshape(768, -1) k = ckpt[f'blocks.{block_id}.attn.key.weight'].reshape(768, -1) v = ckpt[f'blocks.{block_id}.attn.value.weight'].reshape(768, -1) qkv = np.concatenate([q, k, v], axis=1) ckpt[f'blocks.{block_id}.attn.qkv.weight'] = qkv ckpt.pop(f'blocks.{block_id}.attn.query.weight') ckpt.pop(f'blocks.{block_id}.attn.key.weight') ckpt.pop(f'blocks.{block_id}.attn.value.weight') q = ckpt[f'blocks.{block_id}.attn.query.bias'].reshape(-1) k = ckpt[f'blocks.{block_id}.attn.key.bias'].reshape(-1) v = ckpt[f'blocks.{block_id}.attn.value.bias'].reshape(-1) qkv = np.concatenate([q, k, v], axis=0) ckpt[f'blocks.{block_id}.attn.qkv.bias'] = qkv ckpt.pop(f'blocks.{block_id}.attn.query.bias') ckpt.pop(f'blocks.{block_id}.attn.key.bias') ckpt.pop(f'blocks.{block_id}.attn.value.bias') # permute ckpt[f'blocks.{block_id}.mlp.fc1.weight'] = ckpt[f'blocks.{block_id}.mlp.fc1.weight'].T ckpt[f'blocks.{block_id}.mlp.fc2.weight'] = ckpt[f'blocks.{block_id}.mlp.fc2.weight'].T ckpt[f'blocks.{block_id}.attn.qkv.weight'] = ckpt[f'blocks.{block_id}.attn.qkv.weight'].T ckpt[f'blocks.{block_id}.attn.proj.weight'] = ckpt[f'blocks.{block_id}.attn.proj.weight'].reshape(-1, 768).T ckpt['patch_embed.proj.weight'] = ckpt['patch_embed.proj.weight'].transpose(-1, -2, -4, -3) # remove head del ckpt['head.weight'] del ckpt['head.bias'] del ckpt['pre_logits.weight'] del ckpt['pre_logits.bias'] # convert to torch ckpt = {k: torch.from_numpy(v) for k, v in ckpt.items()} unexpected, missing = self.original_model.load_state_dict(ckpt, strict=False) assert len([x for x in missing if 'head' not in x]) == 0, f"Missing keys: {missing}" assert len([x for x in unexpected if 'head' not in x]) == 0, f"Unexpected keys: {unexpected}" # extend pos_embed for the prompts ckpt['pos_embed'] = torch.cat((ckpt['pos_embed'], self.model.pos_embed[:, ckpt['pos_embed'].shape[1]:]), dim=1) unexpected, missing = self.model.load_state_dict(ckpt, strict=False) assert len([x for x in missing if 'prompt' not in x and 'head' not in x]) == 0, f"Missing keys: {missing}" assert len([x for x in unexpected if 'prompt' not in x and 'head' not in x]) == 0, f"Unexpected keys: {unexpected}" if args.freeze: # all parameters are frozen for original vit model for p in self.original_model.parameters(): p.requires_grad = False # freeze args.freeze[blocks, patch_embed, cls_token] parameters for n, p in self.model.named_parameters(): if n.startswith(tuple(args.freeze)): p.requires_grad = False
[docs] def forward(self, x, return_reduce_sim_loss=False): with torch.no_grad(): if self.original_model is not None: original_model_output = self.original_model(x) cls_features = original_model_output['pre_logits'] else: cls_features = None outputs = self.model(x, task_id=-1, cls_features=cls_features, train=self.training) logits = outputs['logits'] reduce_sim = outputs['reduce_sim'] if 'reduce_sim' in outputs else None if return_reduce_sim_loss: return {'logits': logits, 'reduce_sim': reduce_sim} else: return logits