VIT#

Licensed under the Apache License, Version 2.0 (the “License”); you may not use this file except in compliance with the License. You may obtain a copy of the License at

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.


Cloned and trimmed version of timm.models.vision_transformer.py Here for STABLE reference.

Check out https://github.com/pprp/timm/blob/master/timm/models/vision_transformer.py for the original file.

The following is the original docstring of the file.


Vision Transformer (ViT) in PyTorch

A PyTorch implement of Vision Transformers as described in:

‘An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale’
How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers
FlexiViT: One Model for All Patch Sizes
The official jax code is released and available at
Acknowledgments:

Hacked together by / Copyright 2020, Ross Wightman

Classes#

class backbone.vit.Attention(dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0)[source]#

Bases: Module

Attention layer as used in Vision Transformer.

Parameters:
  • dim – Number of input channels

  • num_heads – Number of attention heads

  • qkv_bias – If True, add a learnable bias to q, k, v

  • attn_drop – Dropout rate for attention weights

  • proj_drop – Dropout rate after the final projection

forward(x, **kwargs)[source]#

Forward pass of the attention layer.

Parameters:

x – Input tensor

class backbone.vit.Block(dim, num_heads, mlp_ratio=4.0, qkv_bias=False, drop=0.0, attn_drop=0.0, init_values=None, drop_path=0.0, act_layer=<class 'torch.nn.modules.activation.GELU'>, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>, attn_layer=<class 'backbone.vit.Attention'>, mlp_layer=<class 'backbone.vit.Mlp'>)[source]#

Bases: Module

forward(x, **kwargs)[source]#
class backbone.vit.LayerScale(dim, init_values=1e-05, inplace=False)[source]#

Bases: Module

forward(x)[source]#
class backbone.vit.Mlp(in_features, hidden_features=None, out_features=None, act_layer=<class 'torch.nn.modules.activation.GELU'>, norm_layer=None, bias=True, drop=0.0, use_conv=False)[source]#

Bases: Mlp

forward(x, **kwargs)[source]#
class backbone.vit.VisionTransformer(img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token', embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, init_values=None, class_token=True, no_embed_class=False, pre_norm=False, fc_norm=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, weight_init='', embed_layer=<class 'timm.layers.patch_embed.PatchEmbed'>, norm_layer=None, act_layer=None, block_fn=<class 'backbone.vit.Block'>, attn_layer=None, mlp_layer=None, use_lora=False, args=None)[source]#

Bases: MammothBackbone

Vision Transformer. This implementation supports LoRA (Layer-wise Relevance Adaptation) parameters if use_lora=True.

A PyTorch impl ofAn Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
forward(x, AB={}, returnt='out')[source]#

Compute the forward pass of ViT. Can take in an additional argument AB, which is a dictionary containing LoRA-style parameters for each block.

AB can contain - a single value for each block (e.g. AB = {0: {“qkv”: torch.Tensor(…)}, 1: {“qkv”: torch.Tensor(…)}, …}) - a dictionary for each block with a single key B (e.g. AB = {0: {“qkv”: {“B”: torch.Tensor(…)}}}) - a dictionary for each block with both A and B keys of LoRA parameters (e.g. AB = {0: {“qkv”: {“A”: torch.Tensor(…), “B”: torch.Tensor(…)}}})

Supported keys for each block are qkv, proj, fc1, fc2.

NOTE: The values of AB are summed with the weights of the corresponding block.

Parameters:
  • x (Tensor) – input tensor

  • AB (dict) – dictionary containing LoRA-style parameters for each block

  • returnt – return type (a string among out, features, both, or full)

Returns:

output tensor

forward_features(x, AB={}, return_all=False)[source]#

Compute the forward pass of ViT (features only). Can take in an additional argument AB, which is a dictionary containing LoRA-style parameters for each block.

Parameters:
  • x (Tensor) – input tensor

  • AB – dictionary containing LoRA-style parameters for each block

  • return_all – whether to return all intermediate features

Returns:

features for each patch

forward_head(x, pre_logits=False)[source]#

Compute the forward pass of ViT (head only). Expects input of shape [batch_size, num_patches, embed_dim].

Parameters:
  • x (Tensor) – input tensor

  • pre_logits (bool) – whether to return the pre-logits (pooled features) or the final class scores

Returns:

output tensor with shape [batch_size, num_classes] if pre_logits is False, else [batch_size, embed_dim]

get_grads(discard_classifier=False)[source]#

Returns all the gradients concatenated in a single tensor.

Returns:

gradients tensor

Return type:

Tensor

get_params(discard_classifier=False)[source]#

Returns all the parameters concatenated in a single tensor.

Returns:

parameters tensor

Return type:

Tensor

init_weights(mode='')[source]#

Functions#

backbone.vit.checkpoint_filter_fn(state_dict, model, adapt_layer_scale=False, interpolation='bicubic', antialias=True)[source]#

convert patch embedding weight from manual patchify + linear proj to conv

backbone.vit.create_vision_transformer(variant, base_class=<class 'backbone.vit.VisionTransformer'>, pretrained=False, filter_fn=<function checkpoint_filter_fn>, **kwargs)[source]#
backbone.vit.get_init_weights_vit(mode='jax', head_bias=0.0)[source]#
backbone.vit.init_weights_vit_jax(module, name='', head_bias=0.0)[source]#

ViT weight initialization, matching JAX (Flax) impl

backbone.vit.init_weights_vit_moco(module, name='')[source]#

ViT weight initialization, matching moco-v3 impl minus fixed PatchEmbed

backbone.vit.init_weights_vit_timm(module, name='')[source]#

ViT weight initialization, original timm impl (for reproducibility)

backbone.vit.resize_pos_embed(posemb, posemb_new, num_prefix_tokens=1, gs_new=(), interpolation='bicubic', antialias=False)[source]#

Rescale the grid of position embeddings when loading from state_dict.

DEPRECATED This function is being deprecated in favour of resample_abs_pos_embed

Adapted from:

https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224

backbone.vit.vit_backbone(num_classes, pretrained=True, pretrain_type='in21k-ft-in1k')[source]#
backbone.vit.vit_base_patch16_224_prompt_prototype(pretrained=False, pretrain_type='in21k-ft-in1k', **kwargs)[source]#

ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).

By default, returns a model pre-trained on ImageNet-21k. Supports: - Pre-train on ImageNet-21k (pretrain_type=’in21k’) - Pre-train on ImageNet-21k and finetuned on ImageNet-1k (pretrain_type=’in21k_old’) - Pre-train with MoCoV3 on ImageNet-21k (pretrain_type=’in21k-ft-in1k’)

Parameters:
  • pretrained (bool) – Load pre-trained weights.

  • pretrain_type (str) – Type of pre-training. Default is ‘in21k’. Other options are ‘in21k_old’ and ‘in1k’.

  • **kwargs – Additional arguments to pass to the model.