Source code for models.attriclip

"""
DISCLAIMER: AttriCLIP **does not** reproduce the results in the paper (https://arxiv.org/pdf/2305.11488).
Unfortunately, the original implementation (https://github.com/bhrqw/AttriCLIP) did not reproduced the results either and is no longer available. This is a known issue (see https://github.com/bhrqw/SADA/issues/3).

This implementation is based on that code and on the information provided in the paper.
"""

from utils.args import *
from models.utils.continual_model import ContinualModel

from datasets import get_dataset
import wandb
from models.attriclip_utils.model import CoOp
from models.attriclip_utils.utils import cosine_loss
from utils.conf import get_device


[docs] class Attriclip(ContinualModel): """Continual Learning via Progressive Neural Networks.""" NAME = 'attriclip' COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual']
[docs] @staticmethod def get_parser(parser) -> ArgumentParser: parser.add_argument("--num_prompt", type=int, default=10, help='num_prompt') parser.add_argument("--text_prompt", type=int, default=3, help='text_prompt') parser.add_argument('--freeze_clip', type=int, default=1, help='freeze_clip') return parser
def __init__(self, backbone, loss, args, transform, dataset=None): seq_dataset = get_dataset(args) if dataset is None else dataset self.device = get_device() self.class_names = seq_dataset.get_class_names() backbone = CoOp(self.device, False, False, args) offset_1, offset_2 = seq_dataset.get_offsets(0) cur_class_names = self.class_names[offset_1:offset_2] backbone.init_model(class_names=cur_class_names, text_key=backbone.text_key, text_prompt=backbone.text_prompt) super().__init__(backbone, loss, args, transform, dataset=dataset)
[docs] def begin_task(self, dataset): self.offset_1, self.offset_2 = self.dataset.get_offsets(self.current_task) self.per_epoch_steps = len(dataset.train_loader) cur_class_names = self.class_names[self.offset_1:self.offset_2] self.net.init_model(class_names=cur_class_names, text_key=self.net.text_key, text_prompt=self.net.text_prompt) self.opt, self.custom_scheduler = self.net.get_optimizer(self.per_epoch_steps) self.net.model.eval() self.old_epoch = 0 self.idx = 0 self.iteration = 0 self.opt.zero_grad()
[docs] def observe(self, inputs, labels, not_aug_inputs, epoch=0): if self.old_epoch != epoch: self.idx = 0 self.old_epoch = epoch labels = labels.long() log_dict = {} log_dict['lr'] = self.opt.param_groups[0]['lr'] cur_iter_idx = epoch * self.per_epoch_steps + self.idx self.custom_scheduler.step(cur_iter_idx) output, ima_feat, key_choose, loss_m = self.net.model(inputs) loss_main = self.loss(output, labels - self.offset_1) loss_k = cosine_loss(ima_feat, key_choose) loss = loss_main + 0.7 * loss_k + 0.3 * loss_m self.opt.zero_grad() loss.backward() self.opt.step() self.idx += 1 self.iteration += 1 if not self.args.nowand: wandb.log(log_dict) return loss.item()
[docs] def forward(self, x): test_classes = self.class_names[:self.offset_2] logits = self.net.model(x, test_classes, test=True) return logits[:, :self.offset_2]