Source code for models.slca_utils.convs.linears
'''
Reference:
https://github.com/hshustc/CVPR19_Incremental_Learning/blob/master/cifar100-class-incremental/modified_linear.py
'''
import torch
from torch import nn
from timm.layers import trunc_normal_
from copy import deepcopy
[docs]
class SimpleContinualLinear(nn.Module):
def __init__(self, embed_dim, nb_classes, feat_expand=False, with_norm=False):
super().__init__()
self.embed_dim = embed_dim
self.feat_expand = feat_expand
self.with_norm = with_norm
heads = []
single_head = []
if with_norm:
single_head.append(nn.LayerNorm(embed_dim))
single_head.append(nn.Linear(embed_dim, nb_classes, bias=True))
head = nn.Sequential(*single_head)
heads.append(head)
self.heads = nn.ModuleList(heads)
for m in self.modules():
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
[docs]
def update(self, nb_classes, freeze_old=True):
single_head = []
if self.with_norm:
single_head.append(nn.LayerNorm(self.embed_dim))
_fc = nn.Linear(self.embed_dim, nb_classes, bias=True)
trunc_normal_(_fc.weight, std=.02)
nn.init.constant_(_fc.bias, 0)
single_head.append(_fc)
new_head = nn.Sequential(*single_head)
if freeze_old:
for p in self.heads.parameters():
p.requires_grad = False
self.heads.append(new_head)
[docs]
def forward(self, x):
out = []
for ti in range(len(self.heads)):
fc_inp = x[ti] if self.feat_expand else x
out.append(self.heads[ti](fc_inp))
out = {'logits': torch.cat(out, dim=1)}
return out