Source code for models.pnn

# Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import torch

from utils.args import ArgumentParser
from utils.conf import get_device
from models.utils.continual_model import ContinualModel
from backbone import get_backbone

[docs] def get_pnn_backbone(bone, old_cols=None, x_shape=None): from backbone.MNISTMLP import BaseMNISTMLP from backbone.MNISTMLP_PNN import MNISTMLP_PNN from backbone.ResNetBlock import ResNet from backbone.ResNet18_PNN import resnet18_pnn if isinstance(bone, BaseMNISTMLP): return MNISTMLP_PNN(bone.input_size, bone.output_size, old_cols) elif isinstance(bone, ResNet): return resnet18_pnn(bone.num_classes,, old_cols, x_shape) else: raise NotImplementedError('Progressive Neural Networks is not implemented for this backbone')
[docs] class Pnn(ContinualModel): """Progressive Neural Networks.""" NAME = 'pnn' COMPATIBILITY = ['task-il'] def __init__(self, backbone, loss, args, transform, dataset=None): self.nets = [get_pnn_backbone(backbone).to(get_device())] backbone = self.nets[-1] super(Pnn, self).__init__(backbone, loss, args, transform, dataset=dataset) self.x_shape = None self.soft = torch.nn.Softmax(dim=0) self.logsoft = torch.nn.LogSoftmax(dim=0) self.task_idx = 0
[docs] def forward(self, x, task_label): if self.x_shape is None: self.x_shape = x.shape start_idx, end_idx = self.dataset.get_offsets(task_label) if self.task_idx == 0: out = else: self.nets[task_label].to(self.device) out = self.nets[task_label](x) if self.task_idx != task_label: self.nets[task_label].cpu() # mask out previous tasks - Task-IL forward if start_idx > 0: out[:, :start_idx] = -torch.inf out[:, end_idx:] = -torch.inf return out
[docs] def end_task(self, dataset): # instantiate new column self.task_idx += 1 self.nets[-1].cpu() self.nets.append(get_pnn_backbone(get_backbone(self.args), self.nets, self.x_shape).to(self.device)) = self.nets[-1] self.opt = self.get_optimizer()
[docs] def observe(self, inputs, labels, not_aug_inputs, epoch=None): if self.x_shape is None: self.x_shape = inputs.shape self.opt.zero_grad() outputs = loss = self.loss(outputs, labels) loss.backward() self.opt.step() return loss.item()