from argparse import Namespace
import torch
from torch import nn
from models.dualprompt_utils.vision_transformer import vit_base_patch16_224_dualprompt
[docs]
class Model(nn.Module):
def __init__(self, args: Namespace, n_classes: int):
super().__init__()
self.n_classes = n_classes
self.original_model = vit_base_patch16_224_dualprompt(
pretrained=args.pretrained,
num_classes=n_classes,
drop_rate=args.drop,
drop_path_rate=args.drop_path,
)
self.original_model.eval()
self.model = vit_base_patch16_224_dualprompt(
pretrained=args.pretrained,
num_classes=n_classes,
drop_rate=args.drop,
drop_path_rate=args.drop_path,
prompt_length=args.length,
embedding_key=args.embedding_key,
prompt_init=args.prompt_key_init,
prompt_pool=args.prompt_pool,
prompt_key=args.prompt_key,
pool_size=args.size,
top_k=args.top_k,
batchwise_prompt=args.batchwise_prompt,
prompt_key_init=args.prompt_key_init,
head_type=args.head_type,
use_prompt_mask=args.use_prompt_mask,
use_g_prompt=args.use_g_prompt,
g_prompt_length=args.g_prompt_length,
g_prompt_layer_idx=args.g_prompt_layer_idx,
use_prefix_tune_for_g_prompt=args.use_prefix_tune_for_g_prompt,
use_e_prompt=args.use_e_prompt,
e_prompt_layer_idx=args.e_prompt_layer_idx,
use_prefix_tune_for_e_prompt=args.use_prefix_tune_for_e_prompt,
same_key_value=args.same_key_value,
)
if args.freeze:
for p in self.original_model.parameters():
p.requires_grad = False
for n, p in self.model.named_parameters():
if n.startswith(tuple(args.freeze)):
p.requires_grad = False
[docs]
def forward(self, x, task_id, train=False, return_outputs=False):
with torch.no_grad():
if self.original_model is not None:
original_model_output = self.original_model(x)
cls_features = original_model_output['pre_logits']
else:
cls_features = None
outputs = self.model(x, task_id=task_id, cls_features=cls_features, train=train)
if return_outputs:
return outputs
else:
return outputs['logits']