# Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import logging
import sys
if __name__ == '__main__':
import os
mammoth_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(mammoth_path)
from argparse import ArgumentParser, Namespace
from backbone import REGISTERED_BACKBONES
from datasets import get_dataset_names, get_dataset_config_names
from models import get_all_models
from models.utils.continual_model import ContinualModel
from utils import binary_to_boolean_type, custom_str_underscore, field_with_aliases
[docs]
def get_single_arg_value(parser: ArgumentParser, arg_name: str):
"""
Returns the value of a single argument without explicitly parsing the arguments.
Args:
parser: the argument parser
arg_name: the name of the argument
Returns:
str: the value of the argument
"""
action = [action for action in parser._actions if action.dest == arg_name]
assert len(action) == 1, f'Argument {arg_name} not found in the parser.'
action = action[0]
# if the argument has a default value, return it
if action.default is not None:
return action.default
# otherwise, search for the argument in the sys.argv
for i, arg in enumerate(sys.argv):
arg_k = arg.split('=')[0]
if arg_k in action.option_strings or arg_k == action.dest:
if len(arg.split('=')) == 2:
return arg.split('=')[1]
else:
return sys.argv[i + 1]
return None
[docs]
def update_cli_defaults(parser: ArgumentParser, cnf: dict) -> None:
"""
Updates the default values of the parser with the values in the configuration dictionary.
If an argument is defined as `required` in the parser but a default value is provided in the configuration dictionary, the argument is set as not required.
Args:
parser: the argument parser
cnf: the configuration dictionary
Returns:
None
"""
parser.set_defaults(**cnf)
for action in parser._actions:
if action.dest == 'help':
continue
if action.dest in cnf:
action.default = cnf[action.dest]
action.required = False
[docs]
def fix_model_parser_backwards_compatibility(main_parser: ArgumentParser, model_parser: ArgumentParser = None) -> ArgumentParser:
"""
Fix the backwards compatibility of the `get_parser` method of the models.
Args:
main_parser: the main parser
model_parser: the parser of the model
Returns:
the fixed parser
"""
if model_parser is None:
return main_parser
if main_parser != model_parser:
for action in model_parser._actions:
if action.dest == 'help':
continue
# add the arguments of the model parser to the main parser
if not any([action.dest == a.dest for a in main_parser._actions]):
main_parser._add_action(action)
# update the defaults of the main parser with the defaults of the model parser
set_defaults_args = model_parser._defaults
for action in main_parser._actions:
if action.dest in set_defaults_args:
action.default = set_defaults_args[action.dest]
action.required = False
return main_parser
[docs]
def build_parsable_args(parser: ArgumentParser, spec: dict) -> None:
"""
Builds the argument parser given a specification and extends the given parser.
The specification dictionary can either be a simple list of key-value argument or follow the format:
.. code-block:: python
{
'name': {
'type': type,
'default': default,
'choices': choices,
'help': help,
'required': True/False
}
}
If the specification is a simple list of key-value arguments, the value of the argument is the default value. If the default is set to `inspect.Parameter.empty`, the argument is required. The type of the argument is inferred from the default value (default is `str`).
Args:
parser: the argument parser
spec: the specification dictionary
Returns:
the argument parser
"""
for name, arg_spec in spec.items():
# check if the argument is already defined in the parser
if any([action.dest == name for action in parser._actions]):
logging.warn(f"Argument `{name}` is already defined in the parser. Skipping...")
continue
if isinstance(arg_spec, dict):
arg_type = arg_spec.get('type', str)
arg_default = arg_spec.get('default', None)
arg_choices = arg_spec.get('choices', None)
arg_help = arg_spec.get('help', '')
arg_required = arg_spec.get('required', False)
else:
arg_choices = None
arg_help = ''
arg_type = type(arg_spec)
arg_default = arg_spec
arg_required = False
parser.add_argument(f'--{name}', type=arg_type, default=arg_default, choices=arg_choices, help=arg_help, required=arg_required)
[docs]
def clean_dynamic_args(args: Namespace) -> Namespace:
"""
Extracts the registered name from the dictionary arguments.
"""
if isinstance(args.backbone, dict):
args.backbone = args.backbone['type']
if isinstance(args.model, dict):
args.model = args.model['type']
if isinstance(args.dataset, dict):
args.dataset = args.dataset['type']
return args
[docs]
def add_dynamic_parsable_args(parser: ArgumentParser, dataset: str, backbone: str) -> None:
"""
Add the additional arguments of the chosen dataset and backbone to the parser.
Args:
parser: the parser instance to extend
dataset: the dataset name
backbone: the backbone name
"""
ds_group = parser.add_argument_group('Dataset arguments', 'Arguments used to define the dataset.')
registered_datasets = get_dataset_names()
if isinstance(dataset, dict):
assert 'type' in dataset, "The dataset `type` (i.e., the registered name) must be defined in the dictionary."
bk_name = dataset['type'].replace('-', '_').lower()
bk_args = {**registered_datasets[bk_name]['parsable_args'], **dataset['args']}
dataset = bk_name
else:
bk_args = registered_datasets[dataset.replace('_', '-').lower()]['parsable_args']
build_parsable_args(ds_group, bk_args)
bk_group = parser.add_argument_group('Backbone arguments', 'Arguments used to define the backbone network.')
if isinstance(backbone, dict):
assert 'type' in backbone, "The backbone `type` (i.e., the registered name) must be defined in the dictionary."
bk_name = backbone['type'].replace('-', '_').lower()
bk_args = {**REGISTERED_BACKBONES[bk_name]['parsable_args'], **backbone['args']}
backbone = bk_name
else:
bk_args = REGISTERED_BACKBONES[backbone.replace('-', '_').lower()]['parsable_args']
build_parsable_args(bk_group, bk_args)
# model dynamic arguments? maybe in the future...
[docs]
def add_configuration_args(parser: ArgumentParser, args: Namespace) -> None:
"""
Arguments that need to define the configuration of the dataset and model.
"""
config_group = parser.add_argument_group('Configuration arguments', 'Arguments used to define the dataset and model configurations.')
config_group.add_argument('--dataset_config', type=str,
choices=get_dataset_config_names(args.dataset),
help='The configuration used for this dataset (e.g., number of tasks, transforms, backbone architecture, etc.).'
'The available configurations are defined in the `datasets/config/<dataset>` folder.')
config_group.add_argument('--model_config', type=field_with_aliases({'default': ['base', 'default'], 'best': ['best']}), default='default',
help='The configuration used for this model. The available configurations are defined in the `models/config/<model>.yaml` file '
'and include a `default` (dataset-agostic) configuration and a `best` configuration (dataset-specific). '
'If not provided, the `default` configuration is used.')
[docs]
def add_initial_args(parser) -> ArgumentParser:
"""
Returns the initial parser for the arguments.
"""
parser.add_argument('--dataset', type=custom_str_underscore, required=True,
choices=get_dataset_names(names_only=True),
help='Which dataset to perform experiments on.')
parser.add_argument('--model', type=custom_str_underscore, required=True,
help='Model name.', choices=list(get_all_models().keys()))
parser.add_argument('--backbone', type=custom_str_underscore, help='Backbone network name.', choices=list(REGISTERED_BACKBONES.keys()))
parser.add_argument('--load_best_args', action='store_true',
help='(deprecated) Loads the best arguments for each method, dataset and memory buffer. '
'NOTE: This option is deprecated and not up to date.')
return parser
[docs]
def add_experiment_args(parser: ArgumentParser) -> None:
"""
Adds the arguments used by all the models.
Args:
parser: the parser instance
Returns:
None
"""
exp_group = parser.add_argument_group('Experiment arguments', 'Arguments used to define the experiment settings.')
exp_group.add_argument('--lr', required=True, type=float, help='Learning rate. This should either be set as default by the model '
'(with `set_defaults <https://docs.python.org/3/library/argparse.html#argparse.ArgumentParser.set_defaults>`_),'
' by the dataset (with `set_default_from_args`, see :ref:`module-datasets.utils`), or with `--lr=<value>`.')
exp_group.add_argument('--batch_size', type=int, help='Batch size.')
exp_group.add_argument('--label_perc_by_task', '--label_perc', '--lpt', type=float, default=1,
dest='label_perc', help='Percentage in (0-1] of labeled examples per task.')
exp_group.add_argument('--label_perc_by_class', '--lpc', type=float, default=1, dest='label_perc_by_class',
help='Percentage in (0-1] of labeled examples per task.')
exp_group.add_argument('--joint', type=int, choices=(0, 1), default=0, help='Train model on Joint (single task)?')
exp_group.add_argument('--eval_future', type=int, choices=(0, 1), default=0, help='Evaluate future tasks?')
validation_group = parser.add_argument_group('Validation and fitting arguments', 'Arguments used to define the validation strategy and the method used to fit the model.')
validation_group.add_argument('--validation', type=float, help='Percentage of samples FOR EACH CLASS drawn from the training set to build the validation set.')
validation_group.add_argument('--validation_mode', type=str, choices=['complete', 'current'], default='current',
help='Mode used for validation. Must be used in combination with `validation` argument. Possible values:'
' - `current`: uses only the current task for validation (default).'
' - `complete`: uses data from both current and past tasks for validation.')
validation_group.add_argument('--fitting_mode', type=str, choices=['epochs', 'iters', 'time', 'early_stopping'], default='epochs',
help='Strategy used for fitting the model. Possible values:'
' - `epochs`: fits the model for a fixed number of epochs (default). NOTE: this option is controlled by the `n_epochs` argument.'
' - `iters`: fits the model for a fixed number of iterations. NOTE: this option is controlled by the `n_iters` argument.'
' - `early_stopping`: fits the model until early stopping criteria are met. This option requires a validation set (see `validation` argument).'
' The early stopping criteria are: if the validation loss does not decrease for `early_stopping_patience` epochs, the training stops.')
validation_group.add_argument('--early_stopping_patience', type=int, default=5,
help='Number of epochs to wait before stopping the training if the validation loss does not decrease. Used only if `fitting_mode=early_stopping`.')
validation_group.add_argument('--early_stopping_metric', type=str, default='loss', choices=['loss', 'accuracy'],
help='Metric used for early stopping. Used only if `fitting_mode=early_stopping`.')
validation_group.add_argument('--early_stopping_freq', type=int, default=1,
help='Frequency of validation evaluation. Used only if `fitting_mode=early_stopping`.')
validation_group.add_argument('--early_stopping_epsilon', type=float, default=1e-6,
help='Minimum improvement required to consider a new best model. Used only if `fitting_mode=early_stopping`.')
validation_group.add_argument('--n_epochs', type=int,
help='Number of epochs. Used only if `fitting_mode=epochs`.')
validation_group.add_argument('--n_iters', type=int,
help='Number of iterations. Used only if `fitting_mode=iters`.')
opt_group = parser.add_argument_group('Optimizer and learning rate scheduler arguments', 'Arguments used to define the optimizer and the learning rate scheduler.')
opt_group.add_argument('--optimizer', type=str, default='sgd',
choices=ContinualModel.AVAIL_OPTIMS,
help='Optimizer.')
opt_group.add_argument('--optim_wd', type=float, default=0.,
help='optimizer weight decay.')
opt_group.add_argument('--optim_mom', type=float, default=0.,
help='optimizer momentum.')
opt_group.add_argument('--optim_nesterov', type=binary_to_boolean_type, default=0,
help='optimizer nesterov momentum.')
opt_group.add_argument('--lr_scheduler', type=str, help='Learning rate scheduler.')
opt_group.add_argument('--scheduler_mode', type=str, choices=['epoch', 'iter'], default='epoch',
help='Scheduler mode. Possible values:'
' - `epoch`: the scheduler is called at the end of each epoch.'
' - `iter`: the scheduler is called at the end of each iteration.')
opt_group.add_argument('--lr_milestones', type=int, default=[], nargs='+',
help='Learning rate scheduler milestones (used if `lr_scheduler=multisteplr`).')
opt_group.add_argument('--sched_multistep_lr_gamma', type=float, default=0.1,
help='Learning rate scheduler gamma (used if `lr_scheduler=multisteplr`).')
noise_group = parser.add_argument_group('Noise arguments', 'Arguments used to define the noisy-label settings.')
noise_group.add_argument('--noise_type', type=field_with_aliases({
'symmetric': ['symmetric', 'sym', 'symm'],
'asymmetric': ['asymmetric', 'asym', 'asymm']
}), default='symmetric',
help='Type of noise to apply. The symmetric type is supported by all datasets, while the asymmetric must be supported explicitly by the dataset (see `datasets/utils/label_noise`).')
noise_group.add_argument('--noise_rate', type=float, default=0,
help='Noise rate in [0-1].')
noise_group.add_argument('--disable_noisy_labels_cache', type=binary_to_boolean_type, default=0,
help='Disable caching the noisy label targets? NOTE: if the seed is not set, the noisy labels will be different at each run with this option disabled.')
noise_group.add_argument('--cache_path_noisy_labels', type=str, default='noisy_labels',
help='Path where to save the noisy labels cache. The path is relative to the `base_path`.')
[docs]
def add_management_args(parser: ArgumentParser) -> None:
"""
Adds the management arguments.
Args:
parser: the parser instance
Returns:
None
"""
mng_group = parser.add_argument_group('Management arguments', 'Generic arguments to manage the experiment reproducibility, logging, debugging, etc.')
mng_group.add_argument('--seed', type=int, default=None,
help='The random seed. If not provided, a random seed will be used.')
mng_group.add_argument('--permute_classes', type=binary_to_boolean_type, default=0,
help='Permute classes before splitting into tasks? This applies the seed before permuting if the `seed` argument is present.')
mng_group.add_argument('--base_path', type=str, default="./data/",
help='The base path where to save datasets, logs, results.')
mng_group.add_argument('--results_path', type=str, default="results/",
help='The path where to save the results. NOTE: this path is relative to `base_path`.')
mng_group.add_argument('--device', type=str,
help='The device (or devices) available to use for training. '
'More than one device can be specified by separating them with a comma. '
'If not provided, the code will use the least used GPU available (if there are any), otherwise the CPU. '
'MPS is supported and is automatically used if no GPU is available and MPS is supported. '
'If more than one GPU is available, Mammoth will use the least used one if `--distributed=no`.')
mng_group.add_argument('--notes', type=str, default=None,
help='Helper argument to include notes for this run. Example: distinguish between different versions of a model and allow separation of results')
mng_group.add_argument('--eval_epochs', type=int, default=None,
help='Perform inference on validation every `eval_epochs` epochs. If not provided, the model is evaluated ONLY at the end of each task.')
mng_group.add_argument('--non_verbose', default=0, type=binary_to_boolean_type, help='Make progress bars non verbose')
mng_group.add_argument('--disable_log', default=0, type=binary_to_boolean_type, help='Disable logging?')
mng_group.add_argument('--num_workers', type=int, default=None, help='Number of workers for the dataloaders (default=infer from number of cpus).')
mng_group.add_argument('--enable_other_metrics', default=0, type=binary_to_boolean_type,
help='Enable computing additional metrics: forward and backward transfer.')
mng_group.add_argument('--debug_mode', type=binary_to_boolean_type, default=0, help='Run only a few training steps per epoch. This also disables logging on wandb.')
mng_group.add_argument('--inference_only', default=0, type=binary_to_boolean_type,
help='Perform inference only for each task (no training).')
mng_group.add_argument('-O', '--code_optimization', type=int, default=0, choices=[0, 1, 2, 3],
help='Optimization level for the code.'
'0: no optimization.'
'1: Use TF32, if available.'
'2: Use BF16, if available.'
'3: Use BF16 and `torch.compile`. BEWARE: torch.compile may break your code if you change the model after the first run! Use with caution.')
mng_group.add_argument('--distributed', type=str, default='no', choices=['no', 'dp', 'ddp'], help='Enable distributed training?')
mng_group.add_argument('--savecheck', choices=['last', 'task'], type=str, help='Save checkpoint every `task` or at the end of the training (`last`).')
mng_group.add_argument('--loadcheck', type=str, default=None, help='Path of the checkpoint to load (.pt file for the specific task)')
mng_group.add_argument('--ckpt_name', type=str, help='(optional) checkpoint save name.')
mng_group.add_argument('--start_from', type=int, default=None, help="Task to start from")
mng_group.add_argument('--stop_after', type=int, default=None, help="Task limit")
wandb_group = parser.add_argument_group('Wandb arguments', 'Arguments to manage logging on Wandb.')
wandb_group.add_argument('--wandb_name', type=str, default=None,
help='Wandb name for this run. Overrides the default name (`args.model`).')
wandb_group.add_argument('--wandb_entity', type=str, help='Wandb entity')
wandb_group.add_argument('--wandb_project', type=str, help='Wandb project name')
[docs]
def add_rehearsal_args(parser: ArgumentParser) -> None:
"""
Adds the arguments used by all the rehearsal-based methods
Args:
parser: the parser instance
Returns:
None
"""
group = parser.add_argument_group('Rehearsal arguments', 'Arguments shared by all rehearsal-based methods.')
group.add_argument('--buffer_size', type=int, required=True,
help='The size of the memory buffer.')
group.add_argument('--minibatch_size', type=int,
help='The batch size of the memory buffer.')
[docs]
def check_multiple_defined_arg_during_string_parse() -> None:
"""
Check if an argument is defined multiple times during the string parsing.
Prevents the user from typing the same argument multiple times as:
`--arg1=val1 --arg1=val2`.
"""
cmd_args = sys.argv[1:]
keys = set()
for i, arg in enumerate(cmd_args):
if '=' in arg:
arg_name = arg.split('=')[0]
else:
arg_name = arg if arg.startswith('-') else None
if arg_name is not None and arg_name in keys:
raise ValueError(f"Argument `{arg_name}` is defined multiple times.")
keys.add(arg_name)
class _DocsArgs:
"""
This class is used to generate the documentation of the arguments.
"""
def __init__(self, name: str, tp: str, choices: str, default: str, help_: str):
if tp is None:
tp = 'unknown'
elif tp.__name__ == '_parse_field':
tp = 'field with aliases (str)'
elif tp.__name__ == 'binary_to_boolean_type':
tp = '0|1|True|False -> bool'
elif tp.__name__ == 'custom_str_underscore':
tp = 'str (with underscores replaced by dashes)'
else:
tp = tp.__name__
self.name = name
self.type = tp
self.choices = choices
self.default = default
self.help = help_
def parse_choices(self) -> str:
if self.choices is None:
return ''
return ', '.join([c.keys() if isinstance(c, dict) else str(c) for c in self.choices])
def __str__(self):
tb = f"""**\\-\\-{self.name}** : {self.type}
\t*Help*: {self.help}\n
\t- *Default*: ``{self.default}``"""
if self.choices is not None:
tb += f"\n\t- *Choices*: ``{self.parse_choices()}``"
return tb
class _DocArgsGroup:
"""
This class is used to generate the documentation of the arguments.
"""
def __init__(self, group_name: str, group_desc: str, doc_args: _DocsArgs):
self.group_name = group_name
self.group_desc = group_desc
self.doc_args = doc_args
def __str__(self):
args_str = '\n'.join([arg.__str__() for arg in self.doc_args])
s = f""".. rubric:: {self.group_name.capitalize()}\n\n"""
if self.group_desc:
s += f"*{self.group_desc}*\n\n"
s += args_str
return s
def _parse_actions(actions: list, group_name: str, group_desc: str) -> _DocArgsGroup:
"""
Parses the actions of the parser.
Args:
actions: the actions to parse
group_name: the name of the group
group_desc: the description of the group
Returns:
an instance of _DocArgsGroup containing the parsed actions
"""
docs_args = []
for action in actions:
if action.dest == 'help':
continue
docs_args.append(_DocsArgs(action.dest, action.type, action.choices, action.default, action.help))
return _DocArgsGroup(group_name, group_desc, docs_args)
if __name__ == '__main__':
print("Generating documentation for the arguments...")
os.chdir(mammoth_path)
parser = ArgumentParser()
add_initial_args(parser)
parser.add_argument('--dataset_config', type=str,
help='The configuration used for this dataset (e.g., number of tasks, transforms, backbone architecture, etc.).'
'The available configurations are defined in the `datasets/config/<dataset>` folder.')
docs_args = []
for action in parser._actions:
if action.dest == 'help':
continue
docs_args.append(_DocsArgs(action.dest, action.type, action.choices, action.default, action.help))
with open('docs/utils/args.rst', 'w') as f:
f.write('.. _module-args:\n\n')
f.write('Arguments\n')
f.write('=========\n\n')
f.write('.. rubric:: MAIN MAMMOTH ARGS\n\n')
for arg in docs_args:
f.write(str(arg) + '\n\n')
parser = ArgumentParser()
add_experiment_args(parser)
docs_args = []
for group in parser._action_groups:
if len([a for a in group._group_actions if a.dest != 'help']) == 0:
continue
docs_args.append(_parse_actions(group._group_actions, group.title, group.description))
with open('docs/utils/args.rst', 'a') as f:
f.write('.. rubric:: EXPERIMENT-RELATED ARGS\n\n')
for arg in docs_args:
f.write(str(arg) + '\n\n')
parser = ArgumentParser()
add_management_args(parser)
docs_args = []
for group in parser._action_groups:
if len([a for a in group._group_actions if a.dest != 'help']) == 0:
continue
docs_args.append(_parse_actions(group._group_actions, group.title, group.description))
with open('docs/utils/args.rst', 'a') as f:
f.write('.. rubric:: MANAGEMENT ARGS\n\n')
for arg in docs_args:
f.write(str(arg) + '\n\n')
parser = ArgumentParser()
add_rehearsal_args(parser)
docs_args = []
for action in parser._actions:
if action.dest == 'help':
continue
docs_args.append(_DocsArgs(action.dest, action.type, action.choices, action.default, action.help))
with open('docs/utils/args.rst', 'a') as f:
f.write('.. rubric:: REEHARSAL-ONLY ARGS\n\n')
for arg in docs_args:
f.write(str(arg) + '\n\n')
print("Saving documentation in docs/utils/args.rst")
print("Done!")
from models import get_model_names
for model_name, model_class in get_model_names().items():
parser = model_class.get_parser(ArgumentParser())
model_args_groups = []
for group in parser._action_groups:
if len([a for a in group._group_actions if a.dest != 'help']) == 0:
continue
model_args_groups.append(_parse_actions(group._group_actions, group.title, group.description))
model_filename = model_name.replace("-", "_")
with open(f'docs/models/{model_filename}_args.rst', 'w') as f:
f.write(f'Arguments\n')
f.write(f'~~~~~~~~~~~\n\n')
for arg in model_args_groups:
f.write(str(arg) + '\n\n')
print(f"Saving documentation in docs/models/{model_filename}_args.rst")