Source code for backbone.utils.lora_utils

import collections.abc
from itertools import repeat
from torch import nn
import torch
import torch.nn.functional as F


def _ntuple(n):
    def parse(x):
        if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
            return tuple(x)
        return tuple(repeat(x, n))
    return parse


to_2tuple = _ntuple(2)


[docs] class LoRALayer(): def __init__( self, lora_dropout: float, ): # Optional dropout if lora_dropout > 0.: self.lora_dropout = nn.Dropout(p=lora_dropout) else: self.lora_dropout = lambda x: x
[docs] class LoRALinear(nn.Linear, LoRALayer): def __init__( self, in_features: int, out_features: int, lora_dropout: float = 0., fan_in_fan_out: bool = False, **kwargs ): nn.Linear.__init__(self, in_features, out_features, **kwargs) LoRALayer.__init__(self, lora_dropout=lora_dropout) self.fan_in_fan_out = fan_in_fan_out self.weight.requires_grad = False self.reset_parameters() if fan_in_fan_out: self.weight.data = self.weight.data.transpose(0, 1)
[docs] def reset_parameters(self): nn.Linear.reset_parameters(self)
[docs] def forward(self, x: torch.Tensor, AB: dict = None): def T(w): return w.transpose(1, 2) if self.fan_in_fan_out else w result = F.linear(x, T(self.weight), bias=self.bias) if AB is not None: A = None if isinstance(AB, dict): B = AB['B'] A = AB.get('A') else: B = AB if A is not None: return result + (B @ (A @ x.transpose(1, 2).unsqueeze(1))).sum(1).transpose(1, 2) return result + (B @ x.transpose(1, 2).unsqueeze(1)).sum(1).transpose(1, 2) return result
[docs] class LoRAAttention(nn.Module): """ Attention layer as used in Vision Transformer. Adapted to support LoRA-style parameters. Args: dim: Number of input channels num_heads: Number of attention heads qkv_bias: If True, add a learnable bias to q, k, v attn_drop: Dropout rate for attention weights proj_drop: Dropout rate after the final projection """ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): super().__init__() assert dim % num_heads == 0, 'dim should be divisible by num_heads' self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim ** -0.5 self.qkv = LoRALinear(dim, dim * 3, 0., bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = LoRALinear(dim, dim, 0.) self.proj_drop = nn.Dropout(proj_drop)
[docs] def forward(self, x, AB: dict = None, **kwargs): """ Forward pass of the attention layer. Supports `AB` for LoRA-style parameters (checkout docs for `VisionTransformer.forward`). Args: x: Input tensor AB: Dictionary containing LoRA-style parameters for the layer """ B, N, C = x.shape AB_qkv = None if AB is not None: AB_qkv = AB.get("qkv") qkv = self.qkv(x, AB_qkv) qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) # NOTE: flash attention is less debuggable than the original. Use the commented code below if in trouble. if torch.__version__ >= '2.1.0': x = F.scaled_dot_product_attention(q, k, v, scale=self.scale, dropout_p=self.attn_drop.p) else: attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v) x = x.transpose(1, 2).reshape(B, N, C) AB_proj = None if AB is not None: AB_proj = AB.get("proj") x = self.proj(x, AB_proj) x = self.proj_drop(x) return x
[docs] class LayerScale(nn.Module): def __init__(self, dim, init_values=1e-5, inplace=False): super().__init__() self.inplace = inplace self.gamma = nn.Parameter(init_values * torch.ones(dim))
[docs] def forward(self, x): return x.mul_(self.gamma) if self.inplace else x * self.gamma
[docs] class LoRAMlp(nn.Module): """ MLP as used in Vision Transformer, MLP-Mixer and related networks. Adapted to support LoRA-style parameters. """ def __init__( self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, norm_layer=None, bias=True, drop=0., use_conv=False, ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features bias = to_2tuple(bias) drop_probs = to_2tuple(drop) assert use_conv is False self.fc1 = LoRALinear(in_features, hidden_features, bias=bias[0], lora_dropout=0.) self.act = act_layer() self.drop1 = nn.Dropout(drop_probs[0]) self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity() self.fc2 = LoRALinear(hidden_features, out_features, bias=bias[1], lora_dropout=0.) self.drop2 = nn.Dropout(drop_probs[1])
[docs] def forward(self, x: torch.Tensor, AB: dict = None, **kwargs): """ Forward pass of the MLP layer. Supports `AB` for LoRA-style parameters (checkout docs for `VisionTransformer.forward`). Args: x: Input tensor AB: Dictionary containing LoRA-style parameters for the layer """ AB_fc1 = None AB_fc2 = None if AB is not None: AB_fc1 = AB.get("fc1") AB_fc2 = AB.get("fc2") x = self.fc1(x, AB_fc1) x = self.act(x) x = self.drop1(x) x = self.norm(x) x = self.fc2(x, AB_fc2) x = self.drop2(x) return x