Source code for models.slca_utils.slca

import logging
import numpy as np
import torch
from torch import optim
from torch.nn import functional as F
from models.slca_utils.base import BaseLearner
from models.slca_utils.inc_net import FinetuneIncrementalNet
from torch.distributions.multivariate_normal import MultivariateNormal
from tqdm import tqdm
from datasets import get_dataset
import sys


[docs] class SLCA_Model(BaseLearner): def __init__(self, device, args): super().__init__(device, args) self.device = device self.args = args self._network = FinetuneIncrementalNet(args.feature_extractor_type, pretrained=True) self.bcb_lrscale = 1.0 / 100 self.fix_bcb = False self.save_before_ca = False if self.args.ca_with_logit_norm > 0: self.logit_norm = self.args.ca_with_logit_norm else: self.logit_norm = None self.topk = 5 @property def training(self): return self._network.training
[docs] def to(self, device): self._network.to(device)
[docs] def train(self, *args): self._network.train(*args)
[docs] def eval(self): self._network.eval()
[docs] def get_optimizer(self): lrate = self.args.lr base_params = self._network.convnet.parameters() base_fc_params = [p for p in self._network.fc.parameters() if p.requires_grad == True] head_scale = 1 if not self.fix_bcb: base_params = {'params': base_params, 'lr': lrate * self.bcb_lrscale, 'weight_decay': self.args.optim_wd} base_fc_params = {'params': base_fc_params, 'lr': lrate * head_scale, 'weight_decay': self.args.optim_wd} network_params = [base_params, base_fc_params] else: for p in base_params: p.requires_grad = False network_params = [{'params': base_fc_params, 'lr': lrate * head_scale, 'weight_decay': self.args.optim_wd}] optimizer = optim.SGD(network_params, lr=lrate, momentum=0.9, weight_decay=self.args.optim_wd) scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=self.args.milestones, gamma=self.args.lr_decay) return optimizer, scheduler
[docs] def after_task(self): self._known_classes = self._total_classes logging.info('Exemplar size: {}'.format(self.exemplar_size)) self.save_checkpoint(self.log_path + '/' + self.model_prefix + '_seed{}'.format(self.seed), head_only=self.fix_bcb) self._network.fc.recall()
[docs] def my_compute_class_means(self, loader, offset_1, offset_2): print('Computing class means...', file=sys.stderr) class_vectors = {idx: [] for idx in range(offset_1, offset_2)} class_means, class_covs = {}, {} status = self._network.training self._network.eval() for data in tqdm(loader): imgs, labels = data[0], data[1] if self.args.debug_mode and all(len(class_vectors[idx]) >= 5 for idx in range(offset_1, offset_2)): break imgs, labels = imgs.to(self.device), labels.to(self.device) vectors = self._network.extract_vector(imgs) for c_idx in labels.unique(): class_vectors[c_idx.item()].append(vectors[labels == c_idx].cpu().detach()) class_vectors = {k: torch.cat(v, dim=0) for k, v in class_vectors.items()} for k in class_vectors.keys(): class_means[k] = class_vectors[k].mean(dim=0) class_covs[k] = torch.cov(class_vectors[k].T) + torch.eye(class_means[k].shape[-1]) * 1e-4 print('Done.', file=sys.stderr) self._network.train(status) return class_means, class_covs
def _stage2_compact_classifier(self, class_means, class_covs, offset_1, offset_2): seq_dataset = get_dataset(self.args) cpt = seq_dataset.N_CLASSES_PER_TASK for p in self._network.fc.parameters(): p.requires_grad = True run_epochs = self.args.ca_epochs crct_num = offset_2 param_list = [p for p in self._network.fc.parameters() if p.requires_grad] network_params = [{'params': param_list, 'lr': self.args.lr, 'weight_decay': self.args.optim_wd}] optimizer = optim.SGD(network_params, lr=self.args.lr, momentum=0.9, weight_decay=self.args.optim_wd) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=run_epochs) status = self._network.training self._network.eval() for epoch in range(run_epochs): losses = 0. sampled_data = [] sampled_label = [] num_sampled_pcls = 256 for c_id in range(crct_num): if not isinstance(cpt, list): cpt = [cpt] * seq_dataset.N_TASKS cumsum = np.cumsum(cpt) t_id = np.argmax(cumsum > c_id) decay = (t_id + 1) / (self._cur_task + 1) * 0.1 cls_mean = class_means[c_id].to(self._device) * (0.9 + decay) cls_cov = class_covs[c_id].to(self._device) m = MultivariateNormal(cls_mean.float(), cls_cov.float()) sampled_data_single = m.sample(sample_shape=(num_sampled_pcls,)) sampled_data.append(sampled_data_single) sampled_label.extend([c_id] * num_sampled_pcls) sampled_data = torch.cat(sampled_data, dim=0).float().to(self._device) sampled_label = torch.tensor(sampled_label).long().to(self._device) inputs = sampled_data targets = sampled_label sf_indexes = torch.randperm(inputs.size(0)) inputs = inputs[sf_indexes] targets = targets[sf_indexes] for _iter in tqdm(range(crct_num)): if self.args.debug_mode and _iter >= 5: break inp = inputs[_iter * num_sampled_pcls:(_iter + 1) * num_sampled_pcls] tgt = targets[_iter * num_sampled_pcls:(_iter + 1) * num_sampled_pcls] outputs = self._network(inp, bcb_no_grad=True, fc_only=True) logits = outputs['logits'] if self.logit_norm is not None: per_task_norm = [] prev_t_size = 0 cur_t_size = 0 for _ti in range(self._cur_task + 1): cur_t_size += cpt[_ti] temp_norm = torch.norm(logits[:, prev_t_size:cur_t_size], p=2, dim=-1, keepdim=True) + 1e-7 per_task_norm.append(temp_norm) prev_t_size += cpt[_ti] per_task_norm = torch.cat(per_task_norm, dim=-1) norms = per_task_norm.mean(dim=-1, keepdim=True) norms_all = torch.norm(logits[:, :crct_num], p=2, dim=-1, keepdim=True) + 1e-7 decoupled_logits = torch.div(logits[:, :crct_num], norms) / self.logit_norm loss = F.cross_entropy(decoupled_logits, tgt) else: loss = F.cross_entropy(logits[:, :crct_num], tgt) optimizer.zero_grad() loss.backward() optimizer.step() losses += loss.item() scheduler.step() self._network.train(status)