Source code for datasets

"""
Datasets can be included either by registering them using the `register_dataset` decorator or by following the old naming convention:
- A single dataset is defined in a file named `<dataset_name>.py` in the `datasets` folder.
- The dataset class must inherit from `ContinualDataset`.
"""

import os
import sys
from typing import Callable
import importlib
import inspect
from argparse import Namespace

mammoth_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(mammoth_path)
os.chdir(mammoth_path)

from utils import infer_args_from_signature, register_dynamic_module_fn
from utils.conf import warn_once
from datasets.utils.continual_dataset import ContinualDataset

REGISTERED_DATASETS = dict()  # dictionary containing the registered datasets. Template: {name: {'class': class, 'parsable_args': parsable_args}}


[docs] def register_dataset(name: str) -> Callable: """ Decorator to register a ContinualDatasety. The decorator may be used on a class that inherits from `ContinualDataset` or on a function that returns a `ContinualDataset` instance. The registered dataset can be accessed using the `get_dataset` function and can include additional keyword arguments to be set during parsing. The arguments can be inferred by the *signature* of the dataset's class. The value of the argument is the default value. If the default is set to `Parameter.empty`, the argument is required. If the default is set to `None`, the argument is optional. The type of the argument is inferred from the default value (default is `str`). Args: name: the name of the dataset """ if hasattr(get_dataset_names, 'names'): # reset the cache of the dataset names del get_dataset_names.names return register_dynamic_module_fn(name, REGISTERED_DATASETS, ContinualDataset)
[docs] def get_all_datasets_legacy(): """ Returns the list of all the available datasets in the datasets folder that follow the old naming convention. """ return [model.split('.')[0] for model in os.listdir('datasets') if not model.find('__') > -1 and 'py' in model]
[docs] def get_dataset_names(names_only=False): """ Return the names of the available continual dataset. If an error was detected while loading the available datasets, it raises the appropriate error message. Args: names_only (bool): whether to return only the names of the available datasets Exceptions: AssertError: if the dataset is not available Exception: if an error is detected in the dataset Returns: the named of the available continual datasets """ def _dataset_names(): names = {} # key: dataset name, value: {'class': dataset class, 'parsable_args': parsable_args} for dataset, dataset_conf in REGISTERED_DATASETS.items(): names[dataset.replace('_', '-')] = {'class': dataset_conf['class'], 'parsable_args': dataset_conf['parsable_args']} base_class_signature = inspect.signature(ContinualDataset.__init__) for dataset in get_all_datasets_legacy(): # for the datasets that follow the old naming convention, load the dataset class and check for errors if dataset in names: # dataset registered with the new convention has priority continue try: mod = importlib.import_module('datasets.' + dataset) dataset_classes_name = [x for x in mod.__dir__() if 'type' in str(type(getattr(mod, x))) and 'ContinualDataset' in str(inspect.getmro(getattr(mod, x))[1:]) and 'GCLDataset' not in str(inspect.getmro(getattr(mod, x)))] for d in dataset_classes_name: c = getattr(mod, d) signature = inspect.signature(c.__init__) parsable_args = infer_args_from_signature(signature, excluded_signature=base_class_signature) names[c.NAME.replace('_', '-')] = {'class': c, 'parsable_args': parsable_args} gcl_dataset_classes_name = [x for x in mod.__dir__() if 'type' in str(type(getattr(mod, x))) and 'GCLDataset' in str(inspect.getmro(getattr(mod, x))[1:])] for d in gcl_dataset_classes_name: c = getattr(mod, d) signature = inspect.signature(c.__init__) parsable_args = infer_args_from_signature(signature, excluded_signature=base_class_signature) names[c.NAME.replace('_', '-')] = {'class': c, 'parsable_args': parsable_args} except Exception as e: # if an error is detected, raise the appropriate error message warn_once(f'Error in dataset {dataset}') warn_once(e) names[dataset.replace('_', '-')] = e return names if not hasattr(get_dataset_names, 'names'): setattr(get_dataset_names, 'names', _dataset_names()) names = getattr(get_dataset_names, 'names') if names_only: return list(names.keys()) return names
[docs] def get_dataset_config_names(dataset: str): """ Return the names of the available continual dataset configurations. The configurations can be used to create a dataset with specific hyperparameters and can be specified using the `--dataset_config` attribute. The configurations are stored in the `datasets/configs/<dataset>` folder. """ def _dataset_config_names(dataset): names = [] if os.path.exists(f'datasets/configs/{dataset}'): names = [dset_config.split('.yaml')[0] for dset_config in os.listdir(f'datasets/configs/{dataset}') if dset_config.endswith('.yaml') and not dset_config.startswith('__')] return names if not hasattr(get_dataset_config_names, 'names'): setattr(get_dataset_config_names, 'names', {}) if dataset not in get_dataset_config_names.names: get_dataset_config_names.names[dataset] = _dataset_config_names(dataset) return get_dataset_config_names.names[dataset]
[docs] def get_dataset_class(args: Namespace, return_args=False) -> ContinualDataset: """ Return the class of the selected continual dataset among those that are available. If an error was detected while loading the available datasets, it raises the appropriate error message. Args: args (Namespace): the arguments which contains the `--dataset` attribute return_args (bool): whether to return the parsable arguments of the dataset Exceptions: AssertError: if the dataset is not available Exception: if an error is detected in the dataset Returns: the continual dataset class """ names = get_dataset_names() assert args.dataset in names if isinstance(names[args.dataset], Exception): raise names[args.dataset] if return_args: return names[args.dataset]['class'], names[args.dataset]['parsable_args'] return names[args.dataset]['class']
[docs] def get_dataset(args: Namespace) -> ContinualDataset: """ Creates and returns a continual dataset among those that are available. If an error was detected while loading the available datasets, it raises the appropriate error message. Args: args (Namespace): the arguments which contains the hyperparameters Exceptions: AssertError: if the dataset is not available Exception: if an error is detected in the dataset Returns: the continual dataset instance """ dataset_class, dataset_args = get_dataset_class(args, return_args=True) missing_args = [arg for arg in dataset_args.keys() if arg not in vars(args)] assert len(missing_args) == 0, "Missing arguments for the dataset: " + ', '.join(missing_args) parsed_args = {arg: getattr(args, arg) for arg in dataset_args.keys()} return dataset_class(args, **parsed_args)
# import all files in the `datasets` folder to register the datasets for file in os.listdir(os.path.dirname(__file__)): if file.endswith('.py') and file != '__init__.py': importlib.import_module(f'datasets.{file[:-3]}')