from argparse import Namespace
from collections.abc import Iterable
import inspect
import os
import sys
import string
import random
import logging
from typing import Callable, Dict, Type, TypeVar, Union, get_args, get_origin, Literal
import torch
import numpy as np
T = TypeVar("T")
[docs]
def in_notebook():
# implementation from tqdm autonotebook
try:
get_ipython = sys.modules['IPython'].get_ipython
if 'IPKernelApp' not in get_ipython().config: # pragma: no cover
return False # running in console mode
# running in notebook mode
return True
except Exception:
return False
[docs]
def check_fn_dynamic_type(fn: T, tp: Type[T], strict=True) -> bool:
"""
Controls if the function respects the type `tp`.
The function must have the same number of arguments as the type `tp` and the same type for each argument.
Args:
fn: the function to be checked
tp: the type to be respected
strict: if True, raises an error if the function does not respect the type `tp`
"""
type_args = [str(arg).split("'")[1].split("'")[0] for arg in get_args(tp)[0]]
fn_args = [v._annotation if v._annotation != inspect._empty else str(type(v.default)).split("'")[1].split("'")[0]
for k, v in inspect.signature(fn).parameters.items()]
if not all([f == t for f, t in zip(fn_args, type_args)]):
if strict:
raise ValueError(f'{fn} does not respect type {tp}')
return False
return True
[docs]
def setup_logging():
"""
Configures the logging module.
"""
# check if logging has already been configured
if hasattr(setup_logging, 'done'):
return
formatter = logging.Formatter('[%(levelname)s] %(asctime)s - %(message)s', datefmt='%d-%b-%y %H:%M:%S')
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(formatter)
logger = logging.getLogger('root')
if logger.handlers:
for h in logger.handlers:
logger.removeHandler(h)
logger.setLevel(os.getenv('LOG_LEVEL', 'INFO'))
logger.addHandler(handler)
setattr(setup_logging, 'done', True)
[docs]
def field_with_aliases(choices: dict) -> str:
"""
Build a data type where for each key in `choices` there are a set of aliases.
Example:
Given the following dictionary:
.. code-block:: python
choices = {
'a': ['a', 'alpha'],
'b': ['b', 'beta']
}
The values 'a' and 'alpha' will be converted to 'a', and 'b' and 'beta' will be converted to 'b'.
Args:
choices: the dictionary containing the aliases
Returns:
the data type for argparse
"""
def _parse_field(value: str) -> str:
if not isinstance(value, str):
value = str(value)
for key, aliases in choices.items():
if value in aliases:
return key
raise ValueError(f'Value `{value}` does not match the provided choices `{choices}`')
return _parse_field
[docs]
def binary_to_boolean_type(value: str) -> bool:
"""
Converts a binary string to a boolean type.
Args:
value: the binary string
Returns:
the boolean type
"""
if not isinstance(value, str):
value = str(value)
value = value.lower()
true_values = ['true', '1', 't', 'y', 'yes']
false_values = ['false', '0', 'f', 'n', 'no']
assert value in true_values + false_values
return value in true_values
[docs]
def custom_str_underscore(value):
return str(value).replace("_", '-').strip()
[docs]
def smart_joint(*paths):
return os.path.join(*paths).replace("\\", "/")
[docs]
def create_if_not_exists(path: str) -> None:
"""
Creates the specified folder if it does not exist.
Args:
path: the complete path of the folder to be created
"""
if not os.path.exists(path):
os.makedirs(path)
[docs]
def none_or_float(value):
if value == 'None':
return None
return float(value)
[docs]
def random_id(length=8, alphabet=string.ascii_letters + string.digits):
"""
Returns a random string of the specified length.
Args:
length: the length of the string
alphabet: the alphabet to be used
Returns:
the random string
"""
return ''.join(random.choices(alphabet, k=length))
[docs]
def infer_args_from_signature(signature: inspect.Signature, excluded_signature: inspect.Signature = None, ignore_args: list = None) -> dict:
"""
Load the arguments of a function from its signature.
Args:
signature: the signature of the function
excluded_signature: the signature of the function to be excluded from the arguments
ignore_args: a list of arguments to be ignored when inferring the arguments from the signature
This function will return a dictionary with the arguments of the function, their type, and whether they are required or not.
If an argument has a default value, it will be included in the dictionary as well.
Returns:
the inferred arguments
"""
excluded_args = [] if excluded_signature is None else list(excluded_signature.parameters.keys())
parsable_args = {}
if ignore_args is None:
ignore_args = []
else:
print(ignore_args)
excluded_args += ignore_args
n_ignored_args = len(ignore_args)
for i, (arg_name, value) in enumerate(signature.parameters.items()):
if arg_name in excluded_args:
continue
if arg_name != 'self' and not arg_name.startswith('_') and i>=n_ignored_args:
default = value.default
tp = str
if value.annotation is not inspect._empty:
tp = value.annotation
elif default is not inspect.Parameter.empty:
tp = type(default)
choices = None
if get_origin(tp) == Literal:
choices = get_args(tp)
tp = str
if default is inspect.Parameter.empty and arg_name != 'num_classes':
parsable_args[arg_name] = {
'type': tp,
'required': True
}
else:
parsable_args[arg_name] = {
'type': tp,
'required': False,
'default': default if default is not inspect.Parameter.empty else None
}
if choices is not None:
parsable_args[arg_name]['choices'] = choices
return parsable_args
[docs]
def register_dynamic_module_fn(name: str, register: dict, tp: Type[T], ignore_args: list = None) -> Callable[[Union[T, Callable]], T]:
"""
Register a dynamic module in the specified dictionary.
Args:
name: the name of the module
register: the dictionary where the module will be registered
cls: the class to be registered
tp: the type of the class, used to dynamically infer the arguments
ignore_args: a list of arguments to be ignored when inferring the arguments from the signature
"""
name = name.replace('_', '-').lower()
def register_network_fn(target: Union[T, Callable]) -> T:
# check if the name is already registered
if name in register:
if not in_notebook():
raise ValueError(f"Name {name} already registered!")
else:
logging.warning(f"Name {name} already registered, overwriting it.")
# check if `cls` is a subclass of `T`
if inspect.isfunction(target):
signature = inspect.signature(target)
elif isinstance(target, tp) or issubclass(target, tp):
signature = inspect.signature(target.__init__)
if not hasattr(target, 'NAME'):
setattr(target, 'NAME', name) # set the name of the class
else:
raise ValueError(f"The registered class must be a subclass of {tp.__class__.__name__} or a function returning {tp.__class__.__name__}")
parsable_args = infer_args_from_signature(signature, ignore_args=ignore_args)
register[name] = {'class': target, 'parsable_args': parsable_args}
return target
return register_network_fn
[docs]
class disable_logging:
"""
Wrapper for disabling logging for a specific block of code.
"""
def __init__(self, min_level=logging.CRITICAL):
self.logger = logging.getLogger()
self.min_level = min_level
def __enter__(self):
self.old_logging_level = self.logger.level
logging.disable(self.min_level)
def __exit__(self, exit_type, exit_value, exit_traceback):
logging.disable(self.old_logging_level)
[docs]
def to_parsable_obj(r: Union[Dict, Namespace, list, torch.Tensor, np.ndarray]) -> Union[Dict, list, str, int, float, bool]:
"""
Convert a non-builtin object to a parsable (and loadable with `weights_only=True`) object.
Looking at you, Namespace.
"""
if isinstance(r, Namespace):
return to_parsable_obj(vars(r))
if isinstance(r, list):
return [to_parsable_obj(x) for x in r]
if isinstance(r, dict):
return {k: to_parsable_obj(v) for k, v in r.items()}
else:
if isinstance(r, torch.Tensor):
r = r.detach().cpu().numpy().tolist()
elif isinstance(r, np.ndarray):
r = r.tolist()
if not isinstance(r, str) and isinstance(r, Iterable) and len(r) > 1:
return [to_parsable_obj(x) for x in r]
# check if type of r is builtin
if isinstance(r, (int, float, str, bool)):
try:
r = r.item() # could be numpy scalar
except BaseException:
return r
if isinstance(r, (torch.device)):
return str(r)
if r is not None:
logging.warning(f"Object {r} is not parsable, returning it as str.")
return str(r) # return as str if not parsable
return None