Source code for models.utils

"""
Utility functions for models.
"""

from argparse import ArgumentParser, Namespace
import os
import sys

import yaml

from utils import smart_joint
from utils.conf import warn_once


[docs] def find_file_ignore_underscores(file: str, basepath=None) -> str: """ Returns the file name by ignoring the underscores and dashes. Args: file: the file name basepath: the base path to search the file Returns: str: the name of the file that more closely resembles the original file name """ basepath = basepath if basepath is not None else os.getcwd() files = os.listdir(basepath) file = file.replace('_', '').replace('-', '') for f in files: if f.replace('_', '').replace('-', '') == file: return smart_joint(basepath, f) return None
[docs] def load_model_config(args: Namespace, buffer_size: int = None) -> dict: """ Loads the configuration file for the model. Args: args: the arguments which contains the hyperparameters buffer_size: if a method has a buffer, knowing the buffer_size is required to load the best configuration Returns: dict: the configuration of the model """ filepath = find_file_ignore_underscores(args.model + '.yaml', smart_joint('models', 'config')) if hasattr(args, 'model_config') and args.model_config: assert args.model_config in ['best', 'default'] if filepath is None or not os.path.exists(filepath): if args.model_config == 'best': raise FileNotFoundError(f'Model configuration file {args.model_config} not found in {filepath}') else: warn_once(f'Trying to load default configuration for model {args.model} but no configuration file found in {filepath}.') return {} else: if filepath is None or not os.path.exists(filepath): return {} with open(filepath, 'r') as f: config = yaml.safe_load(f) if config is None or 'default' not in config: warn_once(f'No default configuration found in {filepath}.') default_config = {} else: default_config = config['default'] if args.model_config == 'default': return default_config else: assert args.dataset in config, f'No best configuration found in {filepath} for dataset {args.dataset}.' if buffer_size is not None: assert buffer_size in config[args.dataset], f'No best configuration found in {filepath} for buffer size {buffer_size}.' buffer_config = config[args.dataset][buffer_size] # get arguments for the buffer size only other_dataset_config = config[args.dataset] # get arguments for the dataset only del other_dataset_config[buffer_size] return {**default_config, **other_dataset_config, **buffer_config} # merge all arguments, with the buffer size overwriting the dataset return {**default_config, **config[args.dataset]}