Source code for models.slca_utils.convs.linears

'''
Reference:
https://github.com/hshustc/CVPR19_Incremental_Learning/blob/master/cifar100-class-incremental/modified_linear.py
'''
import torch
from torch import nn
from timm.layers import trunc_normal_
from copy import deepcopy


[docs] class SimpleContinualLinear(nn.Module): def __init__(self, embed_dim, nb_classes, feat_expand=False, with_norm=False): super().__init__() self.embed_dim = embed_dim self.feat_expand = feat_expand self.with_norm = with_norm heads = [] single_head = [] if with_norm: single_head.append(nn.LayerNorm(embed_dim)) single_head.append(nn.Linear(embed_dim, nb_classes, bias=True)) head = nn.Sequential(*single_head) heads.append(head) self.heads = nn.ModuleList(heads) for m in self.modules(): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if m.bias is not None: nn.init.constant_(m.bias, 0)
[docs] def backup(self): self.old_state_dict = deepcopy(self.state_dict())
[docs] def recall(self): self.load_state_dict(self.old_state_dict)
[docs] def update(self, nb_classes, freeze_old=True): single_head = [] if self.with_norm: single_head.append(nn.LayerNorm(self.embed_dim)) _fc = nn.Linear(self.embed_dim, nb_classes, bias=True) trunc_normal_(_fc.weight, std=.02) nn.init.constant_(_fc.bias, 0) single_head.append(_fc) new_head = nn.Sequential(*single_head) if freeze_old: for p in self.heads.parameters(): p.requires_grad = False self.heads.append(new_head)
[docs] def forward(self, x): out = [] for ti in range(len(self.heads)): fc_inp = x[ti] if self.feat_expand else x out.append(self.heads[ti](fc_inp)) out = {'logits': torch.cat(out, dim=1)} return out