# Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import sys
from argparse import Namespace
from typing import Dict, List
from torch import nn
import importlib
import inspect
mammoth_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(mammoth_path)
os.chdir(mammoth_path)
from models.utils.continual_model import ContinualModel
from utils.conf import warn_once
[docs]
def get_all_models() -> List[dict]:
return {model.split('.')[0].replace('_', '-'): model.split('.')[0] for model in os.listdir('models')
if not model.find('__') > -1 and not os.path.isdir('models/' + model)}
[docs]
def get_model(args: Namespace, backbone: nn.Module, loss, transform, dataset) -> ContinualModel:
"""
Return the class of the selected continual model among those that are available.
If an error was detected while loading the available datasets, it raises the appropriate error message.
Args:
args (Namespace): the arguments which contains the `--model` attribute
backbone (nn.Module): the backbone of the model
loss: the loss function
transform: the transform function
dataset: the instance of the dataset
Exceptions:
AssertError: if the model is not available
Exception: if an error is detected in the model
Returns:
the continual model instance
"""
model_name = args.model.replace('_', '-')
names = get_model_names()
assert model_name in names
return get_model_class(args)(backbone, loss, args, transform, dataset)
[docs]
def get_model_class(args: Namespace) -> ContinualModel:
"""
Return the class of the selected continual model among those that are available.
If an error was detected while loading the available datasets, it raises the appropriate error message.
Args:
args (Namespace): the arguments which contains the `--model` attribute
Exceptions:
AssertError: if the model is not available
Exception: if an error is detected in the model
Returns:
the continual model class
"""
names = get_model_names()
model_name = args.model.replace('_', '-')
assert model_name in names
if isinstance(names[model_name], Exception):
raise names[model_name]
return names[model_name]
[docs]
def get_model_names() -> Dict[str, ContinualModel]:
"""
Return the available continual model names and classes.
Returns:
A dictionary containing the names of the available continual models and their classes.
"""
def _get_names():
names: Dict[str, ContinualModel] = {}
for model_name, model in get_all_models().items():
try:
mod = importlib.import_module('models.' + model)
model_classe_name = [x for x in mod.__dir__() if 'type' in str(type(getattr(mod, x)))
and 'ContinualModel' in str(inspect.getmro(getattr(mod, x))[1:])][-1]
c = getattr(mod, model_classe_name)
names[c.NAME.replace('_', '-')] = c
except Exception as e:
warn_once("Error in model", model)
names[model.replace('_', '-')] = e
return names
if not hasattr(get_model_names, 'names'):
setattr(get_model_names, 'names', _get_names())
return getattr(get_model_names, 'names')