Source code for models.l2p_utils.l2p_model

import torch
import torch.nn as nn
from datasets import get_dataset
from models.l2p_utils.vit_prompt import vit_base_patch16_224_l2p


[docs] class L2PModel(nn.Module): def __init__(self, args): super().__init__() dataset = get_dataset(args) n_classes = dataset.N_CLASSES self.original_model = vit_base_patch16_224_l2p( pretrained=True, num_classes=n_classes, drop_rate=0.0, drop_path_rate=0.0, ) self.original_model.eval() self.model = vit_base_patch16_224_l2p( pretrained=True, num_classes=n_classes, prompt_length=args.length, embedding_key=args.embedding_key, prompt_init=args.prompt_key_init, prompt_pool=args.prompt_pool, prompt_key=args.prompt_key, pool_size=args.pool_size_l2p, top_k=args.top_k, batchwise_prompt=args.batchwise_prompt, prompt_key_init=args.prompt_key_init, head_type=args.head_type, use_prompt_mask=args.use_prompt_mask, ) if args.freeze: # all parameters are frozen for original vit model for p in self.original_model.parameters(): p.requires_grad = False # freeze args.freeze[blocks, patch_embed, cls_token] parameters for n, p in self.model.named_parameters(): if n.startswith(tuple(args.freeze)): p.requires_grad = False
[docs] def forward(self, x, return_outputs=False): with torch.no_grad(): if self.original_model is not None: original_model_output = self.original_model(x) cls_features = original_model_output['pre_logits'] else: cls_features = None outputs = self.model(x, task_id=-1, cls_features=cls_features, train=self.training) logits = outputs['logits'] if return_outputs: return outputs else: return logits