Source code for models.ranpac_utils.vit
# --------------------------------------------------------
# References:
# https://github.com/jxhe/unify-parameter-efficient-tuning
# --------------------------------------------------------
import math
import torch
import torch.nn as nn
from timm.layers import DropPath
from functools import partial
from collections import OrderedDict
import torch
import torch.nn as nn
from timm.models.vision_transformer import PatchEmbed
from collections import OrderedDict
import torch
[docs]
class Adapter(nn.Module):
def __init__(self,
config=None,
d_model=None,
bottleneck=None,
dropout=0.0,
init_option="bert",
adapter_scalar="1.0",
adapter_layernorm_option="in"):
super().__init__()
self.n_embd = config.d_model if d_model is None else d_model
self.down_size = config.attn_bn if bottleneck is None else bottleneck
# _before
self.adapter_layernorm_option = adapter_layernorm_option
self.adapter_layer_norm_before = None
if adapter_layernorm_option == "in" or adapter_layernorm_option == "out":
self.adapter_layer_norm_before = nn.LayerNorm(self.n_embd)
if adapter_scalar == "learnable_scalar":
self.scale = nn.Parameter(torch.ones(1))
else:
self.scale = float(adapter_scalar)
self.down_proj = nn.Linear(self.n_embd, self.down_size)
self.non_linear_func = nn.ReLU()
self.up_proj = nn.Linear(self.down_size, self.n_embd)
self.dropout = dropout
if init_option == "bert":
raise NotImplementedError
elif init_option == "lora":
with torch.no_grad():
nn.init.kaiming_uniform_(self.down_proj.weight, a=math.sqrt(5))
nn.init.zeros_(self.up_proj.weight)
nn.init.zeros_(self.down_proj.bias)
nn.init.zeros_(self.up_proj.bias)
[docs]
def forward(self, x, add_residual=True, residual=None):
residual = x if residual is None else residual
if self.adapter_layernorm_option == 'in':
x = self.adapter_layer_norm_before(x)
down = self.down_proj(x)
down = self.non_linear_func(down)
down = nn.functional.dropout(down, p=self.dropout, training=self.training)
up = self.up_proj(down)
up = up * self.scale
if self.adapter_layernorm_option == 'out':
up = self.adapter_layer_norm_before(up)
if add_residual:
output = up + residual
else:
output = up
return output
[docs]
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.,):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.q_proj = nn.Linear(dim, dim, bias=qkv_bias)
self.v_proj = nn.Linear(dim, dim, bias=qkv_bias)
self.k_proj = nn.Linear(dim, dim, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
[docs]
def forward(self, x):
B, N, C = x.shape
q = self.q_proj(x)
k = self._shape(self.k_proj(x), -1, B).view(B * self.num_heads, -1, self.head_dim)
v = self._shape(self.v_proj(x), -1, B).view(B * self.num_heads, -1, self.head_dim)
q = self._shape(q, N, B).view(B * self.num_heads, -1, self.head_dim)
# attn = (q @ k.transpose(-2, -1)) * self.scale
attn_weights = torch.bmm(q, k.transpose(1, 2)) * self.scale
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_probs = self.attn_drop(attn_weights)
attn_output = torch.bmm(attn_probs, v)
attn_output = attn_output.view(B, self.num_heads, N, self.head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(B, N, C)
x = self.proj(attn_output)
x = self.proj_drop(x)
return x
[docs]
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, config=None, layer_id=None):
super().__init__()
self.config = config
self.norm1 = norm_layer(dim)
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.fc1 = nn.Linear(dim, mlp_hidden_dim)
self.fc2 = nn.Linear(mlp_hidden_dim, dim)
self.act = act_layer()
self.mlp_drop = nn.Dropout(drop)
if config.ffn_adapt:
self.adaptmlp = Adapter(self.config, dropout=0.1, bottleneck=config.ffn_num,
init_option=config.ffn_adapter_init_option,
adapter_scalar=config.ffn_adapter_scalar,
adapter_layernorm_option=config.ffn_adapter_layernorm_option,
)
[docs]
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
if self.config.ffn_adapt and self.config.ffn_option == 'parallel':
adapt_x = self.adaptmlp(x, add_residual=False)
residual = x
x = self.mlp_drop(self.act(self.fc1(self.norm2(x))))
x = self.drop_path(self.mlp_drop(self.fc2(x)))
if self.config.ffn_adapt:
if self.config.ffn_option == 'sequential':
x = self.adaptmlp(x)
elif self.config.ffn_option == 'parallel':
x = x + adapt_x
else:
raise ValueError(self.config.ffn_adapt)
x = residual + x
return x
[docs]
class VisionTransformer(nn.Module):
""" Vision Transformer with support for global average pooling
"""
def __init__(self, global_pool=False, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None,
act_layer=None, weight_init='', tuning_config=None):
super().__init__()
# print("I'm using ViT with adapters.")
self.tuning_config = tuning_config
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_tokens = 2 if distilled else 1
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
act_layer = act_layer or nn.GELU
self.patch_embed = embed_layer(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.Sequential(*[
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer,
config=tuning_config, layer_id=i,
)
for i in range(depth)])
self.norm = norm_layer(embed_dim)
# Representation layer
if representation_size and not distilled:
self.num_features = representation_size
self.pre_logits = nn.Sequential(OrderedDict([
('fc', nn.Linear(embed_dim, representation_size)),
('act', nn.Tanh())
]))
else:
self.pre_logits = nn.Identity()
# Classifier head(s)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.head_dist = None
if distilled:
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
# self.init_weights(weight_init)
######### MAE begins ############
self.global_pool = global_pool
if self.global_pool:
self.fc_norm = norm_layer(embed_dim)
del self.norm # remove the original norm
######## Adapter begins #########
if tuning_config.vpt_on:
assert tuning_config.vpt_num > 0, tuning_config.vpt_num
# properly registered
self.embeddings = nn.ParameterList( # batch, num_prompt, embed_dim
[nn.Parameter(torch.empty(1, self.tuning_config.vpt_num, embed_dim)) for _ in
range(depth)])
for eee in self.embeddings:
torch.nn.init.xavier_uniform_(eee.data)
[docs]
def get_classifier(self):
if self.dist_token is None:
return self.head
else:
return self.head, self.head_dist
[docs]
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
if self.num_tokens == 2:
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
[docs]
def forward_features(self, x):
B = x.shape[0]
x = self.patch_embed(x)
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed
x = self.pos_drop(x)
for idx, blk in enumerate(self.blocks):
if self.tuning_config.vpt_on:
eee = self.embeddings[idx].expand(B, -1, -1)
x = torch.cat([eee, x], dim=1)
x = blk(x)
if self.tuning_config.vpt_on:
x = x[:, self.tuning_config.vpt_num:, :]
if self.global_pool:
x = x[:, 1:, :].mean(dim=1) # global pool without cls token
outcome = self.fc_norm(x)
else:
x = self.norm(x)
outcome = x[:, 0]
return outcome
[docs]
def forward(self, x):
x = self.forward_features(x,)
if self.head_dist is not None:
x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple
if self.training and not torch.jit.is_scripting():
# during inference, return the average of both classifier predictions
return x, x_dist
else:
return (x + x_dist) / 2
else:
x = self.head(x)
return x