Source code for models.utils
"""
Utility functions for models.
"""
from argparse import ArgumentParser, Namespace
import os
import sys
import yaml
from utils import smart_joint
from utils.conf import warn_once
[docs]
def load_model_config(args: Namespace, buffer_size: int = None) -> dict:
"""
Loads the configuration file for the model.
Args:
args: the arguments which contains the hyperparameters
buffer_size: if a method has a buffer, knowing the buffer_size is required to load the best configuration
Returns:
dict: the configuration of the model
"""
filepath = smart_joint('models', 'config', args.model + '.yaml')
if hasattr(args, 'model_config') and args.model_config:
assert args.model_config in ['best', 'default']
if not os.path.exists(filepath):
if args.model_config == 'best':
raise FileNotFoundError(f'Model configuration file {args.model_config} not found in {filepath}')
else:
warn_once(f'Trying to load default configuration for model {args.model} but no configuration file found in {filepath}.')
return {}
else:
if not os.path.exists(filepath):
return {}
with open(filepath, 'r') as f:
config = yaml.safe_load(f)
if 'default' not in config:
warn_once(f'No default configuration found in {filepath}.')
default_config = {}
else:
default_config = config['default']
if args.model_config == 'default':
return default_config
else:
assert args.dataset in config, f'No best configuration found in {filepath} for dataset {args.dataset}.'
if buffer_size is not None:
assert buffer_size in config[args.dataset], f'No best configuration found in {filepath} for buffer size {buffer_size}.'
buffer_config = config[args.dataset][buffer_size] # get arguments for the buffer size only
other_dataset_config = config[args.dataset] # get arguments for the dataset only
del other_dataset_config[buffer_size]
return {**default_config, **other_dataset_config, **buffer_config} # merge all arguments, with the buffer size overwriting the dataset
return {**default_config, **config[args.dataset]}