Source code for models.ranpac_utils.vit

# --------------------------------------------------------
# References:
# https://github.com/jxhe/unify-parameter-efficient-tuning
# --------------------------------------------------------

import math
import torch
import torch.nn as nn
from timm.layers import DropPath
from functools import partial
from collections import OrderedDict
import torch
import torch.nn as nn
from timm.models.vision_transformer import PatchEmbed

from collections import OrderedDict
import torch


[docs] class Adapter(nn.Module): def __init__(self, config=None, d_model=None, bottleneck=None, dropout=0.0, init_option="bert", adapter_scalar="1.0", adapter_layernorm_option="in"): super().__init__() self.n_embd = config.d_model if d_model is None else d_model self.down_size = config.attn_bn if bottleneck is None else bottleneck # _before self.adapter_layernorm_option = adapter_layernorm_option self.adapter_layer_norm_before = None if adapter_layernorm_option == "in" or adapter_layernorm_option == "out": self.adapter_layer_norm_before = nn.LayerNorm(self.n_embd) if adapter_scalar == "learnable_scalar": self.scale = nn.Parameter(torch.ones(1)) else: self.scale = float(adapter_scalar) self.down_proj = nn.Linear(self.n_embd, self.down_size) self.non_linear_func = nn.ReLU() self.up_proj = nn.Linear(self.down_size, self.n_embd) self.dropout = dropout if init_option == "bert": raise NotImplementedError elif init_option == "lora": with torch.no_grad(): nn.init.kaiming_uniform_(self.down_proj.weight, a=math.sqrt(5)) nn.init.zeros_(self.up_proj.weight) nn.init.zeros_(self.down_proj.bias) nn.init.zeros_(self.up_proj.bias)
[docs] def forward(self, x, add_residual=True, residual=None): residual = x if residual is None else residual if self.adapter_layernorm_option == 'in': x = self.adapter_layer_norm_before(x) down = self.down_proj(x) down = self.non_linear_func(down) down = nn.functional.dropout(down, p=self.dropout, training=self.training) up = self.up_proj(down) up = up * self.scale if self.adapter_layernorm_option == 'out': up = self.adapter_layer_norm_before(up) if add_residual: output = up + residual else: output = up return output
[docs] class Attention(nn.Module): def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.,): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.head_dim = dim // num_heads self.scale = head_dim ** -0.5 self.q_proj = nn.Linear(dim, dim, bias=qkv_bias) self.v_proj = nn.Linear(dim, dim, bias=qkv_bias) self.k_proj = nn.Linear(dim, dim, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
[docs] def forward(self, x): B, N, C = x.shape q = self.q_proj(x) k = self._shape(self.k_proj(x), -1, B).view(B * self.num_heads, -1, self.head_dim) v = self._shape(self.v_proj(x), -1, B).view(B * self.num_heads, -1, self.head_dim) q = self._shape(q, N, B).view(B * self.num_heads, -1, self.head_dim) # attn = (q @ k.transpose(-2, -1)) * self.scale attn_weights = torch.bmm(q, k.transpose(1, 2)) * self.scale attn_weights = nn.functional.softmax(attn_weights, dim=-1) attn_probs = self.attn_drop(attn_weights) attn_output = torch.bmm(attn_probs, v) attn_output = attn_output.view(B, self.num_heads, N, self.head_dim) attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(B, N, C) x = self.proj(attn_output) x = self.proj_drop(x) return x
[docs] class Block(nn.Module): def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, config=None, layer_id=None): super().__init__() self.config = config self.norm1 = norm_layer(dim) self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.fc1 = nn.Linear(dim, mlp_hidden_dim) self.fc2 = nn.Linear(mlp_hidden_dim, dim) self.act = act_layer() self.mlp_drop = nn.Dropout(drop) if config.ffn_adapt: self.adaptmlp = Adapter(self.config, dropout=0.1, bottleneck=config.ffn_num, init_option=config.ffn_adapter_init_option, adapter_scalar=config.ffn_adapter_scalar, adapter_layernorm_option=config.ffn_adapter_layernorm_option, )
[docs] def forward(self, x): x = x + self.drop_path(self.attn(self.norm1(x))) if self.config.ffn_adapt and self.config.ffn_option == 'parallel': adapt_x = self.adaptmlp(x, add_residual=False) residual = x x = self.mlp_drop(self.act(self.fc1(self.norm2(x)))) x = self.drop_path(self.mlp_drop(self.fc2(x))) if self.config.ffn_adapt: if self.config.ffn_option == 'sequential': x = self.adaptmlp(x) elif self.config.ffn_option == 'parallel': x = x + adapt_x else: raise ValueError(self.config.ffn_adapt) x = residual + x return x
[docs] class VisionTransformer(nn.Module): """ Vision Transformer with support for global average pooling """ def __init__(self, global_pool=False, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None, act_layer=None, weight_init='', tuning_config=None): super().__init__() # print("I'm using ViT with adapters.") self.tuning_config = tuning_config self.num_classes = num_classes self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models self.num_tokens = 2 if distilled else 1 norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) act_layer = act_layer or nn.GELU self.patch_embed = embed_layer( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) self.pos_drop = nn.Dropout(p=drop_rate) dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule self.blocks = nn.Sequential(*[ Block( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer, config=tuning_config, layer_id=i, ) for i in range(depth)]) self.norm = norm_layer(embed_dim) # Representation layer if representation_size and not distilled: self.num_features = representation_size self.pre_logits = nn.Sequential(OrderedDict([ ('fc', nn.Linear(embed_dim, representation_size)), ('act', nn.Tanh()) ])) else: self.pre_logits = nn.Identity() # Classifier head(s) self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() self.head_dist = None if distilled: self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() # self.init_weights(weight_init) ######### MAE begins ############ self.global_pool = global_pool if self.global_pool: self.fc_norm = norm_layer(embed_dim) del self.norm # remove the original norm ######## Adapter begins ######### if tuning_config.vpt_on: assert tuning_config.vpt_num > 0, tuning_config.vpt_num # properly registered self.embeddings = nn.ParameterList( # batch, num_prompt, embed_dim [nn.Parameter(torch.empty(1, self.tuning_config.vpt_num, embed_dim)) for _ in range(depth)]) for eee in self.embeddings: torch.nn.init.xavier_uniform_(eee.data)
[docs] def init_weights(self, mode=''): raise NotImplementedError()
[docs] @torch.jit.ignore def no_weight_decay(self): return {'pos_embed', 'cls_token', 'dist_token'}
[docs] def get_classifier(self): if self.dist_token is None: return self.head else: return self.head, self.head_dist
[docs] def reset_classifier(self, num_classes, global_pool=''): self.num_classes = num_classes self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() if self.num_tokens == 2: self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
[docs] def forward_features(self, x): B = x.shape[0] x = self.patch_embed(x) cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks x = torch.cat((cls_tokens, x), dim=1) x = x + self.pos_embed x = self.pos_drop(x) for idx, blk in enumerate(self.blocks): if self.tuning_config.vpt_on: eee = self.embeddings[idx].expand(B, -1, -1) x = torch.cat([eee, x], dim=1) x = blk(x) if self.tuning_config.vpt_on: x = x[:, self.tuning_config.vpt_num:, :] if self.global_pool: x = x[:, 1:, :].mean(dim=1) # global pool without cls token outcome = self.fc_norm(x) else: x = self.norm(x) outcome = x[:, 0] return outcome
[docs] def forward(self, x): x = self.forward_features(x,) if self.head_dist is not None: x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple if self.training and not torch.jit.is_scripting(): # during inference, return the average of both classifier predictions return x, x_dist else: return (x + x_dist) / 2 else: x = self.head(x) return x