SEQ DTD#

Classes#

class datasets.seq_dtd.MyDTD(root, split='train', transform=None, target_transform=None, download=False)[source]#

Bases: DTD

Custom DTD dataset that returns both augmented and non-augmented images.

class datasets.seq_dtd.SequentialDTD(args, transform_type='weak')[source]#

Bases: ContinualDataset

The Sequential CIFAR100 dataset with 224x224 resolution with ViT-B/16.

Parameters:
  • NAME (str) – name of the dataset.

  • SETTING (str) – setting of the dataset.

  • N_CLASSES_PER_TASK (int) – number of classes per task.

  • N_TASKS (int) – number of tasks.

  • N_CLASSES (int) – number of classes.

  • SIZE (tuple) – size of the images.

  • MEAN (tuple) – mean of the dataset.

  • STD (tuple) – standard deviation of the dataset.

  • TRANSFORM (torchvision.transforms) – transformation to apply to the data.

  • TEST_TRANSFORM (torchvision.transforms) – transformation to apply to the test data.

MEAN = (0.48145466, 0.4578275, 0.40821073)#
NAME: str = 'seq-dtd'#
N_CLASSES: int = 47#
N_CLASSES_PER_TASK: int = [10, 10, 10, 10, 7]#
N_TASKS: int = 5#
SETTING: str = 'class-il'#
SIZE: Tuple[int] = (224, 224)#
STD = (0.26862954, 0.26130258, 0.27577711)#
TEST_TRANSFORM = Compose(     Resize(size=256, interpolation=bilinear, max_size=None, antialias=True)     CenterCrop(size=(224, 224))     ToTensor()     Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) )#
TRANSFORM = Compose(     RandomResizedCrop(size=(224, 224), scale=(0.8, 1.0), ratio=(0.75, 1.3333), interpolation=bilinear, antialias=True)     RandomHorizontalFlip(p=0.5)     ColorJitter(brightness=(0.8, 1.2), contrast=(0.8, 1.2), saturation=(0.8, 1.2), hue=(-0.1, 0.1))     ToTensor()     Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) )#
get_backbone()[source]#
get_batch_size()[source]#
get_class_names()[source]#
get_data_loaders()[source]#
Return type:

Tuple[DataLoader, DataLoader]

static get_denormalization_transform()[source]#
get_epochs()[source]#
static get_loss()[source]#
static get_normalization_transform()[source]#
static get_prompt_templates()[source]#
static get_transform()[source]#