VIT#
Licensed under the Apache License, Version 2.0 (the “License”); you may not use this file except in compliance with the License. You may obtain a copy of the License at
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
Cloned and trimmed version of timm.models.vision_transformer.py Here for STABLE reference.
Check out https://github.com/pprp/timm/blob/master/timm/models/vision_transformer.py for the original file.
The following is the original docstring of the file.
Vision Transformer (ViT) in PyTorch
A PyTorch implement of Vision Transformers as described in:
- ‘An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale’
- How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers
- FlexiViT: One Model for All Patch Sizes
- The official jax code is released and available at
- Acknowledgments:
The paper authors for releasing code and weights, thanks!
I fixed my class token impl based on Phil Wang’s https://github.com/lucidrains/vit-pytorch
Simple transformer style inspired by Andrej Karpathy’s https://github.com/karpathy/minGPT
Bert reference code checks against Huggingface Transformers and Tensorflow Bert
Hacked together by / Copyright 2020, Ross Wightman
Classes#
- class backbone.vit.Attention(dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0)[source]#
Bases:
Module
Attention layer as used in Vision Transformer.
- Parameters:
dim – Number of input channels
num_heads – Number of attention heads
qkv_bias – If True, add a learnable bias to q, k, v
attn_drop – Dropout rate for attention weights
proj_drop – Dropout rate after the final projection
- class backbone.vit.Block(dim, num_heads, mlp_ratio=4.0, qkv_bias=False, drop=0.0, attn_drop=0.0, init_values=None, drop_path=0.0, act_layer=<class 'torch.nn.modules.activation.GELU'>, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>, attn_layer=<class 'backbone.vit.Attention'>, mlp_layer=<class 'backbone.vit.Mlp'>)[source]#
Bases:
Module
- class backbone.vit.Mlp(in_features, hidden_features=None, out_features=None, act_layer=<class 'torch.nn.modules.activation.GELU'>, norm_layer=None, bias=True, drop=0.0, use_conv=False)[source]#
Bases:
Mlp
- class backbone.vit.VisionTransformer(img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token', embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, init_values=None, class_token=True, no_embed_class=False, pre_norm=False, fc_norm=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, weight_init='', embed_layer=<class 'timm.layers.patch_embed.PatchEmbed'>, norm_layer=None, act_layer=None, block_fn=<class 'backbone.vit.Block'>, attn_layer=None, mlp_layer=None, use_lora=False, args=None)[source]#
Bases:
MammothBackbone
Vision Transformer. This implementation supports LoRA (Layer-wise Relevance Adaptation) parameters if use_lora=True.
- A PyTorch impl ofAn Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
- forward(x, AB={}, returnt='out')[source]#
Compute the forward pass of ViT. Can take in an additional argument AB, which is a dictionary containing LoRA-style parameters for each block.
AB can contain - a single value for each block (e.g. AB = {0: {“qkv”: torch.Tensor(…)}, 1: {“qkv”: torch.Tensor(…)}, …}) - a dictionary for each block with a single key B (e.g. AB = {0: {“qkv”: {“B”: torch.Tensor(…)}}}) - a dictionary for each block with both A and B keys of LoRA parameters (e.g. AB = {0: {“qkv”: {“A”: torch.Tensor(…), “B”: torch.Tensor(…)}}})
Supported keys for each block are qkv, proj, fc1, fc2.
NOTE: The values of AB are summed with the weights of the corresponding block.
- forward_features(x, AB={}, return_all=False)[source]#
Compute the forward pass of ViT (features only). Can take in an additional argument AB, which is a dictionary containing LoRA-style parameters for each block.
- Parameters:
x (Tensor) – input tensor
AB – dictionary containing LoRA-style parameters for each block
return_all – whether to return all intermediate features
- Returns:
features for each patch
- forward_head(x, pre_logits=False)[source]#
Compute the forward pass of ViT (head only). Expects input of shape [batch_size, num_patches, embed_dim].
- get_grads(discard_classifier=False)[source]#
Returns all the gradients concatenated in a single tensor.
- Returns:
gradients tensor
- Return type:
Functions#
- backbone.vit.checkpoint_filter_fn(state_dict, model, adapt_layer_scale=False, interpolation='bicubic', antialias=True)[source]#
convert patch embedding weight from manual patchify + linear proj to conv
- backbone.vit.create_vision_transformer(variant, base_class=<class 'backbone.vit.VisionTransformer'>, pretrained=False, filter_fn=<function checkpoint_filter_fn>, **kwargs)[source]#
- backbone.vit.init_weights_vit_jax(module, name='', head_bias=0.0)[source]#
ViT weight initialization, matching JAX (Flax) impl
- backbone.vit.init_weights_vit_moco(module, name='')[source]#
ViT weight initialization, matching moco-v3 impl minus fixed PatchEmbed
- backbone.vit.init_weights_vit_timm(module, name='')[source]#
ViT weight initialization, original timm impl (for reproducibility)
- backbone.vit.resize_pos_embed(posemb, posemb_new, num_prefix_tokens=1, gs_new=(), interpolation='bicubic', antialias=False)[source]#
Rescale the grid of position embeddings when loading from state_dict.
DEPRECATED This function is being deprecated in favour of resample_abs_pos_embed
- backbone.vit.vit_base_patch16_224_prompt_prototype(pretrained=False, pretrain_type='in21k-ft-in1k', **kwargs)[source]#
ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
By default, returns a model pre-trained on ImageNet-21k. Supports: - Pre-train on ImageNet-21k (pretrain_type=’in21k’) - Pre-train on ImageNet-21k and finetuned on ImageNet-1k (pretrain_type=’in21k_old’) - Pre-train with MoCoV3 on ImageNet-21k (pretrain_type=’in21k-ft-in1k’)