import torch
import torch.nn as nn
from backbone.vit import create_vision_transformer
from models.coda_prompt_utils import gram_schmidt
from models.coda_prompt_utils.vit import VisionTransformer
[docs]
class CodaPrompt(nn.Module):
def __init__(self, emb_d, n_tasks, prompt_param, key_dim=768):
super().__init__()
self.task_count = 0
self.emb_d = emb_d
self.key_d = key_dim
self.n_tasks = n_tasks
self._init_smart(emb_d, prompt_param)
pt = int(self.e_pool_size / (self.n_tasks))
# e prompt init
for e in self.e_layers:
# for model saving/loading simplicity, we init the full paramaters here
# however, please note that we reinit the new components at each task
# in the "spirit of continual learning", as we don't know how many tasks
# we will encounter at the start of the task sequence
#
# in the original paper, we used ortho init at the start - this modification is more
# fair in the spirit of continual learning and has little affect on performance
e_l = self.e_p_length
p = tensor_prompt(self.e_pool_size, e_l, emb_d)
k = tensor_prompt(self.e_pool_size, self.key_d)
a = tensor_prompt(self.e_pool_size, self.key_d)
p = gram_schmidt(p, start_c=0, end_c=pt)
k = gram_schmidt(k, start_c=0, end_c=pt)
a = gram_schmidt(a, start_c=0, end_c=pt)
setattr(self, f'e_p_{e}', p)
setattr(self, f'e_k_{e}', k)
setattr(self, f'e_a_{e}', a)
def _init_smart(self, emb_d, prompt_param):
# prompt basic param
self.e_pool_size = int(prompt_param[0])
self.e_p_length = int(prompt_param[1])
self.e_layers = [0, 1, 2, 3, 4]
# strenth of ortho penalty
self.ortho_mu = prompt_param[2]
[docs]
def process_task_count(self):
self.task_count += 1
# in the spirit of continual learning, we will reinit the new components
# for the new task with Gram Schmidt
#
# in the original paper, we used ortho init at the start - this modification is more
# fair in the spirit of continual learning and has little affect on performance
pt = int(self.e_pool_size / (self.n_tasks))
s = int(self.task_count * pt)
f = int((self.task_count + 1) * pt)
for e in self.e_layers:
K = getattr(self, f'e_k_{e}')
A = getattr(self, f'e_a_{e}')
P = getattr(self, f'e_p_{e}')
k = gram_schmidt(K, s, f)
a = gram_schmidt(A, s, f)
p = gram_schmidt(P, s, f)
setattr(self, f'e_p_{e}', p)
setattr(self, f'e_k_{e}', k)
setattr(self, f'e_a_{e}', a)
[docs]
def forward(self, x_querry, l, x_block, train=False, task_id=None):
# e prompts
e_valid = False
if l in self.e_layers:
e_valid = True
B, C = x_querry.shape
K = getattr(self, f'e_k_{l}')
A = getattr(self, f'e_a_{l}')
p = getattr(self, f'e_p_{l}')
pt = int(self.e_pool_size / (self.n_tasks))
s = int(self.task_count * pt)
f = int((self.task_count + 1) * pt)
# freeze/control past tasks
if train:
if self.task_count > 0:
K = torch.cat((K[:s].detach().clone(), K[s:f]), dim=0)
A = torch.cat((A[:s].detach().clone(), A[s:f]), dim=0)
p = torch.cat((p[:s].detach().clone(), p[s:f]), dim=0)
else:
K = K[s:f]
A = A[s:f]
p = p[s:f]
else:
K = K[0:f]
A = A[0:f]
p = p[0:f]
# with attention and cosine sim
# (b x 1 x d) * soft([1 x k x d]) = (b x k x d) -> attention = k x d
a_querry = torch.einsum('bd,kd->bkd', x_querry, A)
# # (b x k x d) - [1 x k x d] = (b x k) -> key = k x d
n_K = nn.functional.normalize(K, dim=1)
q = nn.functional.normalize(a_querry, dim=2)
aq_k = torch.einsum('bkd,kd->bk', q, n_K)
# (b x 1 x k x 1) * [1 x plen x k x d] = (b x plen x d) -> prompt = plen x k x d
P_ = torch.einsum('bk,kld->bld', aq_k, p)
# select prompts
i = int(self.e_p_length / 2)
Ek = P_[:, :i, :]
Ev = P_[:, i:, :]
# ortho penalty
if train and self.ortho_mu > 0:
loss = ortho_penalty(K) * self.ortho_mu
loss += ortho_penalty(A) * self.ortho_mu
loss += ortho_penalty(p.view(p.shape[0], -1)) * self.ortho_mu
else:
loss = 0
else:
loss = 0
# combine prompts for prefix tuning
if e_valid:
p_return = [Ek, Ev]
else:
p_return = None
# return
return p_return, loss, x_block
[docs]
def ortho_penalty(t):
return ((t @ t.T - torch.eye(t.shape[0]).to(t.device))**2).mean()
[docs]
def tensor_prompt(a, b, c=None, ortho=False):
if c is None:
p = torch.nn.Parameter(torch.FloatTensor(a, b), requires_grad=True)
else:
p = torch.nn.Parameter(torch.FloatTensor(a, b, c), requires_grad=True)
if ortho:
nn.init.orthogonal_(p)
else:
nn.init.uniform_(p)
return p
[docs]
class Model(nn.Module):
def __init__(self, num_classes=10, pt=False, prompt_param=None):
super().__init__()
self.task_id = None
# get feature encoder
vit_model = VisionTransformer(img_size=224, patch_size=16, embed_dim=768, depth=12,
num_heads=12, drop_path_rate=0)
if pt:
load_dict = create_vision_transformer('vit_base_patch16_224', base_class=VisionTransformer, pretrained=True, num_classes=0).state_dict()
if 'head.weight' in load_dict:
del load_dict['head.weight']
del load_dict['head.bias']
missing, unexpected = vit_model.load_state_dict(load_dict, strict=False)
assert len([m for m in missing if 'head' not in m]) == 0, f"Missing keys: {missing}"
assert len(unexpected) == 0, f"Unexpected keys: {unexpected}"
# classifier
self.last = nn.Linear(768, num_classes)
self.prompt = CodaPrompt(768, prompt_param[0], prompt_param[1])
# feature encoder changes if transformer vs resnet
self.feat = vit_model
# pen: get penultimate features
[docs]
def forward(self, x, pen=False, train=False):
if self.prompt is not None:
with torch.no_grad():
q, _ = self.feat(x)
q = q[:, 0, :]
out, prompt_loss = self.feat(x, prompt=self.prompt, q=q, train=train, task_id=self.task_id)
out = out[:, 0, :]
else:
out, _ = self.feat(x)
out = out[:, 0, :]
out = out.view(out.size(0), -1)
if not pen:
out = self.last(out)
if self.prompt is not None and train:
return out, prompt_loss
else:
return out