Source code for backbone

from argparse import Namespace
import importlib
import os
import math

import torch
import torch.nn as nn

from typing import Callable

from utils import register_dynamic_module_fn

REGISTERED_BACKBONES = dict()  # dictionary containing the registered networks. Template: {name: {'class': class, 'parsable_args': parsable_args}}


[docs] def xavier(m: nn.Module) -> None: """ Applies Xavier initialization to linear modules. Args: m: the module to be initialized Example:: >>> net = nn.Sequential(nn.Linear(10, 10), nn.ReLU()) >>> net.apply(xavier) """ if m.__class__.__name__ == 'Linear': fan_in = m.weight.data.size(1) fan_out = m.weight.data.size(0) std = 1.0 * math.sqrt(2.0 / (fan_in + fan_out)) a = math.sqrt(3.0) * std m.weight.data.uniform_(-a, a) if m.bias is not None: m.bias.data.fill_(0.0)
[docs] def num_flat_features(x: torch.Tensor) -> int: """ Computes the total number of items except the first (batch) dimension. Args: x: input tensor Returns: number of item from the second dimension onward """ size = x.size()[1:] num_features = 1 for ff in size: num_features *= ff return num_features
[docs] class MammothBackbone(nn.Module): """ A backbone module for the Mammoth model. Args: **kwargs: additional keyword arguments Methods: forward: Compute a forward pass. features: Get the features of the input tensor (same as forward but with returnt='features'). get_params: Returns all the parameters concatenated in a single tensor. set_params: Sets the parameters to a given value. get_grads: Returns all the gradients concatenated in a single tensor. get_grads_list: Returns a list containing the gradients (a tensor for each layer). """ def __init__(self, **kwargs) -> None: super(MammothBackbone, self).__init__() self.device = torch.device('cpu') if 'device' not in kwargs else kwargs['device']
[docs] def to(self, device, *args, **kwargs): super(MammothBackbone, self).to(device, *args, **kwargs) self.device = device return self
[docs] def forward(self, x: torch.Tensor, returnt='out') -> torch.Tensor: """ Compute a forward pass. Args: x: input tensor (batch_size, *input_shape) returnt: return type (a string among `out`, `features`, `both`, or `all`) Returns: output tensor """ raise NotImplementedError
[docs] def features(self, x: torch.Tensor) -> torch.Tensor: """ Compute the features of the input tensor. Args: x: input tensor Returns: features tensor """ return self.forward(x, returnt='features')
[docs] def get_params(self) -> torch.Tensor: """ Returns all the parameters concatenated in a single tensor. Returns: parameters tensor """ return torch.nn.utils.parameters_to_vector(self.parameters())
[docs] def set_params(self, new_params: torch.Tensor) -> None: """ Sets the parameters to a given value. Args: new_params: concatenated values to be set """ torch.nn.utils.vector_to_parameters(new_params, self.parameters())
[docs] def get_grads(self) -> torch.Tensor: """ Returns all the gradients concatenated in a single tensor. Returns: gradients tensor """ grads = [] for pp in list(self.parameters()): grads.append(pp.grad.view(-1)) return torch.cat(grads)
[docs] def set_grads(self, new_grads: torch.Tensor) -> None: """ Sets the gradients of all parameters. Args: new_params: concatenated values to be set """ progress = 0 for pp in list(self.parameters()): cand_grads = new_grads[progress: progress + torch.tensor(pp.size()).prod()].view(pp.size()) progress += torch.tensor(pp.size()).prod() pp.grad = cand_grads
[docs] def register_backbone(name: str) -> Callable: """ Decorator to register a backbone network for use in a Dataset. The decorator may be used on a class that inherits from `MammothBackbone` or on a function that returns a `MammothBackbone` instance. The registered model can be accessed using the `get_backbone` function and can include additional keyword arguments to be set during parsing. The arguments can be inferred by the *signature* of the backbone network's class. The value of the argument is the default value. If the default is set to `Parameter.empty`, the argument is required. If the default is set to `None`, the argument is optional. The type of the argument is inferred from the default value (default is `str`). Args: name: the name of the backbone network """ return register_dynamic_module_fn(name, REGISTERED_BACKBONES, MammothBackbone)
[docs] def get_backbone_class(name: str, return_args=False) -> MammothBackbone: """ Get the backbone network class from the registered networks. Args: name: the name of the backbone network return_args: whether to return the parsable arguments of the backbone network Returns: the backbone class """ name = name.replace('_', '-').lower() assert name in REGISTERED_BACKBONES, f"Attempted to access non-registered network: {name}" cl = REGISTERED_BACKBONES[name]['class'] if return_args: return cl, REGISTERED_BACKBONES[name]['parsable_args']
[docs] def get_backbone(args: Namespace) -> MammothBackbone: """ Build the backbone network from the registered networks. Args: args: the arguments which contains the `--backbone` attribute and the additional arguments required by the backbone network Returns: the backbone model """ backbone_class, backbone_args = get_backbone_class(args.backbone, return_args=True) missing_args = [arg for arg in backbone_args.keys() if arg not in vars(args)] assert len(missing_args) == 0, "Missing arguments for the backbone network: " + ', '.join(missing_args) parsed_args = {arg: getattr(args, arg) for arg in backbone_args.keys()} return backbone_class(**parsed_args)
# import all files in the backbone folder to register the networks for file in os.listdir(os.path.dirname(__file__)): if file.endswith('.py') and file != '__init__.py': importlib.import_module(f'backbone.{file[:-3]}')