Source code for models.coda_prompt_utils.vit

'''
 * Based on vit from blip code base
 * https://github.com/salesforce/BLIP
'''

import torch
import torch.nn as nn
import torch.nn.functional as F

from timm.layers import trunc_normal_, DropPath

from backbone.vit import Mlp, VisionTransformer as MammothVP


[docs] class Attention(nn.Module): def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights self.scale = qk_scale or head_dim ** -0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) self.attn_gradients = None self.attention_map = None
[docs] def save_attn_gradients(self, attn_gradients): self.attn_gradients = attn_gradients
[docs] def get_attn_gradients(self): return self.attn_gradients
[docs] def save_attention_map(self, attention_map): self.attention_map = attention_map
[docs] def get_attention_map(self): return self.attention_map
[docs] def forward(self, x, prompt=None): B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) if prompt is not None: pk, pv = prompt pk = pk.reshape(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) pv = pv.reshape(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) k = torch.cat((pk, k), dim=2) v = torch.cat((pv, v), dim=2) if torch.__version__ >= '2.1.0': x = F.scaled_dot_product_attention(q, k, v, scale=self.scale, dropout_p=self.attn_drop.p) else: attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v) x = x.transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x
[docs] class Block(nn.Module): def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): super().__init__() self.norm1 = norm_layer(dim) self.attn = Attention( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
[docs] def forward(self, x, prompt=None): x = x + self.drop_path(self.attn(self.norm1(x), prompt=prompt)) x = x + self.drop_path(self.mlp(self.norm2(x))) return x
[docs] class VisionTransformer(MammothVP): def __init__(self, qk_scale=None, args=None, **kwargs): super().__init__(args=args, **kwargs) num_patches = self.patch_embed.num_patches self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, self.embed_dim)) self.blocks = nn.ModuleList([ Block( dim=self.embed_dim, num_heads=self.num_heads, mlp_ratio=self.mlp_ratio, qkv_bias=self.qkv_bias, drop=self.pos_drop.p, attn_drop=self.attn_drop_rate, drop_path=self.dpr[i], norm_layer=self.norm_layer, act_layer=self.act_layer ) for i in range(self.depth)]) trunc_normal_(self.pos_embed, std=.02) trunc_normal_(self.cls_token, std=.02) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0)
[docs] def forward(self, x, prompt=None, q=None, train=False, task_id=None): B = x.shape[0] x = self.patch_embed(x) cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks x = torch.cat((cls_tokens, x), dim=1) x = x + self.pos_embed[:, :x.size(1), :] x = self.pos_drop(x) prompt_loss = torch.zeros((1,), requires_grad=True).to(x.device) for i, blk in enumerate(self.blocks): if prompt is not None: if train: p_list, loss, x = prompt.forward(q, i, x, train=True, task_id=task_id) prompt_loss += loss else: p_list, _, x = prompt.forward(q, i, x, train=False, task_id=task_id) else: p_list = None x = blk(x, prompt=p_list) x = self.norm(x) return x, prompt_loss