MNIST 360#

Classes#

class datasets.mnist_360.MNIST360(args, is_train=False)[source]#

Bases: Dataset

A custom dataset class for MNIST360 that provides training and testing data with incremental rotation for each class.

Parameters:
  • args (object) – An object containing the arguments for the dataset.

  • is_train (bool) – A flag indicating whether the dataset is for training or testing.

N_CLASSES#

The number of classes in the dataset.

Type:

int

dataset#

A list of data loaders for each class.

Type:

list

remaining_training_items#

A list of the remaining training items for each class.

Type:

list

num_rounds#

The number of rounds for each class.

Type:

int

args#

An object containing the arguments for the dataset.

Type:

object

is_train#

A flag indicating whether the dataset is for training or testing.

Type:

bool

is_over#

A flag indicating whether the dataset is completed.

Type:

bool

completed_rounds#

The number of completed rounds.

Type:

int

test_class#

The current test class index.

Type:

int

test_iteration#

The current test iteration index.

Type:

int

train_classes#

A list of the current training classes.

Type:

list

active_train_loaders#

A list of the active training data loaders.

Type:

list

current_items#

The current number of items in the dataset.

Type:

int

N_CLASSES = 9#
get_test_data()[source]#

Ensembles the next examples of the current class in a batch.

Returns:

The batch of examples. Tensor: The labels of the examples.

Return type:

Tensor

get_train_data()[source]#

Ensembles the next examples of the current classes in a single batch.

Returns:

The batch of examples.

Tensor: The labels of the examples.

Tensor: The batch of examples without augmentation.

Return type:

Tensor

init_test_loaders()[source]#

Initializes the test loader.

init_train_loaders()[source]#

Initializes the train loader.

reinit()[source]#
train_next_class()[source]#

Changes the couple of current training classes.

class datasets.mnist_360.SequentialMNIST360(args)[source]#

Bases: GCLDataset

A dataset class for the MNIST-360 dataset in the context of general-continual learning.

NAME#

The name of the dataset.

Type:

str

SETTING#

The setting of the dataset.

Type:

str

N_CLASSES#

The number of classes in the dataset.

Type:

int

TRANSFORM#

The transformation to apply to the data.

Type:

torch.nn.Module

SIZE#

The size of the input images.

Type:

tuple

args#

An object containing the arguments for the dataset.

Type:

Namespace

NAME: str = 'mnist-360'#
N_CLASSES: int = 9#
SETTING: str = 'general-continual'#
SIZE: Tuple[int] = (28, 28)#
TRANSFORM = Identity()#
get_backbone()[source]#
Return type:

Module

get_batch_size()[source]#
Return type:

int

get_data_loaders()[source]#

Get the data loaders for the MNIST360 dataset, add them to the current object and return them.

Returns:

DataLoader for the training dataset.

test_loader (torch.utils.data.DataLoader): DataLoader for the test dataset.

Return type:

train_loader (torch.utils.data.DataLoader)

static get_denormalization_transform()[source]#
get_epochs()[source]#
static get_loss()[source]#
Return type:

Callable

static get_normalization_transform()[source]#
static get_transform()[source]#

Functions#

datasets.mnist_360.custom_collate_unbatch(batch)[source]#

Custom collate function to unbatch a batch of data.

Parameters:

batch (list) – A list of tensors representing a batch of data.

Returns:

A list of tensors, where each tensor is unbatched from the input batch.

Return type:

list