# Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from argparse import Namespace
from copy import deepcopy
from typing import Callable, List, Tuple

import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.datasets import MNIST

from datasets.perm_mnist import MyMNIST
from datasets.transforms.rotation import IncrementalRotation
from datasets.utils.gcl_dataset import GCLDataset
from datasets.utils.validation import get_train_val
from utils.conf import base_path, create_seeded_dataloader
from datasets.utils import set_default_from_args

[docs] def custom_collate_unbatch(batch) -> List[torch.Tensor]: """ Custom collate function to unbatch a batch of data. Args: batch (list): A list of tensors representing a batch of data. Returns: list: A list of tensors, where each tensor is unbatched from the input batch. """ return [b.squeeze(0) for b in]
[docs] class MNIST360( """A custom dataset class for MNIST360 that provides training and testing data with incremental rotation for each class. Args: args (object): An object containing the arguments for the dataset. is_train (bool): A flag indicating whether the dataset is for training or testing. Attributes: N_CLASSES (int): The number of classes in the dataset. dataset (list): A list of data loaders for each class. remaining_training_items (list): A list of the remaining training items for each class. num_rounds (int): The number of rounds for each class. args (object): An object containing the arguments for the dataset. is_train (bool): A flag indicating whether the dataset is for training or testing. is_over (bool): A flag indicating whether the dataset is completed. completed_rounds (int): The number of completed rounds. test_class (int): The current test class index. test_iteration (int): The current test iteration index. train_classes (list): A list of the current training classes. active_train_loaders (list): A list of the active training data loaders. current_items (int): The current number of items in the dataset. """ N_CLASSES = 9 def __init__(self, args: Namespace, is_train: bool = False) -> None: super().__init__() self.num_rounds = 3 self.args = args self.is_train = is_train self.reinit()
[docs] def train_next_class(self) -> None: """ Changes the couple of current training classes. """ self.train_classes[0] += 1 self.train_classes[1] += 1 if self.train_classes[0] == self.N_CLASSES: self.train_classes[0] = 0 if self.train_classes[1] == self.N_CLASSES: self.train_classes[1] = 0 if self.train_classes[0] == 0: self.completed_rounds += 1 if self.completed_rounds == 3: self.is_over = True if not self.is_over: self.active_train_loaders = [ self.dataset[self.train_classes[0]].pop(), self.dataset[self.train_classes[1]].pop()] self.active_remaining_items = [ self.remaining_training_items[self.train_classes[0]].pop(), self.remaining_training_items[self.train_classes[1]].pop()]
[docs] def init_train_loaders(self) -> None: """ Initializes the train loader. """ self.remaining_training_items = [] self.dataset = [] train_dataset = MyMNIST(base_path() + 'MNIST', train=True, download=True) if self.args.validation: test_transform = transforms.ToTensor() train_dataset, self.val_dataset = get_train_val( train_dataset, test_transform, 'mnist-360', val_perc=self.args.validation / 100) for j in range(self.N_CLASSES): self.dataset.append([]) self.remaining_training_items.append([]) train_mask = np.isin(np.array(train_dataset.targets), [j]) train_rotation = IncrementalRotation(init_deg=(j - 1) * 60, increase_per_iteration=360.0 / train_mask.sum()) for k in range(self.num_rounds * 2): tmp_train_dataset = deepcopy(train_dataset) numbers_per_batch = train_mask.sum() // (self.num_rounds * 2) + 1 =[ train_mask][k * numbers_per_batch:(k + 1) * numbers_per_batch] tmp_train_dataset.targets = tmp_train_dataset.targets[ train_mask][k * numbers_per_batch:(k + 1) * numbers_per_batch] tmp_train_dataset.transform = transforms.Compose( [train_rotation, transforms.ToTensor()]) self.dataset[-1].append(create_seeded_dataloader(self.args, tmp_train_dataset, non_verbose=True, batch_size=1, shuffle=True, num_workers=0)) self.remaining_training_items[-1].append([0])
[docs] def init_test_loaders(self) -> None: """ Initializes the test loader. """ self.remaining_training_items = [] self.dataset = [] if self.args.validation: test_transform = transforms.ToTensor() train_dataset = MyMNIST(base_path() + 'MNIST', train=True, download=True) _, test_dataset = get_train_val( train_dataset, test_transform, 'mnist-360', val_perc=self.args.validation) else: test_dataset = MNIST(base_path() + 'MNIST', train=False, download=True) for j in range(self.N_CLASSES): tmp_test_dataset = deepcopy(test_dataset) test_mask = np.isin(np.array(tmp_test_dataset.targets), [j]) =[test_mask] tmp_test_dataset.targets = tmp_test_dataset.targets[test_mask] test_rotation = IncrementalRotation( increase_per_iteration=360.0 / test_mask.sum()) tmp_test_dataset.transform = transforms.Compose( [test_rotation, transforms.ToTensor()]) self.dataset.append(create_seeded_dataloader(self.args, tmp_test_dataset, non_verbose=True, batch_size=self.args.batch_size, shuffle=False, num_workers=0))
[docs] def get_train_data(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Ensembles the next examples of the current classes in a single batch. Returns: Tensor: The batch of examples. Tensor: The labels of the examples. Tensor: The batch of examples without augmentation. """ batch_size_0 = min(int(round(self.active_remaining_items[0] / (self.active_remaining_items[0] + self.active_remaining_items[1]) * self.args.batch_size)), self.active_remaining_items[0]) batch_size_1 = min(self.args.batch_size - batch_size_0, self.active_remaining_items[1]) x_train, y_train, x_train_naug = [], [], [] for j in range(batch_size_0): i_x_train, i_y_train, i_x_train_naug = next(iter( self.active_train_loaders[0])) x_train.append(i_x_train) y_train.append(i_y_train) x_train_naug.append(i_x_train_naug) for j in range(batch_size_1): i_x_train, i_y_train, i_x_train_naug = next(iter( self.active_train_loaders[1])) x_train.append(i_x_train) y_train.append(i_y_train) x_train_naug.append(i_x_train_naug) x_train, y_train, x_train_naug =, \, self.active_remaining_items[0] -= batch_size_0 self.active_remaining_items[1] -= batch_size_1 if self.active_remaining_items[0] <= 0 or \ self.active_remaining_items[1] <= 0: self.train_next_class() return x_train, y_train, x_train_naug
[docs] def get_test_data(self) -> Tuple[torch.Tensor, torch.Tensor]: """Ensembles the next examples of the current class in a batch. Returns: Tensor: The batch of examples. Tensor: The labels of the examples. """ x_test, y_test = next(iter(self.dataset[self.test_class])) residual_items = len(self.dataset[self.test_class].dataset) - \ self.test_iteration * self.args.batch_size - len(x_test) self.test_iteration += 1 if residual_items <= 0: if residual_items < 0: x_test = x_test[:residual_items] y_test = y_test[:residual_items] self.test_iteration = 0 self.test_class += 1 if self.test_class == self.N_CLASSES: self.is_over = True return x_test, y_test
[docs] def reinit(self) -> None: self.is_over = False self.completed_rounds, self.test_class, self.test_iteration = 0, 0, 0 self.train_classes = [0, 1] if self.is_train: self.init_train_loaders() else: self.init_test_loaders() if self.is_train: self.active_train_loaders = [ self.dataset[self.train_classes[0]].pop(), self.dataset[self.train_classes[1]].pop()] self.active_remaining_items = [ self.remaining_training_items[self.train_classes[0]].pop(), self.remaining_training_items[self.train_classes[1]].pop()]
def __iter__(self): self.reinit() return self def __next__(self): if self.is_over: raise StopIteration if self.is_train: return self.get_train_data() else: return self.get_test_data()
[docs] class SequentialMNIST360(GCLDataset): """ A dataset class for the MNIST-360 dataset in the context of general-continual learning. Attributes: NAME (str): The name of the dataset. SETTING (str): The setting of the dataset. N_CLASSES (int): The number of classes in the dataset. TRANSFORM (torch.nn.Module): The transformation to apply to the data. SIZE (tuple): The size of the input images. args (Namespace): An object containing the arguments for the dataset. """ NAME = 'mnist-360' SETTING = 'general-continual' N_CLASSES = 9 TRANSFORM = torch.nn.Identity() SIZE = (28, 28) def __init__(self, args: Namespace) -> None: super().__init__(args) self.args = args assert args.label_perc == 1, "MNIST-360 does not support partial labels."
[docs] def get_data_loaders(self) -> Tuple[,]: """ Get the data loaders for the MNIST360 dataset, add them to the current object and return them. Returns: train_loader ( DataLoader for the training dataset. test_loader ( DataLoader for the test dataset. """ train_dataset = MNIST360(self.args, is_train=True) test_dataset = MNIST360(self.args, is_train=False) # dataset is already shuffled and batched internally - no need for a dataloader self.test_loaders.append(test_dataset) self.train_loader = train_dataset return train_dataset, test_dataset
[docs] @set_default_from_args("backbone") def get_backbone() -> torch.nn.Module: return "mnistmlp"
[docs] @staticmethod def get_loss() -> Callable: return F.cross_entropy
[docs] @staticmethod def get_transform(): return None
[docs] @staticmethod def get_normalization_transform(): return None
[docs] @staticmethod def get_denormalization_transform(): return None
[docs] @set_default_from_args('batch_size') def get_batch_size(self) -> int: return 16
[docs] @set_default_from_args('n_epochs') def get_epochs(self): return 1
if __name__ == "__main__": ds = SequentialMNIST360(Namespace(validation=False, label_perc=1, n_epochs=1, batch_size=16, permute_classes=False, joint=False, num_workers=0, seed=None)) train, test = ds.get_data_loaders() # load all data and save it in results/tmp import os import torchvision from PIL import Image from tqdm import tqdm os.makedirs('../data/results/mnist360images/tmp/train', exist_ok=True) for i, (x, y, _) in tqdm(enumerate(train), total=len(train)): torchvision.utils.save_image(x, f'../data/results/mnist360images/tmp/train/{i}.png') os.makedirs('../data/results/mnist360images/tmp/test', exist_ok=True) for i, (x, y) in tqdm(enumerate(test), total=len(test)): torchvision.utils.save_image(x, f'../data/results/mnist360images/tmp/test/{i}.png')