Source code for utils.args

# Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import logging
import sys

if __name__ == '__main__':
    import os
    mammoth_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
    sys.path.append(mammoth_path)

from argparse import ArgumentParser, Namespace

from backbone import REGISTERED_BACKBONES
from datasets import get_dataset_names, get_dataset_config_names
from models import get_all_models
from models.utils.continual_model import ContinualModel
from utils import binary_to_boolean_type, custom_str_underscore, field_with_aliases


[docs] def get_single_arg_value(parser: ArgumentParser, arg_name: str): """ Returns the value of a single argument without explicitly parsing the arguments. Args: parser: the argument parser arg_name: the name of the argument Returns: str: the value of the argument """ action = [action for action in parser._actions if action.dest == arg_name] assert len(action) == 1, f'Argument {arg_name} not found in the parser.' action = action[0] # if the argument has a default value, return it if action.default is not None: return action.default # otherwise, search for the argument in the sys.argv for i, arg in enumerate(sys.argv): arg_k = arg.split('=')[0] if arg_k in action.option_strings or arg_k == action.dest: if len(arg.split('=')) == 2: return arg.split('=')[1] else: return sys.argv[i + 1] return None
[docs] def update_cli_defaults(parser: ArgumentParser, cnf: dict) -> None: """ Updates the default values of the parser with the values in the configuration dictionary. If an argument is defined as `required` in the parser but a default value is provided in the configuration dictionary, the argument is set as not required. Args: parser: the argument parser cnf: the configuration dictionary Returns: None """ parser.set_defaults(**cnf) for action in parser._actions: if action.dest == 'help': continue if action.dest in cnf: action.default = cnf[action.dest] action.required = False
[docs] def fix_model_parser_backwards_compatibility(main_parser: ArgumentParser, model_parser: ArgumentParser = None) -> ArgumentParser: """ Fix the backwards compatibility of the `get_parser` method of the models. Args: main_parser: the main parser model_parser: the parser of the model Returns: the fixed parser """ if model_parser is None: return main_parser if main_parser != model_parser: for action in model_parser._actions: if action.dest == 'help': continue # add the arguments of the model parser to the main parser if not any([action.dest == a.dest for a in main_parser._actions]): main_parser._add_action(action) # update the defaults of the main parser with the defaults of the model parser set_defaults_args = model_parser._defaults for action in main_parser._actions: if action.dest in set_defaults_args: action.default = set_defaults_args[action.dest] action.required = False return main_parser
[docs] def build_parsable_args(parser: ArgumentParser, spec: dict) -> None: """ Builds the argument parser given a specification and extends the given parser. The specification dictionary can either be a simple list of key-value argument or follow the format: .. code-block:: python { 'name': { 'type': type, 'default': default, 'choices': choices, 'help': help, 'required': True/False } } If the specification is a simple list of key-value arguments, the value of the argument is the default value. If the default is set to `inspect.Parameter.empty`, the argument is required. The type of the argument is inferred from the default value (default is `str`). Args: parser: the argument parser spec: the specification dictionary Returns: the argument parser """ for name, arg_spec in spec.items(): # check if the argument is already defined in the parser if any([action.dest == name for action in parser._actions]): logging.warn(f"Argument `{name}` is already defined in the parser. Skipping...") continue if isinstance(arg_spec, dict): arg_type = arg_spec.get('type', str) arg_default = arg_spec.get('default', None) arg_choices = arg_spec.get('choices', None) arg_help = arg_spec.get('help', '') arg_required = arg_spec.get('required', False) else: arg_choices = None arg_help = '' arg_type = type(arg_spec) arg_default = arg_spec arg_required = False parser.add_argument(f'--{name}', type=arg_type, default=arg_default, choices=arg_choices, help=arg_help, required=arg_required)
[docs] def clean_dynamic_args(args: Namespace) -> Namespace: """ Extracts the registered name from the dictionary arguments. """ if isinstance(args.backbone, dict): args.backbone = args.backbone['type'] if isinstance(args.model, dict): args.model = args.model['type'] if isinstance(args.dataset, dict): args.dataset = args.dataset['type'] return args
[docs] def add_dynamic_parsable_args(parser: ArgumentParser, dataset: str, backbone: str) -> None: """ Add the additional arguments of the chosen dataset and backbone to the parser. Args: parser: the parser instance to extend dataset: the dataset name backbone: the backbone name """ ds_group = parser.add_argument_group('Dataset arguments', 'Arguments used to define the dataset.') registered_datasets = get_dataset_names() if isinstance(dataset, dict): assert 'type' in dataset, "The dataset `type` (i.e., the registered name) must be defined in the dictionary." bk_name = dataset['type'].replace('-', '_').lower() bk_args = {**registered_datasets[bk_name]['parsable_args'], **dataset['args']} dataset = bk_name else: bk_args = registered_datasets[dataset.replace('_', '-').lower()]['parsable_args'] build_parsable_args(ds_group, bk_args) bk_group = parser.add_argument_group('Backbone arguments', 'Arguments used to define the backbone network.') if isinstance(backbone, dict): assert 'type' in backbone, "The backbone `type` (i.e., the registered name) must be defined in the dictionary." bk_name = backbone['type'].replace('-', '_').lower() bk_args = {**REGISTERED_BACKBONES[bk_name]['parsable_args'], **backbone['args']} backbone = bk_name else: bk_args = REGISTERED_BACKBONES[backbone.replace('-', '_').lower()]['parsable_args'] build_parsable_args(bk_group, bk_args)
# model dynamic arguments? maybe in the future...
[docs] def add_configuration_args(parser: ArgumentParser, args: Namespace) -> None: """ Arguments that need to define the configuration of the dataset and model. """ config_group = parser.add_argument_group('Configuration arguments', 'Arguments used to define the dataset and model configurations.') config_group.add_argument('--dataset_config', type=str, choices=get_dataset_config_names(args.dataset), help='The configuration used for this dataset (e.g., number of tasks, transforms, backbone architecture, etc.).' 'The available configurations are defined in the `datasets/config/<dataset>` folder.') config_group.add_argument('--model_config', type=field_with_aliases({'default': ['base', 'default'], 'best': ['best']}), default='default', help='The configuration used for this model. The available configurations are defined in the `models/config/<model>.yaml` file ' 'and include a `default` (dataset-agostic) configuration and a `best` configuration (dataset-specific). ' 'If not provided, the `default` configuration is used.')
[docs] def add_initial_args(parser) -> ArgumentParser: """ Returns the initial parser for the arguments. """ parser.add_argument('--dataset', type=custom_str_underscore, required=True, choices=get_dataset_names(names_only=True), help='Which dataset to perform experiments on.') parser.add_argument('--model', type=custom_str_underscore, required=True, help='Model name.', choices=list(get_all_models().keys())) parser.add_argument('--backbone', type=custom_str_underscore, help='Backbone network name.', choices=list(REGISTERED_BACKBONES.keys())) parser.add_argument('--load_best_args', action='store_true', help='(deprecated) Loads the best arguments for each method, dataset and memory buffer. ' 'NOTE: This option is deprecated and not up to date.') return parser
[docs] def add_experiment_args(parser: ArgumentParser) -> None: """ Adds the arguments used by all the models. Args: parser: the parser instance Returns: None """ exp_group = parser.add_argument_group('Experiment arguments', 'Arguments used to define the experiment settings.') exp_group.add_argument('--lr', required=True, type=float, help='Learning rate. This should either be set as default by the model ' '(with `set_defaults <https://docs.python.org/3/library/argparse.html#argparse.ArgumentParser.set_defaults>`_),' ' by the dataset (with `set_default_from_args`, see :ref:`module-datasets.utils`), or with `--lr=<value>`.') exp_group.add_argument('--batch_size', type=int, help='Batch size.') exp_group.add_argument('--label_perc_by_task', '--label_perc', '--lpt', type=float, default=1, dest='label_perc', help='Percentage in (0-1] of labeled examples per task.') exp_group.add_argument('--label_perc_by_class', '--lpc', type=float, default=1, dest='label_perc_by_class', help='Percentage in (0-1] of labeled examples per task.') exp_group.add_argument('--joint', type=int, choices=(0, 1), default=0, help='Train model on Joint (single task)?') exp_group.add_argument('--eval_future', type=int, choices=(0, 1), default=0, help='Evaluate future tasks?') validation_group = parser.add_argument_group('Validation and fitting arguments', 'Arguments used to define the validation strategy and the method used to fit the model.') validation_group.add_argument('--validation', type=float, help='Percentage of samples FOR EACH CLASS drawn from the training set to build the validation set.') validation_group.add_argument('--validation_mode', type=str, choices=['complete', 'current'], default='current', help='Mode used for validation. Must be used in combination with `validation` argument. Possible values:' ' - `current`: uses only the current task for validation (default).' ' - `complete`: uses data from both current and past tasks for validation.') validation_group.add_argument('--fitting_mode', type=str, choices=['epochs', 'iters', 'time', 'early_stopping'], default='epochs', help='Strategy used for fitting the model. Possible values:' ' - `epochs`: fits the model for a fixed number of epochs (default). NOTE: this option is controlled by the `n_epochs` argument.' ' - `iters`: fits the model for a fixed number of iterations. NOTE: this option is controlled by the `n_iters` argument.' ' - `early_stopping`: fits the model until early stopping criteria are met. This option requires a validation set (see `validation` argument).' ' The early stopping criteria are: if the validation loss does not decrease for `early_stopping_patience` epochs, the training stops.') validation_group.add_argument('--early_stopping_patience', type=int, default=5, help='Number of epochs to wait before stopping the training if the validation loss does not decrease. Used only if `fitting_mode=early_stopping`.') validation_group.add_argument('--early_stopping_metric', type=str, default='loss', choices=['loss', 'accuracy'], help='Metric used for early stopping. Used only if `fitting_mode=early_stopping`.') validation_group.add_argument('--early_stopping_freq', type=int, default=1, help='Frequency of validation evaluation. Used only if `fitting_mode=early_stopping`.') validation_group.add_argument('--early_stopping_epsilon', type=float, default=1e-6, help='Minimum improvement required to consider a new best model. Used only if `fitting_mode=early_stopping`.') validation_group.add_argument('--n_epochs', type=int, help='Number of epochs. Used only if `fitting_mode=epochs`.') validation_group.add_argument('--n_iters', type=int, help='Number of iterations. Used only if `fitting_mode=iters`.') opt_group = parser.add_argument_group('Optimizer and learning rate scheduler arguments', 'Arguments used to define the optimizer and the learning rate scheduler.') opt_group.add_argument('--optimizer', type=str, default='sgd', choices=ContinualModel.AVAIL_OPTIMS, help='Optimizer.') opt_group.add_argument('--optim_wd', type=float, default=0., help='optimizer weight decay.') opt_group.add_argument('--optim_mom', type=float, default=0., help='optimizer momentum.') opt_group.add_argument('--optim_nesterov', type=binary_to_boolean_type, default=0, help='optimizer nesterov momentum.') opt_group.add_argument('--lr_scheduler', type=str, help='Learning rate scheduler.') opt_group.add_argument('--scheduler_mode', type=str, choices=['epoch', 'iter'], default='epoch', help='Scheduler mode. Possible values:' ' - `epoch`: the scheduler is called at the end of each epoch.' ' - `iter`: the scheduler is called at the end of each iteration.') opt_group.add_argument('--lr_milestones', type=int, default=[], nargs='+', help='Learning rate scheduler milestones (used if `lr_scheduler=multisteplr`).') opt_group.add_argument('--sched_multistep_lr_gamma', type=float, default=0.1, help='Learning rate scheduler gamma (used if `lr_scheduler=multisteplr`).') noise_group = parser.add_argument_group('Noise arguments', 'Arguments used to define the noisy-label settings.') noise_group.add_argument('--noise_type', type=field_with_aliases({ 'symmetric': ['symmetric', 'sym', 'symm'], 'asymmetric': ['asymmetric', 'asym', 'asymm'] }), default='symmetric', help='Type of noise to apply. The symmetric type is supported by all datasets, while the asymmetric must be supported explicitly by the dataset (see `datasets/utils/label_noise`).') noise_group.add_argument('--noise_rate', type=float, default=0, help='Noise rate in [0-1].') noise_group.add_argument('--disable_noisy_labels_cache', type=binary_to_boolean_type, default=0, help='Disable caching the noisy label targets? NOTE: if the seed is not set, the noisy labels will be different at each run with this option disabled.') noise_group.add_argument('--cache_path_noisy_labels', type=str, default='noisy_labels', help='Path where to save the noisy labels cache. The path is relative to the `base_path`.')
[docs] def add_management_args(parser: ArgumentParser) -> None: """ Adds the management arguments. Args: parser: the parser instance Returns: None """ mng_group = parser.add_argument_group('Management arguments', 'Generic arguments to manage the experiment reproducibility, logging, debugging, etc.') mng_group.add_argument('--seed', type=int, default=None, help='The random seed. If not provided, a random seed will be used.') mng_group.add_argument('--permute_classes', type=binary_to_boolean_type, default=0, help='Permute classes before splitting into tasks? This applies the seed before permuting if the `seed` argument is present.') mng_group.add_argument('--base_path', type=str, default="./data/", help='The base path where to save datasets, logs, results.') mng_group.add_argument('--results_path', type=str, default="results/", help='The path where to save the results. NOTE: this path is relative to `base_path`.') mng_group.add_argument('--device', type=str, help='The device (or devices) available to use for training. ' 'More than one device can be specified by separating them with a comma. ' 'If not provided, the code will use the least used GPU available (if there are any), otherwise the CPU. ' 'MPS is supported and is automatically used if no GPU is available and MPS is supported. ' 'If more than one GPU is available, Mammoth will use the least used one if `--distributed=no`.') mng_group.add_argument('--notes', type=str, default=None, help='Helper argument to include notes for this run. Example: distinguish between different versions of a model and allow separation of results') mng_group.add_argument('--eval_epochs', type=int, default=None, help='Perform inference on validation every `eval_epochs` epochs. If not provided, the model is evaluated ONLY at the end of each task.') mng_group.add_argument('--non_verbose', default=0, type=binary_to_boolean_type, help='Make progress bars non verbose') mng_group.add_argument('--disable_log', default=0, type=binary_to_boolean_type, help='Disable logging?') mng_group.add_argument('--num_workers', type=int, default=None, help='Number of workers for the dataloaders (default=infer from number of cpus).') mng_group.add_argument('--enable_other_metrics', default=0, type=binary_to_boolean_type, help='Enable computing additional metrics: forward and backward transfer.') mng_group.add_argument('--debug_mode', type=binary_to_boolean_type, default=0, help='Run only a few training steps per epoch. This also disables logging on wandb.') mng_group.add_argument('--inference_only', default=0, type=binary_to_boolean_type, help='Perform inference only for each task (no training).') mng_group.add_argument('-O', '--code_optimization', type=int, default=0, choices=[0, 1, 2, 3], help='Optimization level for the code.' '0: no optimization.' '1: Use TF32, if available.' '2: Use BF16, if available.' '3: Use BF16 and `torch.compile`. BEWARE: torch.compile may break your code if you change the model after the first run! Use with caution.') mng_group.add_argument('--distributed', type=str, default='no', choices=['no', 'dp', 'ddp'], help='Enable distributed training?') mng_group.add_argument('--savecheck', choices=['last', 'task'], type=str, help='Save checkpoint every `task` or at the end of the training (`last`).') mng_group.add_argument('--loadcheck', type=str, default=None, help='Path of the checkpoint to load (.pt file for the specific task)') mng_group.add_argument('--ckpt_name', type=str, help='(optional) checkpoint save name.') mng_group.add_argument('--start_from', type=int, default=None, help="Task to start from") mng_group.add_argument('--stop_after', type=int, default=None, help="Task limit") wandb_group = parser.add_argument_group('Wandb arguments', 'Arguments to manage logging on Wandb.') wandb_group.add_argument('--wandb_name', type=str, default=None, help='Wandb name for this run. Overrides the default name (`args.model`).') wandb_group.add_argument('--wandb_entity', type=str, help='Wandb entity') wandb_group.add_argument('--wandb_project', type=str, help='Wandb project name')
[docs] def add_rehearsal_args(parser: ArgumentParser) -> None: """ Adds the arguments used by all the rehearsal-based methods Args: parser: the parser instance Returns: None """ group = parser.add_argument_group('Rehearsal arguments', 'Arguments shared by all rehearsal-based methods.') group.add_argument('--buffer_size', type=int, required=True, help='The size of the memory buffer.') group.add_argument('--minibatch_size', type=int, help='The batch size of the memory buffer.')
[docs] def check_multiple_defined_arg_during_string_parse() -> None: """ Check if an argument is defined multiple times during the string parsing. Prevents the user from typing the same argument multiple times as: `--arg1=val1 --arg1=val2`. """ cmd_args = sys.argv[1:] keys = set() for i, arg in enumerate(cmd_args): if '=' in arg: arg_name = arg.split('=')[0] else: arg_name = arg if arg.startswith('-') else None if arg_name is not None and arg_name in keys: raise ValueError(f"Argument `{arg_name}` is defined multiple times.") keys.add(arg_name)
class _DocsArgs: """ This class is used to generate the documentation of the arguments. """ def __init__(self, name: str, tp: str, choices: str, default: str, help_: str): if tp is None: tp = 'unknown' elif tp.__name__ == '_parse_field': tp = 'field with aliases (str)' elif tp.__name__ == 'binary_to_boolean_type': tp = '0|1|True|False -> bool' elif tp.__name__ == 'custom_str_underscore': tp = 'str (with underscores replaced by dashes)' else: tp = tp.__name__ self.name = name self.type = tp self.choices = choices self.default = default self.help = help_ def parse_choices(self) -> str: if self.choices is None: return '' return ', '.join([c.keys() if isinstance(c, dict) else str(c) for c in self.choices]) def __str__(self): tb = f"""**\\-\\-{self.name}** : {self.type} \t*Help*: {self.help}\n \t- *Default*: ``{self.default}``""" if self.choices is not None: tb += f"\n\t- *Choices*: ``{self.parse_choices()}``" return tb class _DocArgsGroup: """ This class is used to generate the documentation of the arguments. """ def __init__(self, group_name: str, group_desc: str, doc_args: _DocsArgs): self.group_name = group_name self.group_desc = group_desc self.doc_args = doc_args def __str__(self): args_str = '\n'.join([arg.__str__() for arg in self.doc_args]) s = f""".. rubric:: {self.group_name.capitalize()}\n\n""" if self.group_desc: s += f"*{self.group_desc}*\n\n" s += args_str return s def _parse_actions(actions: list, group_name: str, group_desc: str) -> _DocArgsGroup: """ Parses the actions of the parser. Args: actions: the actions to parse group_name: the name of the group group_desc: the description of the group Returns: an instance of _DocArgsGroup containing the parsed actions """ docs_args = [] for action in actions: if action.dest == 'help': continue docs_args.append(_DocsArgs(action.dest, action.type, action.choices, action.default, action.help)) return _DocArgsGroup(group_name, group_desc, docs_args) if __name__ == '__main__': print("Generating documentation for the arguments...") os.chdir(mammoth_path) parser = ArgumentParser() add_initial_args(parser) parser.add_argument('--dataset_config', type=str, help='The configuration used for this dataset (e.g., number of tasks, transforms, backbone architecture, etc.).' 'The available configurations are defined in the `datasets/config/<dataset>` folder.') docs_args = [] for action in parser._actions: if action.dest == 'help': continue docs_args.append(_DocsArgs(action.dest, action.type, action.choices, action.default, action.help)) with open('docs/utils/args.rst', 'w') as f: f.write('.. _module-args:\n\n') f.write('Arguments\n') f.write('=========\n\n') f.write('.. rubric:: MAIN MAMMOTH ARGS\n\n') for arg in docs_args: f.write(str(arg) + '\n\n') parser = ArgumentParser() add_experiment_args(parser) docs_args = [] for group in parser._action_groups: if len([a for a in group._group_actions if a.dest != 'help']) == 0: continue docs_args.append(_parse_actions(group._group_actions, group.title, group.description)) with open('docs/utils/args.rst', 'a') as f: f.write('.. rubric:: EXPERIMENT-RELATED ARGS\n\n') for arg in docs_args: f.write(str(arg) + '\n\n') parser = ArgumentParser() add_management_args(parser) docs_args = [] for group in parser._action_groups: if len([a for a in group._group_actions if a.dest != 'help']) == 0: continue docs_args.append(_parse_actions(group._group_actions, group.title, group.description)) with open('docs/utils/args.rst', 'a') as f: f.write('.. rubric:: MANAGEMENT ARGS\n\n') for arg in docs_args: f.write(str(arg) + '\n\n') parser = ArgumentParser() add_rehearsal_args(parser) docs_args = [] for action in parser._actions: if action.dest == 'help': continue docs_args.append(_DocsArgs(action.dest, action.type, action.choices, action.default, action.help)) with open('docs/utils/args.rst', 'a') as f: f.write('.. rubric:: REEHARSAL-ONLY ARGS\n\n') for arg in docs_args: f.write(str(arg) + '\n\n') print("Saving documentation in docs/utils/args.rst") print("Done!") from models import get_model_names for model_name, model_class in get_model_names().items(): parser = model_class.get_parser(ArgumentParser()) model_args_groups = [] for group in parser._action_groups: if len([a for a in group._group_actions if a.dest != 'help']) == 0: continue model_args_groups.append(_parse_actions(group._group_actions, group.title, group.description)) model_filename = model_name.replace("-", "_") with open(f'docs/models/{model_filename}_args.rst', 'w') as f: f.write(f'Arguments\n') f.write(f'~~~~~~~~~~~\n\n') for arg in model_args_groups: f.write(str(arg) + '\n\n') print(f"Saving documentation in docs/models/{model_filename}_args.rst")