Source code for utils.main

"""
This script is the main entry point for the Mammoth project. It contains the main function `main()` that orchestrates the training process.

The script performs the following tasks:
- Imports necessary modules and libraries.
- Sets up the necessary paths and configurations.
- Parses command-line arguments.
- Initializes the dataset, model, and other components.
- Trains the model using the `train()` function.

To run the script, execute it directly or import it as a module and call the `main()` function.
"""
# Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

# needed (don't change it)
import logging
import numpy  # noqa
import os
import sys
import time
import importlib
import socket
import datetime
import uuid
from argparse import ArgumentParser, Namespace
import torch

mammoth_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(mammoth_path)
sys.path.append(mammoth_path + '/datasets')
sys.path.append(mammoth_path + '/backbone')
sys.path.append(mammoth_path + '/models')

from utils import setup_logging
setup_logging()

if __name__ == '__main__':
    logging.info(f"Running Mammoth! on {socket.gethostname()}. (if you see this message more than once, you are probably importing something wrong)")

    from utils.conf import warn_once
    try:
        if os.getenv('MAMMOTH_TEST', '0') == '0':
            from dotenv import load_dotenv
            load_dotenv()
        else:
            warn_once("Running in test mode. Ignoring .env file.")
    except ImportError:
        warn_once("Warning: python-dotenv not installed. Ignoring .env file.")


[docs] def lecun_fix(): # Yann moved his website to CloudFlare. You need this now from six.moves import urllib # pyright: ignore opener = urllib.request.build_opener() opener.addheaders = [('User-agent', 'Mozilla/5.0')] urllib.request.install_opener(opener)
[docs] def check_args(args, dataset=None): """ Just a (non complete) stream of asserts to ensure the validity of the arguments. """ assert args.label_perc_by_class == 1 or args.label_perc == 1, "Cannot use both `label_perc_by_task` and `label_perc_by_class`" if args.joint: assert args.start_from is None and args.stop_after is None, "Joint training does not support start_from and stop_after" assert not args.enable_other_metrics, "Joint training does not support other metrics" assert not args.eval_future, "Joint training does not support future evaluation (what is the future?)" assert 0 < args.label_perc <= 1, "label_perc must be in (0, 1]" if args.savecheck: assert not args.inference_only, "Should not save checkpoint in inference only mode" assert (args.noise_rate >= 0.) and (args.noise_rate <= 1.), "Noise rate must be in [0, 1]" if dataset is not None: from datasets.utils.gcl_dataset import GCLDataset, ContinualDataset if isinstance(dataset, GCLDataset): assert args.n_epochs == 1, "GCLDataset is not compatible with multiple epochs" assert args.enable_other_metrics == 0, "GCLDataset is not compatible with other metrics (i.e., forward/backward transfer and forgetting)" assert args.eval_future == 0, "GCLDataset is not compatible with future evaluation" assert args.noise_rate == 0, "GCLDataset is not compatible with automatic noise injection" assert issubclass(dataset.__class__, ContinualDataset) or issubclass(dataset.__class__, GCLDataset), "Dataset must be an instance of `ContinualDataset` or `GCLDataset`"
[docs] def load_configs(parser: ArgumentParser) -> dict: from models import get_model_class from models.utils import load_model_config from datasets import get_dataset_class from datasets.utils import get_default_args_for_dataset, load_dataset_config from utils.args import fix_model_parser_backwards_compatibility, get_single_arg_value args = parser.parse_known_args()[0] # load the model configuration # - get the model parser and fix the get_parser function for backwards compatibility model_parser = get_model_class(args).get_parser(parser) parser = fix_model_parser_backwards_compatibility(parser, model_parser) is_rehearsal = any([p for p in parser._actions if p.dest == 'buffer_size']) buffer_size = None if is_rehearsal: # get buffer size buffer_size = get_single_arg_value(parser, 'buffer_size') assert buffer_size is not None, "Buffer size not found in the arguments. Please specify it with --buffer_size." try: buffer_size = int(buffer_size) # try convert to int, check if it is a valid number except ValueError: raise ValueError(f'--buffer_size must be an integer but found {buffer_size}') # - get the defaults that were set with `set_defaults` in the parser base_config = parser._defaults.copy() # - get the configuration file for the model model_config = load_model_config(args, buffer_size=buffer_size) # update the dataset class with the configuration dataset_class = get_dataset_class(args) # load the dataset configuration. If the model specified a dataset config, use it. Otherwise, use the dataset configuration base_dataset_config = get_default_args_for_dataset(args.dataset) if 'dataset_config' in model_config: # if the dataset specified a dataset config, use it cnf_file_dataset_config = load_dataset_config(model_config['dataset_config'], args.dataset) else: cnf_file_dataset_config = load_dataset_config(args.dataset_config, args.dataset) dataset_config = {**base_dataset_config, **cnf_file_dataset_config} dataset_config = dataset_class.set_default_from_config(dataset_config, parser) # the updated configuration file is cleaned from the dataset-specific arguments # - merge the dataset and model configurations, with the model configuration taking precedence config = {**dataset_config, **base_config, **model_config} return config
[docs] def parse_args(): """ Parse command line arguments for the mammoth program and sets up the `args` object. Returns: args (argparse.Namespace): Parsed command line arguments. """ from utils import create_if_not_exists from utils.conf import warn_once from utils.args import add_initial_args, add_management_args, add_experiment_args, add_configuration_args, clean_dynamic_args, \ check_multiple_defined_arg_during_string_parse, add_dynamic_parsable_args, update_cli_defaults, get_single_arg_value from models import get_all_models check_multiple_defined_arg_during_string_parse() parser = ArgumentParser(description='Mammoth - An Extendible (General) Continual Learning Framework for Pytorch', allow_abbrev=False) # 1) add arguments that include model, dataset, and backbone. These define the rest of the arguments. # the backbone is optional as may be set by the dataset or the model. The dataset and model are required. add_initial_args(parser) args = parser.parse_known_args()[0] if args.backbone is None: logging.warning('No backbone specified. Using default backbone (set by the dataset).') # 2) load the configuration arguments for the dataset and model add_configuration_args(parser, args) config = load_configs(parser) # 3) add the remaining arguments # - get the chosen backbone. The CLI argument takes precedence over the configuration file. backbone = args.backbone if backbone is None: if 'backbone' in config: backbone = config['backbone'] else: backbone = get_single_arg_value(parser, 'backbone') assert backbone is not None, "Backbone not found in the arguments. Please specify it with --backbone or in the model or dataset configuration file." # - add the dynamic arguments defined by the chosen dataset and model add_dynamic_parsable_args(parser, args.dataset, backbone) # - add the main Mammoth arguments add_management_args(parser) add_experiment_args(parser) # 4) Once all arguments are in the parser, we can set the defaults using the loaded configuration update_cli_defaults(parser, config) # 5) parse the arguments if args.load_best_args: from utils.best_args import best_args warn_once("The `load_best_args` option is untested and not up to date.") is_rehearsal = any([p for p in parser._actions if p.dest == 'buffer_size']) # check if model has a buffer args = parser.parse_args() if args.model == 'joint': best = best_args[args.dataset]['sgd'] else: best = best_args[args.dataset][args.model] if is_rehearsal: best = best[args.buffer_size] else: best = best[-1] to_parse = sys.argv[1:] + ['--' + k + '=' + str(v) for k, v in best.items()] to_parse.remove('--load_best_args') args = parser.parse_args(to_parse) if args.model == 'joint' and args.dataset == 'mnist-360': args.model = 'joint_gcl' else: args = parser.parse_args() # 6) clean dynamically loaded args args = clean_dynamic_args(args) # 7) final checks and updates to the arguments models_dict = get_all_models() args.model = models_dict[args.model] if args.lr_scheduler is not None: logging.info('`lr_scheduler` set to {}, overrides default from dataset.'.format(args.lr_scheduler)) if args.seed is not None: from utils.conf import set_random_seed set_random_seed(args.seed) # Add uuid, timestamp and hostname for logging args.conf_jobnum = str(uuid.uuid4()) args.conf_timestamp = str(datetime.datetime.now()) args.conf_host = socket.gethostname() # Add the current git commit hash to the arguments if available try: import git repo = git.Repo(path=os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) args.conf_git_hash = repo.head.object.hexsha except Exception: logging.error("Could not retrieve git hash.") args.conf_git_hash = None if args.savecheck: if not os.path.isdir('checkpoints'): create_if_not_exists("checkpoints") now = time.strftime("%Y%m%d-%H%M%S") uid = args.conf_jobnum.split('-')[0] extra_ckpt_name = "" if args.ckpt_name is None else f"{args.ckpt_name}_" args.ckpt_name = f"{extra_ckpt_name}{args.model}_{args.dataset}_{args.dataset_config}_{args.buffer_size if hasattr(args, 'buffer_size') else 0}_{args.n_epochs}_{str(now)}_{uid}" print("Saving checkpoint into", args.ckpt_name, file=sys.stderr) check_args(args) if args.validation is not None: logging.info(f"Using {args.validation}% of the training set as validation set.") logging.info(f"Validation will be computed with mode `{args.validation_mode}`.") return args
[docs] def extend_args(args, dataset): """ Extend the command-line arguments with the default values from the dataset and the model. """ from datasets import ContinualDataset dataset: ContinualDataset = dataset # noqa, used for type hinting if hasattr(args, 'num_classes') and args.num_classes is None: args.num_classes = dataset.N_CLASSES if args.fitting_mode == 'epochs' and args.n_epochs is None and isinstance(dataset, ContinualDataset): args.n_epochs = dataset.get_epochs() elif args.fitting_mode == 'iters' and args.n_iters is None and isinstance(dataset, ContinualDataset): args.n_iters = dataset.get_iters() if args.batch_size is None: args.batch_size = dataset.get_batch_size() if hasattr(importlib.import_module('models.' + args.model), 'Buffer') and (not hasattr(args, 'minibatch_size') or args.minibatch_size is None): args.minibatch_size = dataset.get_minibatch_size() else: args.minibatch_size = args.batch_size if args.validation: if args.validation_mode == 'current': assert dataset.SETTING in ['class-il', 'task-il'], "`current` validation modes is only supported for class-il and task-il settings (requires a task division)." if args.debug_mode: print('Debug mode enabled: running only a few forward steps per epoch with W&B disabled.') # set logging level to debug args.nowand = 1 if args.wandb_entity is None: args.wandb_entity = os.getenv('WANDB_ENTITY', None) if args.wandb_project is None: args.wandb_project = os.getenv('WANDB_PROJECT', None) if args.wandb_entity is None or args.wandb_project is None: logging.info('`wandb_entity` and `wandb_project` not set. Disabling wandb.') args.nowand = 1 else: print('Logging to wandb: {}/{}'.format(args.wandb_entity, args.wandb_project)) args.nowand = 0
[docs] def main(args=None): from utils.conf import base_path, get_device from models import get_model from datasets import get_dataset from utils.training import train from models.utils.future_model import FutureModel from backbone import get_backbone lecun_fix() if args is None: args = parse_args() device = get_device(avail_devices=args.device) args.device = device # set base path base_path(args.base_path) if args.code_optimization != 0: torch.set_float32_matmul_precision('high' if args.code_optimization == 1 else 'medium') logging.info(f"Code_optimization is set to {args.code_optimization}") logging.info(f"Using {torch.get_float32_matmul_precision()} precision for matmul.") if args.code_optimization == 2: if not torch.cuda.is_bf16_supported(): raise NotImplementedError('BF16 is not supported on this machine.') dataset = get_dataset(args) extend_args(args, dataset) check_args(args, dataset=dataset) backbone = get_backbone(args) logging.info(f"Using backbone: {args.backbone}") if args.code_optimization == 3: # check if the model is compatible with torch.compile # from https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html if torch.cuda.get_device_capability()[0] >= 7 and os.name != 'nt': print("================ Compiling model with torch.compile ================") logging.warning("`torch.compile` may break your code if you change the model after the first run!") print("This includes adding classifiers for new tasks, changing the backbone, etc.") print("ALSO: some models CHANGE the backbone during initialization. Remember to call `torch.compile` again after that.") print("====================================================================") backbone = torch.compile(backbone) else: if torch.cuda.get_device_capability()[0] < 7: raise NotImplementedError('torch.compile is not supported on this machine.') else: raise Exception(f"torch.compile is not supported on Windows. Check https://github.com/pytorch/pytorch/issues/90768 for updates.") loss = dataset.get_loss() model = get_model(args, backbone, loss, dataset.get_transform(), dataset=dataset) assert isinstance(model, FutureModel) or not args.eval_future, "Model does not support future_forward." if args.distributed == 'dp': from utils.distributed import make_dp if args.batch_size < torch.cuda.device_count(): raise Exception(f"Batch too small for DataParallel (Need at least {torch.cuda.device_count()}).") model.net = make_dp(model.net) model.to('cuda:0') args.conf_ngpus = torch.cuda.device_count() elif args.distributed == 'ddp': # DDP breaks the buffer, it has to be synchronized. raise NotImplementedError('Distributed Data Parallel not supported yet.') try: import setproctitle # set job name setproctitle.setproctitle('{}_{}_{}'.format(args.model, args.buffer_size if 'buffer_size' in args else 0, args.dataset)) except Exception: pass train(model, dataset, args)
if __name__ == '__main__': main()