Source code for models.utils

"""
Utility functions for models.
"""

from argparse import ArgumentParser, Namespace
import os
import sys

import yaml

from utils import smart_joint
from utils.conf import warn_once


[docs] def load_model_config(args: Namespace, buffer_size: int = None) -> dict: """ Loads the configuration file for the model. Args: args: the arguments which contains the hyperparameters buffer_size: if a method has a buffer, knowing the buffer_size is required to load the best configuration Returns: dict: the configuration of the model """ filepath = smart_joint('models', 'config', args.model + '.yaml') if hasattr(args, 'model_config') and args.model_config: assert args.model_config in ['best', 'default'] if not os.path.exists(filepath): if args.model_config == 'best': raise FileNotFoundError(f'Model configuration file {args.model_config} not found in {filepath}') else: warn_once(f'Trying to load default configuration for model {args.model} but no configuration file found in {filepath}.') return {} else: if not os.path.exists(filepath): return {} with open(filepath, 'r') as f: config = yaml.safe_load(f) if 'default' not in config: warn_once(f'No default configuration found in {filepath}.') default_config = {} else: default_config = config['default'] if args.model_config == 'default': return default_config else: assert args.dataset in config, f'No best configuration found in {filepath} for dataset {args.dataset}.' if buffer_size is not None: assert buffer_size in config[args.dataset], f'No best configuration found in {filepath} for buffer size {buffer_size}.' buffer_config = config[args.dataset][buffer_size] # get arguments for the buffer size only other_dataset_config = config[args.dataset] # get arguments for the dataset only del other_dataset_config[buffer_size] return {**default_config, **other_dataset_config, **buffer_config} # merge all arguments, with the buffer size overwriting the dataset return {**default_config, **config[args.dataset]}