Source code for utils

import inspect
import os
import sys
import string
import random
import logging
from typing import Callable, Type, TypeVar, Union, get_args
T = TypeVar("T")


[docs] def check_fn_dynamic_type(fn: T, tp: Type[T], strict=True) -> bool: """ Controls if the function respects the type `tp`. The function must have the same number of arguments as the type `tp` and the same type for each argument. Args: fn: the function to be checked tp: the type to be respected strict: if True, raises an error if the function does not respect the type `tp` """ type_args = [str(arg).split("'")[1].split("'")[0] for arg in get_args(tp)[0]] fn_args = [v._annotation if v._annotation != inspect._empty else str(type(v.default)).split("'")[1].split("'")[0] for k, v in inspect.signature(fn).parameters.items()] if not all([f == t for f, t in zip(fn_args, type_args)]): if strict: raise ValueError(f'{fn} does not respect type {tp}') return False return True
[docs] def setup_logging(): """ Configures the logging module. """ # check if logging has already been configured if hasattr(setup_logging, 'done'): return formatter = logging.Formatter('[%(levelname)s] %(asctime)s - %(message)s', datefmt='%d-%b-%y %H:%M:%S') handler = logging.StreamHandler(sys.stdout) handler.setFormatter(formatter) logger = logging.getLogger('root') logger.setLevel(logging.INFO) logger.addHandler(handler) setattr(setup_logging, 'done', True)
[docs] def field_with_aliases(choices: dict) -> str: """ Build a data type where for each key in `choices` there are a set of aliases. Example: Given the following dictionary: .. code-block:: python choices = { 'a': ['a', 'alpha'], 'b': ['b', 'beta'] } The values 'a' and 'alpha' will be converted to 'a', and 'b' and 'beta' will be converted to 'b'. Args: choices: the dictionary containing the aliases Returns: the data type for argparse """ def _parse_field(value: str) -> str: if not isinstance(value, str): value = str(value) for key, aliases in choices.items(): if value in aliases: return key raise ValueError(f'Value `{value}` does not match the provided choices `{choices}`') return _parse_field
[docs] def binary_to_boolean_type(value: str) -> bool: """ Converts a binary string to a boolean type. Args: value: the binary string Returns: the boolean type """ if not isinstance(value, str): value = str(value) value = value.lower() true_values = ['true', '1', 't', 'y', 'yes'] false_values = ['false', '0', 'f', 'n', 'no'] assert value in true_values + false_values return value in true_values
[docs] def custom_str_underscore(value): return str(value).replace("_", '-').strip()
[docs] def smart_joint(*paths): return os.path.join(*paths).replace("\\", "/")
[docs] def create_if_not_exists(path: str) -> None: """ Creates the specified folder if it does not exist. Args: path: the complete path of the folder to be created """ if not os.path.exists(path): os.makedirs(path)
[docs] def none_or_float(value): if value == 'None': return None return float(value)
[docs] def random_id(length=8, alphabet=string.ascii_letters + string.digits): """ Returns a random string of the specified length. Args: length: the length of the string alphabet: the alphabet to be used Returns: the random string """ return ''.join(random.choices(alphabet, k=length))
[docs] def infer_args_from_signature(signature: inspect.Signature, excluded_signature: inspect.Signature = None) -> dict: """ Load the arguments of a function from its signature. Args: signature: the signature of the function Returns: the inferred arguments """ excluded_args = {} if excluded_signature is None else list(excluded_signature.parameters.keys()) parsable_args = {} for arg_name, value in list(signature.parameters.items()): if arg_name in excluded_args: continue if arg_name != 'self' and not arg_name.startswith('_'): default = value.default tp = str if value.annotation is not inspect._empty: tp = value.annotation elif default is not inspect.Parameter.empty: tp = type(default) if default is inspect.Parameter.empty and arg_name != 'num_classes': parsable_args[arg_name] = { 'type': tp, 'required': True } else: parsable_args[arg_name] = { 'type': tp, 'required': False, 'default': default if default is not inspect.Parameter.empty else None } return parsable_args
[docs] def register_dynamic_module_fn(name: str, register: dict, tp: Type[T]): """ Register a dynamic module in the specified dictionary. Args: name: the name of the module register: the dictionary where the module will be registered cls: the class to be registered tp: the type of the class, used to dynamically infer the arguments """ name = name.replace('_', '-').lower() def register_network_fn(target: Union[T, Callable]) -> T: # check if the name is already registered if name in register: raise ValueError(f"Name {name} already registered!") # check if `cls` is a subclass of `T` if inspect.isfunction(target): signature = inspect.signature(target) elif isinstance(target, tp) or issubclass(target, tp): signature = inspect.signature(target.__init__) else: raise ValueError(f"The registered class must be a subclass of {tp.__class__.__name__} or a function returning {tp.__class__.__name__}") parsable_args = infer_args_from_signature(signature) register[name] = {'class': target, 'parsable_args': parsable_args} return target return register_network_fn
[docs] class disable_logging: """ Wrapper for disabling logging for a specific block of code. """ def __init__(self, min_level=logging.CRITICAL): self.logger = logging.getLogger() self.min_level = min_level def __enter__(self): self.old_logging_level = self.logger.level logging.disable(self.min_level) def __exit__(self, exit_type, exit_value, exit_traceback): logging.disable(self.old_logging_level)