# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
from backbone.utils.lora_utils import LoRALayer
[docs]
class ClipLinear(nn.Linear, LoRALayer):
def __init__(
self,
in_features: int,
out_features: int,
lora_dropout: float = 0.,
fan_in_fan_out: bool = False,
**kwargs
):
nn.Linear.__init__(self, in_features, out_features, **kwargs)
LoRALayer.__init__(self, lora_dropout=lora_dropout)
self.fan_in_fan_out = fan_in_fan_out
self.weight.requires_grad = False
self.reset_parameters()
if fan_in_fan_out:
self.weight.data = self.weight.data.transpose(0, 1)
[docs]
def reset_parameters(self):
nn.Linear.reset_parameters(self)
[docs]
def forward(self, x: torch.Tensor, AB: dict = None):
def T(w):
return w.transpose(1, 2) if self.fan_in_fan_out else w
result = F.linear(x, T(self.weight), bias=self.bias)
if AB is not None:
A = None
if isinstance(AB, dict):
B = AB['B']
A = AB.get('A')
else:
B = AB
if A is not None:
res = (B @ (A @ torch.permute(x, (1, 2, 0)).unsqueeze(1))).sum(1)
return result + torch.permute(res, (2, 0, 1))
res = (B @ torch.permute(x, (1, 2, 0)).unsqueeze(1)).sum(1)
return result + torch.permute(res, (2, 0, 1))
return result
[docs]
class IncrementalClassifier(nn.Module):
def __init__(self, embed_dim: int, nb_classes: int):
"""
Incremental classifier for continual learning.
Args:
embed_dim: int, dimension of the input features.
nb_classes: int, number of classes to classify.
"""
super().__init__()
self.embed_dim = embed_dim
heads = [nn.Linear(embed_dim, nb_classes, bias=True)]
self.heads = nn.ModuleList(heads)
self.old_state_dict = None
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
[docs]
def update(self, nb_classes: int, freeze_old=True):
"""
Add a new head to the classifier.
Args:
nb_classes, number of classes to add.
freeze_old: bool, whether to freeze the old heads.
"""
_fc = nn.Linear(self.embed_dim, nb_classes, bias=True).to(self.get_device())
nn.init.trunc_normal_(_fc.weight, std=.02)
nn.init.constant_(_fc.bias, 0)
if freeze_old:
for param in self.heads.parameters():
param.requires_grad = False
self.heads.append(_fc)
[docs]
def forward(self, x: torch.Tensor):
"""
Forward pass.
Compute the logits for each head and concatenate them.
Args:
x: torch.Tensor, input features.
"""
return torch.cat([h(x) for h in self.heads], dim=1)