from argparse import Namespace
import importlib
import inspect
import os
import math
import numpy as np
import torch
import torch.nn as nn
from typing import Callable
from utils import register_dynamic_module_fn
REGISTERED_BACKBONES = dict() # dictionary containing the registered networks. Template: {name: {'class': class, 'parsable_args': parsable_args}}
[docs]
def xavier(m: nn.Module) -> None:
"""
Applies Xavier initialization to linear modules.
Args:
m: the module to be initialized
Example::
>>> net = nn.Sequential(nn.Linear(10, 10), nn.ReLU())
>>> net.apply(xavier)
"""
if m.__class__.__name__ == 'Linear':
fan_in = m.weight.data.size(1)
fan_out = m.weight.data.size(0)
std = 1.0 * math.sqrt(2.0 / (fan_in + fan_out))
a = math.sqrt(3.0) * std
m.weight.data.uniform_(-a, a)
if m.bias is not None:
m.bias.data.fill_(0.0)
[docs]
def num_flat_features(x: torch.Tensor) -> int:
"""
Computes the total number of items except the first (batch) dimension.
Args:
x: input tensor
Returns:
number of item from the second dimension onward
"""
size = x.size()[1:]
num_features = 1
for ff in size:
num_features *= ff
return num_features
[docs]
class MammothBackbone(nn.Module):
"""
A backbone module for the Mammoth model.
Args:
**kwargs: additional keyword arguments
Methods:
forward: Compute a forward pass.
features: Get the features of the input tensor (same as forward but with returnt='features').
get_params: Returns all the parameters concatenated in a single tensor.
set_params: Sets the parameters to a given value.
get_grads: Returns all the gradients concatenated in a single tensor.
get_grads_list: Returns a list containing the gradients (a tensor for each layer).
"""
def __init__(self, **kwargs) -> None:
super(MammothBackbone, self).__init__()
self.device = torch.device('cpu') if 'device' not in kwargs else kwargs['device']
[docs]
def to(self, device, *args, **kwargs):
super(MammothBackbone, self).to(device, *args, **kwargs)
self.device = device
return self
[docs]
def forward(self, x: torch.Tensor, returnt='out') -> torch.Tensor:
"""
Compute a forward pass.
Args:
x: input tensor (batch_size, *input_shape)
returnt: return type (a string among `out`, `features`, `both`, or `all`)
Returns:
output tensor
"""
raise NotImplementedError
[docs]
def features(self, x: torch.Tensor) -> torch.Tensor:
"""
Compute the features of the input tensor.
Args:
x: input tensor
Returns:
features tensor
"""
return self.forward(x, returnt='features')
[docs]
def get_params(self) -> torch.Tensor:
"""
Returns all the parameters concatenated in a single tensor.
Returns:
parameters tensor
"""
return torch.nn.utils.parameters_to_vector(self.parameters())
[docs]
def set_params(self, new_params: torch.Tensor) -> None:
"""
Sets the parameters to a given value.
Args:
new_params: concatenated values to be set
"""
torch.nn.utils.vector_to_parameters(new_params, self.parameters())
[docs]
def get_grads(self) -> torch.Tensor:
"""
Returns all the gradients concatenated in a single tensor.
Returns:
gradients tensor
"""
grads = []
for pp in list(self.parameters()):
grads.append(pp.grad.view(-1))
return torch.cat(grads)
[docs]
def set_grads(self, new_grads: torch.Tensor) -> None:
"""
Sets the gradients of all parameters.
Args:
new_params: concatenated values to be set
"""
progress = 0
for pp in list(self.parameters()):
cand_grads = new_grads[progress: progress +
torch.tensor(pp.size()).prod()].view(pp.size())
progress += torch.tensor(pp.size()).prod()
pp.grad = cand_grads
[docs]
def register_backbone(name: str) -> Callable:
"""
Decorator to register a backbone network for use in a Dataset. The decorator may be used on a class that inherits from `MammothBackbone` or on a function that returns a `MammothBackbone` instance.
The registered model can be accessed using the `get_backbone` function and can include additional keyword arguments to be set during parsing.
The arguments can be inferred by the *signature* of the backbone network's class. The value of the argument is the default value. If the default is set to `Parameter.empty`, the argument is required. If the default is set to `None`, the argument is optional. The type of the argument is inferred from the default value (default is `str`).
Args:
name: the name of the backbone network
"""
return register_dynamic_module_fn(name, REGISTERED_BACKBONES, MammothBackbone)
[docs]
def get_backbone_class(name: str, return_args=False) -> MammothBackbone:
"""
Get the backbone network class from the registered networks.
Args:
name: the name of the backbone network
return_args: whether to return the parsable arguments of the backbone network
Returns:
the backbone class
"""
name = name.replace('-', '_').lower()
assert name in REGISTERED_BACKBONES, "Attempted to access non-registered network"
cl = REGISTERED_BACKBONES[name]['class']
if return_args:
return cl, REGISTERED_BACKBONES[name]['parsable_args']
[docs]
def get_backbone(args: Namespace) -> MammothBackbone:
"""
Build the backbone network from the registered networks.
Args:
args: the arguments which contains the `--backbone` attribute and the additional arguments required by the backbone network
Returns:
the backbone model
"""
backbone_class, backbone_args = get_backbone_class(args.backbone, return_args=True)
missing_args = [arg for arg in backbone_args.keys() if arg not in vars(args)]
assert len(missing_args) == 0, "Missing arguments for the backbone network: " + ', '.join(missing_args)
parsed_args = {arg: getattr(args, arg) for arg in backbone_args.keys()}
return backbone_class(**parsed_args)
# import all files in the backbone folder to register the networks
for file in os.listdir(os.path.dirname(__file__)):
if file.endswith('.py') and file != '__init__.py':
importlib.import_module(f'backbone.{file[:-3]}')