Creating Custom Datasets#
This example shows how to create custom continual learning datasets in Mammoth Lite. You’ll learn to implement your own benchmarks, register them with the framework, and use them in experiments.
Note
This example is based on the Jupyter notebook examples/notebooks/create_a_dataset.ipynb
. You can run it interactively or follow along here.
Learning Objectives#
By the end of this example, you’ll understand:
The two-layer structure of datasets in Mammoth Lite
How to create a data source class inheriting from
MammothDataset
How to create a continual dataset class inheriting from
ContinualDataset
How to register custom datasets with the framework
How to configure data transformations and task structures
Understanding Dataset Architecture#
Mammoth Lite uses a two-layer architecture for datasets:
- Layer 1: Data Source (MammothDataset)
Handles loading and accessing individual data samples. Similar to PyTorch’s
Dataset
class but with continual learning requirements.- Layer 2: Continual Dataset (ContinualDataset)
Defines the continual learning scenario: how data is split into tasks, transformations, and evaluation protocol.
This separation allows you to:
Reuse the same data source for different continual learning scenarios
Easily modify task structures without changing data loading logic
Support different evaluation protocols on the same underlying data
Data Source Requirements#
The data source class must provide:
Required Attributes:
* data
: Training/testing data samples
* targets
: Corresponding labels
* not_aug_transform
: Transformation without data augmentation
Required Methods:
* __len__()
: Return number of samples
* __getitem__(index)
: Return (augmented_image, label, original_image)
The triple return format is crucial for continual learning algorithms that need both augmented data for training and original data for memory storage (rehearsal).
Creating a Custom Data Source#
Let’s create a custom CIFAR-10 data source that meets Mammoth Lite’s requirements:
from mammoth_lite import MammothDataset
from torchvision.datasets import CIFAR10
from torchvision import transforms
from PIL import Image
class MammothCIFAR10(MammothDataset, CIFAR10):
"""
Custom CIFAR-10 dataset that returns the required triple format.
Inherits from both MammothDataset (for continual learning features)
and CIFAR10 (for data loading functionality).
"""
def __init__(self, root, is_train=True, transform=None):
"""
Initialize the dataset.
Args:
root: Directory to store/load data
is_train: Whether to load training or test set
transform: Data augmentation transforms
"""
self.root = root
# Download data if not present (avoid debug messages)
super().__init__(root, is_train, transform,
download=not self._check_integrity())
def __getitem__(self, index):
"""
Get a sample from the dataset.
Returns:
tuple: (augmented_image, label, original_image)
"""
# Get raw data and label
img, target = self.data[index], self.targets[index]
# Convert numpy array to PIL Image for transformations
img = Image.fromarray(img, mode='RGB')
original_img = img.copy() # Important: copy to avoid modification
# Apply non-augmenting transform to original image
not_aug_img = self.not_aug_transform(original_img)
# Apply augmenting transforms if specified
if self.transform is not None:
img = self.transform(img)
return img, target, not_aug_img
Key Implementation Details#
- Multiple Inheritance
Inheriting from both
MammothDataset
andCIFAR10
gives you continual learning features and CIFAR-10 data loading.- Image Copying
original_img = img.copy()
is crucial - without this, transforms would modify the original image.- Transform Application
Apply
not_aug_transform
to get normalized but non-augmented images for memory storage.- Triple Return Format
Always return
(augmented_image, label, original_image)
from__getitem__
.
Creating a Custom Continual Dataset#
Now create the continual learning scenario definition:
from mammoth_lite import register_dataset, ContinualDataset, base_path
from torchvision import transforms
@register_dataset(name='custom-cifar10')
class CustomSeqCifar10(ContinualDataset):
"""
Custom Sequential CIFAR-10 continual learning benchmark.
Splits CIFAR-10 into 5 tasks of 2 classes each.
"""
# Required dataset metadata
NAME = 'custom-cifar10'
SETTING = 'class-il' # Class-incremental learning
N_CLASSES_PER_TASK = 2 # 2 classes per task
N_TASKS = 5 # 5 total tasks
# Dataset statistics for normalization
MEAN = (0.4914, 0.4822, 0.4465)
STD = (0.2470, 0.2435, 0.2615)
# Training transformations (with data augmentation)
TRANSFORM = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(MEAN, STD)
])
# Test transformations (no augmentation)
TEST_TRANSFORM = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(MEAN, STD)
])
def get_data_loaders(self):
"""
Create and return train and test datasets.
Returns:
tuple: (train_dataset, test_dataset)
"""
# Create datasets using our custom data source
train_dataset = MammothCIFAR10(
root=base_path() + 'CIFAR10',
is_train=True,
transform=self.TRANSFORM
)
test_dataset = MammothCIFAR10(
root=base_path() + 'CIFAR10',
is_train=False,
transform=self.TEST_TRANSFORM
)
return train_dataset, test_dataset
@staticmethod
def get_backbone():
"""
Specify the default backbone architecture.
Returns:
str: Name of registered backbone
"""
return "resnet18"
Required Attributes Explained#
NAME: Unique identifier for your dataset (by default, will use the same name as in the @register_dataset
decorator)
- SETTING: Continual learning scenario type:
'class-il'
: Class-incremental learning'task-il'
: Task-incremental learning… other settings are available in the full Mammoth framework or can be defined as needed
N_CLASSES_PER_TASK: Number of classes introduced in each task
N_TASKS: Total number of sequential tasks
MEAN/STD: Dataset statistics for proper normalization
TRANSFORM/TEST_TRANSFORM: Data augmentation for training/testing
Required Methods#
get_data_loaders(): Returns train and test dataset instances
get_backbone(): Specifies default neural network architecture
Testing Your Custom Dataset#
Once defined, use your custom dataset like any built-in benchmark:
from mammoth_lite import load_runner, train
# Load SGD model with your custom dataset
model, dataset = load_runner(
model='sgd',
dataset='custom-cifar10', # Use your custom dataset
args={
'lr': 0.1,
'n_epochs': 1,
'batch_size': 32
}
)
# Verify dataset properties
print(f"Dataset: {dataset.NAME}")
print(f"Setting: {dataset.SETTING}")
print(f"Tasks: {dataset.N_TASKS}")
print(f"Classes per task: {dataset.N_CLASSES_PER_TASK}")
# Run training
train(model, dataset)
Expected Output:
Dataset: custom-cifar10
Setting: class-il
Tasks: 5
Classes per task: 2
Task 1: 100%|██████████| 313/313 [01:18<00:00, 19.95it/s]
Accuracy for task 1 [Class-IL]: 87.60% [Task-IL]: 87.60%
Data Augmentation Strategies#
Different Augmentations per Task#
class TaskSpecificTransforms(ContinualDataset):
"""
Apply different augmentations for different tasks.
"""
def get_transforms(self, task_id):
"""Return task-specific transforms."""
base_transform = [transforms.ToTensor(),
transforms.Normalize(self.MEAN, self.STD)]
if task_id == 0:
# Minimal augmentation for first task
return transforms.Compose([
transforms.RandomHorizontalFlip(0.3),
*base_transform
])
else:
# Stronger augmentation for later tasks
return transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(0.1, 0.1, 0.1, 0.1),
*base_transform
])
Progressive Difficulty#
class ProgressiveDifficultyDataset(ContinualDataset):
"""
Increase task difficulty over time.
"""
def get_data_loaders(self, task_id):
"""Return data loaders with task-specific difficulty."""
# Base dataset
train_data = self.load_base_data(train=True)
test_data = self.load_base_data(train=False)
# Apply task-specific modifications
if task_id > 0:
# Add noise, blur, or other corruptions
corruption_level = task_id * 0.1
train_data = self.add_corruption(train_data, corruption_level)
test_data = self.add_corruption(test_data, corruption_level)
return train_data, test_data
Validation and Testing#
Dataset Validation#
def validate_dataset():
"""Test that custom dataset works correctly."""
# Test loading
model, dataset = load_runner('sgd', 'custom-cifar10', {'n_epochs': 1})
# Validate properties
assert dataset.N_TASKS == 5
assert dataset.N_CLASSES_PER_TASK == 2
assert dataset.SETTING == 'class-il'
# Test data loading
train_loader, test_loader = dataset.get_data_loaders()
# Check sample format
sample = next(iter(train_loader))
assert len(sample) == 3 # (img, label, not_aug_img)
print("✓ Dataset validation passed")
validate_dataset()
Next Steps#
Now that you can create custom datasets:
Explore Different Scenarios: Try domain-incremental or task-incremental settings
Create Complex Benchmarks: Combine multiple datasets or add synthetic corruptions
Custom Backbones: Learn to design architectures in Creating Custom Backbones
Share Your Work: Contribute interesting datasets to the community
Custom datasets enable you to test continual learning algorithms on your specific problem domains and research questions!