Source code for datasets.seq_tinyimagenet

# Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import logging
import os
from typing import Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import Dataset

from backbone.ResNetBlock import resnet18
from datasets.transforms.denormalization import DeNormalize
from datasets.utils.continual_dataset import (ContinualDataset, fix_class_names_order,
                                              store_masked_loaders)
from utils import smart_joint
from utils.conf import base_path
from datasets.utils import set_default_from_args


[docs] class TinyImagenet(Dataset): """Defines the Tiny Imagenet dataset.""" def __init__(self, root: str, train: bool = True, transform: Optional[nn.Module] = None, target_transform: Optional[nn.Module] = None, download: bool = False) -> None: self.not_aug_transform = transforms.Compose([transforms.ToTensor()]) self.root = root self.train = train self.transform = transform self.target_transform = target_transform self.download = download if download: if os.path.isdir(root) and len(os.listdir(root)) > 0: logging.info('Download not needed, files already on disk.') else: from onedrivedownloader import download logging.info('Downloading dataset') ln = "https://unimore365-my.sharepoint.com/:u:/g/personal/263133_unimore_it/EVKugslStrtNpyLGbgrhjaABqRHcE3PB_r2OEaV7Jy94oQ?e=9K29aD" download(ln, filename=smart_joint(root, 'tiny-imagenet-processed.zip'), unzip=True, unzip_path=root, clean=True) self.data = [] for num in range(20): self.data.append(np.load(smart_joint( root, 'processed/x_%s_%02d.npy' % ('train' if self.train else 'val', num + 1)))) self.data = np.concatenate(np.array(self.data)) self.targets = [] for num in range(20): self.targets.append(np.load(smart_joint( root, 'processed/y_%s_%02d.npy' % ('train' if self.train else 'val', num + 1)))) self.targets = np.concatenate(np.array(self.targets)) def __len__(self): return len(self.data) def __getitem__(self, index): img, target = self.data[index], self.targets[index] # doing this so that it is consistent with all other datasets # to return a PIL Image img = Image.fromarray(np.uint8(255 * img)) original_img = img.copy() if self.transform is not None: img = self.transform(img) if self.target_transform is not None: target = self.target_transform(target) if hasattr(self, 'logits'): return img, target, original_img, self.logits[index] return img, target
[docs] class MyTinyImagenet(TinyImagenet): """Overrides the TinyImagenet dataset to change the getitem function.""" def __init__(self, root: str, train: bool = True, transform: Optional[nn.Module] = None, target_transform: Optional[nn.Module] = None, download: bool = False) -> None: super(MyTinyImagenet, self).__init__( root, train, transform, target_transform, download) def __getitem__(self, index): img, target = self.data[index], self.targets[index] # doing this so that it is consistent with all other datasets # to return a PIL Image img = Image.fromarray(np.uint8(255 * img)) original_img = img.copy() not_aug_img = self.not_aug_transform(original_img) if self.transform is not None: img = self.transform(img) if self.target_transform is not None: target = self.target_transform(target) if hasattr(self, 'logits'): return img, target, not_aug_img, self.logits[index] return img, target, not_aug_img
[docs] class SequentialTinyImagenet(ContinualDataset): """The Sequential Tiny Imagenet dataset. Args: NAME (str): name of the dataset. SETTING (str): setting of the dataset. N_CLASSES_PER_TASK (int): number of classes per task. N_TASKS (int): number of tasks. N_CLASSES (int): number of classes. SIZE (tuple): size of the images. MEAN (tuple): mean of the dataset. STD (tuple): standard deviation of the dataset. TRANSFORM (torchvision.transforms): transformations to apply to the dataset. """ NAME = 'seq-tinyimg' SETTING = 'class-il' N_CLASSES_PER_TASK = 20 N_TASKS = 10 N_CLASSES = N_CLASSES_PER_TASK * N_TASKS MEAN, STD = (0.4802, 0.4480, 0.3975), (0.2770, 0.2691, 0.2821) SIZE = (64, 64) TRANSFORM = transforms.Compose( [transforms.RandomCrop(64, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(MEAN, STD)])
[docs] def get_data_loaders(self) -> Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader]: transform = self.TRANSFORM test_transform = transforms.Compose( [transforms.ToTensor(), self.get_normalization_transform()]) train_dataset = MyTinyImagenet(base_path() + 'TINYIMG', train=True, download=True, transform=transform) test_dataset = TinyImagenet(base_path() + 'TINYIMG', train=False, download=True, transform=test_transform) train, test = store_masked_loaders(train_dataset, test_dataset, self) return train, test
[docs] @set_default_from_args("backbone") def get_backbone(): return "resnet18"
[docs] @staticmethod def get_loss(): return F.cross_entropy
[docs] def get_transform(self): transform = transforms.Compose( [transforms.ToPILImage(), self.TRANSFORM]) return transform
[docs] @staticmethod def get_normalization_transform(): transform = transforms.Normalize(SequentialTinyImagenet.MEAN, SequentialTinyImagenet.STD) return transform
[docs] @staticmethod def get_denormalization_transform(): transform = DeNormalize(SequentialTinyImagenet.MEAN, SequentialTinyImagenet.STD) return transform
[docs] @set_default_from_args('n_epochs') def get_epochs(self): return 50
[docs] @set_default_from_args('batch_size') def get_batch_size(self): return 32
[docs] def get_class_names(self): if self.class_names is not None: return self.class_names classes = fix_class_names_order(CLASS_NAMES, self.args) self.class_names = classes return self.class_names
CLASS_NAMES = [ 'egyptian_cat', 'reel', 'volleyball', 'rocking_chair', 'lemon', 'bullfrog', 'basketball', 'cliff', 'espresso', 'plunger', 'parking_meter', 'german_shepherd', 'dining_table', 'monarch', 'brown_bear', 'school_bus', 'pizza', 'guinea_pig', 'umbrella', 'organ', 'oboe', 'maypole', 'goldfish', 'potpie', 'hourglass', 'seashore', 'computer_keyboard', 'arabian_camel', 'ice_cream', 'nail', 'space_heater', 'cardigan', 'baboon', 'snail', 'coral_reef', 'albatross', 'spider_web', 'sea_cucumber', 'backpack', 'labrador_retriever', 'pretzel', 'king_penguin', 'sulphur_butterfly', 'tarantula', 'lesser_panda', 'pop_bottle', 'banana', 'sock', 'cockroach', 'projectile', 'beer_bottle', 'mantis', 'freight_car', 'guacamole', 'remote_control', 'european_fire_salamander', 'lakeside', 'chimpanzee', 'pay-phone', 'fur_coat', 'alp', 'lampshade', 'torch', 'abacus', 'moving_van', 'barrel', 'tabby', 'goose', 'koala', 'bullet_train', 'cd_player', 'teapot', 'birdhouse', 'gazelle', 'academic_gown', 'tractor', 'ladybug', 'miniskirt', 'golden_retriever', 'triumphal_arch', 'cannon', 'neck_brace', 'sombrero', 'gasmask', 'candle', 'desk', 'frying_pan', 'bee', 'dam', 'spiny_lobster', 'police_van', 'ipod', 'punching_bag', 'beacon', 'jellyfish', 'wok', "potter's_wheel", 'sandal', 'pill_bottle', 'butcher_shop', 'slug', 'hog', 'cougar', 'crane', 'vestment', 'dragonfly', 'cash_machine', 'mushroom', 'jinrikisha', 'water_tower', 'chest', 'snorkel', 'sunglasses', 'fly', 'limousine', 'black_stork', 'dugong', 'sports_car', 'water_jug', 'suspension_bridge', 'ox', 'ice_lolly', 'turnstile', 'christmas_stocking', 'broom', 'scorpion', 'wooden_spoon', 'picket_fence', 'rugby_ball', 'sewing_machine', 'steel_arch_bridge', 'persian_cat', 'refrigerator', 'barn', 'apron', 'yorkshire_terrier', 'swimming_trunks', 'stopwatch', 'lawn_mower', 'thatch', 'fountain', 'black_widow', 'bikini', 'plate', 'teddy', 'barbershop', 'confectionery', 'beach_wagon', 'scoreboard', 'orange', 'flagpole', 'american_lobster', 'trolleybus', 'drumstick', 'dumbbell', 'brass', 'bow_tie', 'convertible', 'bighorn', 'orangutan', 'american_alligator', 'centipede', 'syringe', 'go-kart', 'brain_coral', 'sea_slug', 'cliff_dwelling', 'mashed_potato', 'viaduct', 'military_uniform', 'pomegranate', 'chain', 'kimono', 'comic_book', 'trilobite', 'bison', 'pole', 'boa_constrictor', 'poncho', 'bathtub', 'grasshopper', 'walking_stick', 'chihuahua', 'tailed_frog', 'lion', 'altar', 'obelisk', 'beaker', 'bell_pepper', 'bannister', 'bucket', 'magnetic_compass', 'meat_loaf', 'gondola', 'standard_poodle', 'acorn', 'lifeboat', 'binoculars', 'cauliflower', 'african_elephant' ]