Source code for models.ranpac_utils.inc_net

from argparse import Namespace
import copy
from functools import partial
import math
import torch
from torch import nn
import torch.nn.functional as F
from backbone import MammothBackbone
from backbone.vit import vit_base_patch16_224_prompt_prototype
from models.ranpac_utils.vit import VisionTransformer


[docs] class RanPACNet(MammothBackbone): def __init__(self, backbone: MammothBackbone): super(RanPACNet, self).__init__() self.fc = None self.device = backbone.device tuning_config = Namespace(ffn_adapt=True, ffn_option="parallel", ffn_adapter_layernorm_option="none", ffn_adapter_init_option="lora", ffn_adapter_scalar="0.1", ffn_num=64, d_model=768, vpt_on=False, vpt_num=0, ) self.convnet = VisionTransformer(patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), num_classes=0, global_pool=False, drop_path_rate=0.0, tuning_config=tuning_config) backbone.head = nn.Identity() state_dict = backbone.state_dict() for key in list(state_dict.keys()): if 'qkv.weight' in key: qkv_weight = state_dict.pop(key) q_weight = qkv_weight[:768] k_weight = qkv_weight[768:768 * 2] v_weight = qkv_weight[768 * 2:] state_dict[key.replace('qkv.weight', 'q_proj.weight')] = q_weight state_dict[key.replace('qkv.weight', 'k_proj.weight')] = k_weight state_dict[key.replace('qkv.weight', 'v_proj.weight')] = v_weight elif 'qkv.bias' in key: qkv_bias = state_dict.pop(key) q_bias = qkv_bias[:768] k_bias = qkv_bias[768:768 * 2] v_bias = qkv_bias[768 * 2:] state_dict[key.replace('qkv.bias', 'q_proj.bias')] = q_bias state_dict[key.replace('qkv.bias', 'k_proj.bias')] = k_bias state_dict[key.replace('qkv.bias', 'v_proj.bias')] = v_bias # second, modify the mlp.fc.weight to match fc.weight for key in list(state_dict.keys()): if 'mlp.fc' in key: fc_weight = state_dict.pop(key) state_dict[key.replace('mlp.', '')] = fc_weight missing, unexpected = self.convnet.load_state_dict(state_dict, strict=False) assert len([m for m in missing if 'adaptmlp' not in m]) == 0, f"Missing keys: {missing}" assert len(unexpected) == 0, f"Unexpected keys: {unexpected}" # freeze all but the adapter for name, p in self.convnet.named_parameters(): if name in missing: p.requires_grad = True else: p.requires_grad = False self.convnet.eval()
[docs] def update_fc(self, nb_classes): fc = CosineLinear(self.convnet.embed_dim, nb_classes).to(self.device) if self.fc is not None: nb_output = self.fc.out_features weight = copy.deepcopy(self.fc.weight.data) fc.sigma.data = self.fc.sigma.data weight = torch.cat([weight, torch.zeros(nb_classes - nb_output, self.convnet.embed_dim).to(self.device)]) fc.weight = nn.Parameter(weight) del self.fc self.fc = fc
[docs] def forward(self, x): x = self.convnet(x) out = self.fc(x) return out
[docs] class CosineLinear(nn.Module): def __init__(self, in_features, out_features, nb_proxy=1, sigma=True): super(CosineLinear, self).__init__() self.in_features = in_features self.out_features = out_features * nb_proxy self.nb_proxy = nb_proxy self.weight = nn.Parameter(torch.Tensor(self.out_features, in_features)) if sigma: self.sigma = nn.Parameter(torch.Tensor(1)) else: self.register_parameter('sigma', None) self.reset_parameters() self.use_RP = False
[docs] def reset_parameters(self): stdv = 1. / math.sqrt(self.weight.size(1)) self.weight.data.uniform_(-stdv, stdv) if self.sigma is not None: self.sigma.data.fill_(1)
[docs] def forward(self, input): if not self.use_RP: out = F.linear(F.normalize(input, p=2, dim=1), F.normalize(self.weight, p=2, dim=1)) else: if self.W_rand is not None: inn = torch.nn.functional.relu(input @ self.W_rand) else: inn = input # inn=torch.bmm(input[:,0:100].unsqueeze(-1), input[:,0:100].unsqueeze(-2)).flatten(start_dim=1) #interaction terms instead of RP out = F.linear(inn, self.weight) if self.sigma is not None: out = self.sigma * out return {'logits': out}