Source code for models.tak_utils.backbone

import torch
import torch.nn as nn
import numpy as np
from copy import deepcopy
import types
import open_clip

from datasets.seq_8vision import Sequential8Vision
from backbone import MammothBackbone
from datasets.utils.continual_dataset import ContinualDataset

try:
    import clip
    from clip.model import CLIP
except ImportError:
    raise ImportError("Please install the CLIP package by running: "
                      "pip install git+https://github.com/openai/CLIP.git")

from models.tak_utils.templates import get_templates


[docs] def create_clip(name_clip_backbone, device) -> CLIP: clip_model, _ = clip.load(name_clip_backbone, device=torch.device('cpu'), jit=False) surgery(clip_model) return clip_model.to(device)
[docs] @torch.no_grad() def surgery(clip_model: CLIP): num_blocks = len(clip_model.visual.transformer.resblocks) embed_dim = clip_model.visual.class_embedding.shape[0] for block_id in range(num_blocks): old_ma = clip_model.visual.transformer.resblocks[block_id].attn old_ma_sd = old_ma.state_dict() new_ma = MultiheadAttention(embed_dim, old_ma.num_heads, True).to('cpu') new_ma.qkv.weight.zero_() new_ma.qkv.weight.add_(old_ma_sd['in_proj_weight']) new_ma.qkv.bias.zero_() new_ma.qkv.bias.add_(old_ma_sd['in_proj_bias']) new_ma.proj.weight.zero_() new_ma.proj.weight.add_(old_ma_sd['out_proj.weight']) new_ma.proj.bias.zero_() new_ma.proj.bias.add_(old_ma_sd['out_proj.bias']) del clip_model.visual.transformer.resblocks[block_id].attn clip_model.visual.transformer.resblocks[block_id].attn = new_ma replace_visual_outproj(clip_model)
[docs] class ClsEmbedder(nn.Module): def __init__(self, class_embedding: torch.Tensor): super().__init__() self.register_parameter('class_embedding', nn.Parameter(class_embedding.clone()))
[docs] def forward(self, x: torch.Tensor): """ Forward pass that adds the class embedding to the input tensor. :param x: Input tensor of shape [*, width, grid, grid] :return: Tensor with class embedding added """ # if self.class_embedding.device != x.device: # self.class_embedding = self.class_embedding.to(x.device) x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] return x
[docs] def get_custom_forward(old_forward): def custom_visual_forward(ext, x: torch.Tensor): # NOTE: from clip/model.py x = ext.conv1(x) # shape = [*, width, grid, grid] x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] # NOTE: changed x = ext.cls_token_layer(x) # shape = [*, grid ** 2 + 1, width] # x = torch.cat([ext.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) x = x + ext.positional_embedding.to(x.dtype) x = ext.ln_pre(x) x = x.permute(1, 0, 2) # NLD -> LND x = ext.transformer(x) x = x.permute(1, 0, 2) # LND -> NLD x = ext.ln_post(x[:, 0, :]) # NOTE: changed x = ext.lin_proj(x) return x return custom_visual_forward
[docs] def replace_visual_outproj(clip_model): # replace the projection layer with a linear layer visual_proj: torch.Tensor = clip_model.visual.proj.clone() clip_model.visual.lin_proj = nn.Linear(visual_proj.shape[1], visual_proj.shape[0], bias=False) clip_model.visual.lin_proj.weight = nn.Parameter(visual_proj.T) clip_model.visual.lin_proj.requires_grad_(clip_model.visual.proj.requires_grad) clip_model.visual.proj = None del clip_model.visual.proj old_forward = clip_model.visual.forward.__func__ cls_token_layer = ClsEmbedder(clip_model.visual.class_embedding.clone()).requires_grad_(True) clip_model.visual.register_module('cls_token_layer', cls_token_layer) clip_model.visual.class_embedding = None del clip_model.visual.class_embedding clip_model.visual.forward = types.MethodType(get_custom_forward(old_forward), clip_model.visual)
[docs] class MultiheadAttention(torch.nn.Module): def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): super().__init__() assert dim % num_heads == 0, 'dim should be divisible by num_heads' self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim ** -0.5 self.qkv = torch.nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = torch.nn.Dropout(attn_drop) self.proj = torch.nn.Linear(dim, dim) self.proj_drop = torch.nn.Dropout(proj_drop)
[docs] def forward(self, query, key, value, need_weights=False, attn_mask=None): N, B, C = query.shape query = query.transpose(0, 1) qkv = self.qkv(query) qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) x = x.transpose(1, 0) return x, attn
[docs] def build_classification_head(clip_model, dataset, offset, eval=False, all_heads=False): template = get_templates(dataset.NAME) classnames = dataset.class_names device = clip_model.text_projection.device if isinstance(dataset, Sequential8Vision): classes_cumsum = np.cumsum(dataset.N_CLASSES_PER_TASK) all_templates = template template = all_templates[0] cur_task = 0 print('Building classification head.') with torch.no_grad(): zeroshot_weights = [] for class_idx, classname in enumerate(classnames): texts = [] if isinstance(dataset, Sequential8Vision): if class_idx >= classes_cumsum[cur_task]: cur_task += 1 template = all_templates[cur_task] for t in template: texts.append(t(classname)) texts = open_clip.tokenize(texts).to(device) # tokenize embeddings = clip_model.encode_text(texts) # embed with text encoder embeddings /= embeddings.norm(dim=-1, keepdim=True) embeddings = embeddings.mean(dim=0, keepdim=True) embeddings /= embeddings.norm() zeroshot_weights.append(embeddings) zeroshot_weights = torch.stack(zeroshot_weights, dim=0).to(device) zeroshot_weights = torch.transpose(zeroshot_weights, 0, 2) zeroshot_weights *= 100. zeroshot_weights = zeroshot_weights.squeeze().float() zeroshot_weights = torch.transpose(zeroshot_weights, 0, 1) if all_heads: classification_head = ClassificationHead(normalize=True, weights=zeroshot_weights) else: if eval: classification_head = ClassificationHead(normalize=True, weights=zeroshot_weights[:][:offset[1]]) else: classification_head = ClassificationHead(normalize=True, weights=zeroshot_weights[:][offset[0]:offset[1]]) classification_head.requires_grad_(False) return classification_head
[docs] class ClassificationHead(torch.nn.Linear): def __init__(self, normalize, weights, biases=None): output_size, input_size = weights.shape super().__init__(input_size, output_size) self.normalize = normalize if weights is not None: self.weight = nn.Parameter(weights.clone()) if biases is not None: self.bias = nn.Parameter(biases.clone()) else: self.bias = nn.Parameter(torch.zeros_like(self.bias, device=self.weight.device))
[docs] def forward(self, inputs): if self.normalize: inputs = inputs / inputs.norm(dim=-1, keepdim=True) return super().forward(inputs)
def __call__(self, inputs): return self.forward(inputs)
[docs] class Backbone(MammothBackbone): @torch.no_grad() def __init__(self, clip_model, dataset: ContinualDataset, args) -> None: super().__init__() self.dataset = dataset self.dtype = torch.float32 self.args = args self.visual_encoder = deepcopy(clip_model.to(dtype=torch.float32).visual) self.copy_visual_encoder(clip_model) self.classes = self.dataset.get_class_names() self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) self.task_id = 0
[docs] def copy_visual_encoder(self, clip_model): self.visual_encoder.load_state_dict(clip_model.visual.state_dict())
[docs] def forward(self, x): image_features = self.visual_encoder(x.type(self.dtype)) image_features = image_features / image_features.norm(dim=-1, keepdim=True) return image_features