Source code for models.dap

"""
On the reproducibility of DAP
-----------------------------

The original implementation of DAP is available at `https://github.com/naver-ai/dap-cl` and features:
 - a custom backbone, available with the `--load_original_checkpoint` flag
 - majority voting during test time, available with the `--enable_test_time_majority_voting` flag
 - custom splits for many datasets. Specifically: `imagenet-r`, `chestx`, `eurosat-rgb`, `isic`, and `resisc45` use a validation set of 20% from the training set as the test set, while other datasets use their original splits. This differs from the Mammoth implementation, which follows that of other commonly used works (e.g, `eurosat-rgb` uses the split of CoOp).
"""

import logging
import os
import numpy as np
import torch
import torch.nn as nn
from argparse import ArgumentParser

from datasets.utils.continual_dataset import ContinualDataset
from models.utils.continual_model import ContinualModel

from models.dap_utils.dap_model import DAPModel
from utils import binary_to_boolean_type


[docs] class DAP(ContinualModel): """Generating Instance-level Prompts for Rehearsal-free Continual Learning.""" NAME = 'dap' COMPATIBILITY = ['class-il', 'task-il'] net: DAPModel
[docs] @staticmethod def get_parser(parser) -> ArgumentParser: parser.set_defaults(optimizer='adam') parser.add_argument('--sim_lambda', type=float, default=0.1) # 2 imgr, 0.5 resisc, 0.1 default parser.add_argument("--virtual_bs_n", type=int, default=1, help="virtual batch size iterations") parser.add_argument("--enable_test_time_majority_voting", type=int, default=0, help="Enable majority voting for selecting the prompts during test time. NOTE: " "This should be avoided as it is not a fair comparison with other methods.") parser.add_argument('--task_emb', type=int, default=16, help='task embedding size') parser.add_argument('--num_dap_tokens', type=int, default=10, help='number of dap tokens') parser.add_argument('--load_original_checkpoint', type=binary_to_boolean_type, default=0, help='load original checkpoint. This requires the file `imagenet21k_ViT-B_16.npz` to be ' 'present in the ./data directory. You can download it following the instructions in ' 'https://github.com/naver-ai/dap-cl') return parser
def __init__(self, backbone, loss, args, transform, dataset=None): if args.enable_test_time_majority_voting: logging.warning("Majority voting is enabled during test time. The results will not be a fair comparison with other methods.") if args.load_original_checkpoint and not os.path.exists('./data/imagenet21k_ViT-B_16.npz'): raise FileNotFoundError('`imagenet21k_ViT-B_16.npz` not found in ./data directory. Please follow the instructions in ' 'https://github.com/naver-ai/dap-cl to download the file.') super(DAP, self).__init__(backbone, loss, args, transform, dataset=dataset) self.net = DAPModel(backbone=self.net, n_tasks=self.n_tasks, num_classes=self.num_classes, args=args, device=args.device) self.opt = self.get_optimizer()
[docs] def get_optimizer(self): # check if optimizer is in torch.optim _param_groups = [{'params': p, 'lr': self.args.lr} for p in self.get_parameters()] opt = torch.optim.Adam(_param_groups, lr=self.args.lr, weight_decay=self.args.optim_wd, betas=(0.9, 0.9), amsgrad=True) if opt is None: raise ValueError('Unknown optimizer: {}'.format(self.args.optimizer)) return opt
[docs] def get_parameters(self): return [p for p in self.net.parameters() if p.requires_grad]
[docs] def begin_task(self, dataset: ContinualDataset) -> None: if self.current_task == 1: # freeze layer after first task for k, p in self.net.enc.named_parameters(): if "dap_downsample" in k: p.requires_grad = False # dataset.train_loader.dataset.transform = self.dataset.TEST_TRANSFORM # transforms.Compose([transforms.ToTensor(), self.dataset.get_normalization_transform()]) self.net.train() self.opt = self.get_optimizer()
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: return super().forward(x)[:, :self.n_seen_classes]
[docs] def observe(self, inputs, labels, not_aug_inputs, epoch=None): self.opt.zero_grad() if self.args.dataset == 'seq-imagenet-r': outputs, reduce_sim = self.net(inputs, task_id=self.current_task, is_train=True, n_past_classes=self.n_past_classes, n_cur_classes=self.n_cur_classes, is_imgr=True) else: outputs, reduce_sim = self.net(inputs, task_id=self.current_task, is_train=True) outputs[:, :self.n_past_classes] = -np.inf outputs[:, self.n_seen_classes:] = -np.inf # outputs = outputs[:, self.n_past_classes:self.n_seen_classes] loss = self.loss(outputs, labels) loss -= self.args.sim_lambda * reduce_sim if self.epoch_iteration == 0: self.opt.zero_grad() (loss / self.args.virtual_bs_n).backward() if (self.epoch_iteration > 0 or self.args.virtual_bs_n == 1) and \ self.epoch_iteration % self.args.virtual_bs_n == 0: nn.utils.clip_grad_norm_(self.net.parameters(), 1) self.opt.step() self.opt.zero_grad() return loss.item()