Source code for backbone.ResNet18_PNN

# Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from typing import List

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.functional import avg_pool2d, relu

from backbone.ResNetBlock import BasicBlock, ResNet, conv3x3
from backbone.utils.modules import AlphaModule, ListModule


[docs] class BasicBlockPnn(BasicBlock): """ The basic block of ResNet. Modified for PNN. """
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Compute a forward pass. Args: x: input tensor (batch_size, input_size) Returns: output tensor (10) """ out = relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) out += self.shortcut(x) return out
[docs] class ResNetPNN(ResNet): """ ResNet network architecture modified for PNN. """ def __init__(self, block: BasicBlock, num_blocks: List[int], num_classes: int, nf: int, old_cols: List[nn.Module] = None, x_shape: torch.Size = None): """ Instantiates the layers of the network. Args: block: the basic ResNet block num_blocks: the number of blocks per layer num_classes: the number of output classes nf: the number of filters """ super(ResNetPNN, self).__init__(block, num_blocks, num_classes, nf) if old_cols is None: old_cols = [] self.old_cols = old_cols self.x_shape = x_shape if len(old_cols) == 0: return assert self.x_shape is not None self.in_planes = self.nf self.lateral_classifier = nn.Linear(nf * 8, num_classes) self.adaptor4 = nn.Sequential( AlphaModule((nf * 8 * len(old_cols), 1, 1)), nn.Conv2d(nf * 8 * len(old_cols), nf * 8, 1), nn.ReLU() ) for i in range(5): setattr(self, 'old_layer' + str(i) + 's', ListModule()) for i in range(1, 4): factor = 2 ** (i - 1) setattr(self, 'lateral_layer' + str(i + 1), self._make_layer(block, nf * (2 ** i), num_blocks[i], stride=2) ) setattr(self, 'adaptor' + str(i), nn.Sequential( AlphaModule((nf * len(old_cols) * factor, self.x_shape[2] // factor, self.x_shape[3] // factor)), nn.Conv2d(nf * len(old_cols) * factor, nf * factor, 1), nn.ReLU(), getattr(self, 'lateral_layer' + str(i + 1)) )) for old_col in old_cols: self.in_planes = self.nf self.old_layer0s.append(conv3x3(3, nf * 1)) self.old_layer0s[-1].load_state_dict(old_col.conv1.state_dict()) for i in range(1, 5): factor = (2 ** (i - 1)) layer = getattr(self, 'old_layer' + str(i) + 's') layer.append(self._make_layer(block, nf * factor, num_blocks[i - 1], stride=(1 if i == 1 else 2))) old_layer = getattr(old_col, 'layer' + str(i)) layer[-1].load_state_dict(old_layer.state_dict()) def _make_layer(self, block: BasicBlock, planes: int, num_blocks: int, stride: int) -> nn.Module: """ Instantiates a ResNet layer. Args: block: ResNet basic block planes: channels across the network num_blocks: number of blocks stride: stride Returns: ResNet layer """ strides = [stride] + [1] * (num_blocks - 1) layers = [] for stride in strides: layers.append(nn.ReLU()) layers.append(block(self.in_planes, planes, stride)) self.in_planes = planes * block.expansion layers.append(nn.ReLU()) return nn.Sequential(*(layers[1:]))
[docs] def forward(self, x: torch.Tensor, returnt='out') -> torch.Tensor: """ Compute a forward pass. Args: x: input tensor (batch_size, *input_shape) Returns: output tensor (output_classes) """ if self.x_shape is None: self.x_shape = x.shape if len(self.old_cols) == 0: return super(ResNetPNN, self).forward(x) else: with torch.no_grad(): out0_old = [relu(self.bn1(old(x))) for old in self.old_layer0s] out1_old = [old(out0_old[i]) for i, old in enumerate(self.old_layer1s)] out2_old = [old(out1_old[i]) for i, old in enumerate(self.old_layer2s)] out3_old = [old(out2_old[i]) for i, old in enumerate(self.old_layer3s)] out4_old = [old(out3_old[i]) for i, old in enumerate(self.old_layer4s)] out = relu(self.bn1(self.conv1(x))) out = F.relu(self.layer1(out)) y = self.adaptor1(torch.cat(out1_old, 1)) out = F.relu(self.layer2(out) + y) y = self.adaptor2(torch.cat(out2_old, 1)) out = F.relu(self.layer3(out) + y) y = self.adaptor3(torch.cat(out3_old, 1)) out = F.relu(self.layer4(out) + y) out = avg_pool2d(out, out.shape[2]) out = out.view(out.size(0), -1) y = avg_pool2d(torch.cat(out4_old, 1), out4_old[0].shape[2]) y = self.adaptor4(y) y = y.view(out.size(0), -1) y = self.lateral_classifier(y) out = self.classifier(out) + y if returnt == 'out': return out raise NotImplementedError("Unknown return type")
[docs] def resnet18_pnn(nclasses: int, nf: int = 64, old_cols: List[nn.Module] = None, x_shape: torch.Size = None): """ Instantiates a ResNet18 network. Args: nclasses: number of output classes nf: number of filters Returns: ResNet network """ if old_cols is None: old_cols = [] return ResNetPNN(BasicBlockPnn, [2, 2, 2, 2], nclasses, nf, old_cols=old_cols, x_shape=x_shape)