# Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
from utils.args import ArgumentParser
from utils.conf import get_device
from models.utils.continual_model import ContinualModel
from backbone import get_backbone
[docs]
def get_pnn_backbone(bone, old_cols=None, x_shape=None):
    from backbone.MNISTMLP import BaseMNISTMLP
    from backbone.MNISTMLP_PNN import MNISTMLP_PNN
    from backbone.ResNetBlock import ResNet
    from backbone.ResNet18_PNN import resnet18_pnn
    if isinstance(bone, BaseMNISTMLP):
        return MNISTMLP_PNN(bone.input_size, bone.output_size, old_cols)
    elif isinstance(bone, ResNet):
        return resnet18_pnn(bone.num_classes, bone.nf, old_cols, x_shape)
    else:
        raise NotImplementedError('Progressive Neural Networks is not implemented for this backbone') 
[docs]
class Pnn(ContinualModel):
    """Progressive Neural Networks."""
    NAME = 'pnn'
    COMPATIBILITY = ['task-il']
    def __init__(self, backbone, loss, args, transform, dataset=None):
        self.nets = [get_pnn_backbone(backbone).to(get_device())]
        backbone = self.nets[-1]
        super(Pnn, self).__init__(backbone, loss, args, transform, dataset=dataset)
        self.x_shape = None
        self.soft = torch.nn.Softmax(dim=0)
        self.logsoft = torch.nn.LogSoftmax(dim=0)
        self.task_idx = 0
[docs]
    def forward(self, x, task_label):
        if self.x_shape is None:
            self.x_shape = x.shape
        start_idx, end_idx = self.dataset.get_offsets(task_label)
        if self.task_idx == 0:
            out = self.net(x)
        else:
            self.nets[task_label].to(self.device)
            out = self.nets[task_label](x)
            if self.task_idx != task_label:
                self.nets[task_label].cpu()
        # mask out previous tasks - Task-IL forward
        if start_idx > 0:
            out[:, :start_idx] = -torch.inf
        out[:, end_idx:] = -torch.inf
        return out 
[docs]
    def end_task(self, dataset):
        # instantiate new column
        self.task_idx += 1
        self.nets[-1].cpu()
        self.nets.append(get_pnn_backbone(get_backbone(self.args), self.nets, self.x_shape).to(self.device))
        self.net = self.nets[-1]
        self.opt = self.get_optimizer() 
[docs]
    def observe(self, inputs, labels, not_aug_inputs, epoch=None):
        if self.x_shape is None:
            self.x_shape = inputs.shape
        self.net.to(self.device)
        self.opt.zero_grad()
        outputs = self.net(inputs)
        loss = self.loss(outputs, labels)
        loss.backward()
        self.opt.step()
        return loss.item()