Source code for models.idefics

from argparse import Namespace
import os
import torch
import torch.nn as nn
from torchvision import transforms

try:
    import bitsandbytes
except ImportError:
    raise ImportError("Please install the BitsAndBytes package by running: `pip install -i https://pypi.org/simple/ bitsandbytes`")

try:
    import accelerate
except ImportError:
    raise ImportError("Please install the accelerate package by running: `pip install accelerate`")

try:
    from transformers import BitsAndBytesConfig, IdeficsForVisionText2Text, AutoProcessor
    from transformers.generation import GenerationConfig
except ImportError:
    raise ImportError("Please install the HuggingFace Transformers package by running: pip install transformers")

from datasets.utils.continual_dataset import ContinualDataset
from models.utils.continual_model import ContinualModel
from utils.args import ArgumentParser


[docs] class FinalModel(nn.Module): @torch.no_grad() def __init__(self, dataset: ContinualDataset, args: Namespace, denorm_transform, device): super().__init__() self.denorm_transform = denorm_transform self.device = device quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16, llm_int8_skip_modules=["lm_head", "embed_tokens"], ) self.processor = AutoProcessor.from_pretrained(args.idefics_model_name, use_auth_token=False, use_fast=False) self.model = IdeficsForVisionText2Text.from_pretrained(args.idefics_model_name, quantization_config=quantization_config, device_map=self.device) # Generation args exit_condition = self.processor.tokenizer("<end_of_utterance>", add_special_tokens=False).input_ids bad_words_ids = self.processor.tokenizer(["<image>", "<fake_token_around_image>"], add_special_tokens=False).input_ids self.model.config.max_new_tokens = 200 self.model.config.min_length = 1 self.model.config.eos_token_id = exit_condition self.model.config.bad_words_ids = bad_words_ids self.model.config.output_logits = True self.model.config.output_scores = False self.model.config.return_dict_in_generate = True self.gen_cfg = GenerationConfig.from_model_config(self.model.config) class_names = [' '.join(c.lower().split('_')) for c in dataset.get_class_names()] self.class_names = [f'({i+1}) {c}' for i, c in enumerate(class_names)] if '<classnames>' in args.classification_prompt: classification_prompt = args.classification_prompt.replace('<classnames>', str(class_names)) if '<datasetname>' in args.classification_prompt: classification_prompt = args.classification_prompt.replace('<datasetname>', dataset.NAME.replace('seq-', '').replace('-224', '').replace('-', ' ')) self.classification_prompt = classification_prompt self.eye = torch.eye(len(class_names))
[docs] def get_closest_classname(self, pred_class_name): # get the index of the closest class name pred_class_name = pred_class_name.lower().replace('_', ' ').strip() closest_class_name = [c for c in self.class_names if pred_class_name in c or any(cs for cs in pred_class_name.split('.') if cs.strip() in c)] if len(closest_class_name) == 0: return -1 else: return self.class_names.index(closest_class_name[0])
[docs] @torch.no_grad() def forward(self, x): x = self.denorm_transform(x.cpu()) prompts = [] for i in range(len(x)): x_pil = transforms.ToPILImage()(x[i]) prompts.append([self.classification_prompt, x_pil, "<end_of_utterance>", "\nAssistant: "]) inputs = self.processor(prompts, return_tensors="pt").to(self.device) bad_words_ids = self.processor.tokenizer(["<image>", "<fake_token_around_image>"], add_special_tokens=False).input_ids generated_ids = self.model.generate(**inputs, max_new_tokens=200, bad_words_ids=bad_words_ids) outputs = self.processor.batch_decode(generated_ids[0], skip_special_tokens=True) # Extract the class names from the output out_class_names = [output.lower().split('assistant:')[-1].strip().lower() for output in outputs] # Convert the class names to a prediction tensor prediction = torch.tensor([self.get_closest_classname(class_name) for class_name in out_class_names]) preds = torch.zeros(len(prediction), len(self.class_names)) preds[prediction != -1] = self.eye[prediction[prediction != -1]] return preds.to(self.device)
[docs] class Idefics(ContinualModel): """STATIC Continual Learning with LLAVA.""" NAME = 'idefics' COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual']
[docs] @staticmethod def get_parser(parser) -> ArgumentParser: parser.set_defaults(lr=0, n_epochs=0) # disable training by default parser.add_argument('--idefics_model_name', type=str, default='HuggingFaceM4/idefics-9b-instruct', help='Name of the LLAVA model to use') parser.add_argument('--classification_prompt', type=str, help='Prompt to use for classification. If <classnames> is present, it will be replaced with the class names. If <datasetname> is present, it will be replaced with the dataset name', default="Instruction: Classify the following image into a single category from the following list: <classnames>.\n") return parser
def __init__(self, backbone, loss, args, transform, dataset=None): backbone = None if args.n_epochs != 0: print(f"IDEFICS is a STATIC model, setting n_epochs to {0}") args.n_epochs = 0 os.environ["TOKENIZERS_PARALLELISM"] = "false" # disable tokenizers parallelism super().__init__(backbone, loss, args, transform, dataset=dataset) denorm_transform = self.dataset.get_denormalization_transform() self.net = FinalModel(self.dataset, args, denorm_transform=denorm_transform, device=self.device) self.predictions = [] self.original_labels = []
[docs] def begin_task(self, dataset): dataset.test_loaders[-1].dataset.transform = transforms.ToTensor() self.eval()
[docs] def observe(self, inputs, labels, not_aug_inputs, epoch=None): # do nothing return 0
[docs] @torch.no_grad() def forward(self, x): return self.net(x)[:, :self.n_seen_classes]