from argparse import Namespace
import logging
import torch
import torch.nn as nn
from torchvision import transforms
try:
    from transformers import AutoProcessor, LlavaForConditionalGeneration, BitsAndBytesConfig
except ImportError:
    raise ImportError("Please install the HuggingFace Transformers package by running: pip install transformers")
from datasets.utils.continual_dataset import ContinualDataset
from models.utils.continual_model import ContinualModel
from utils import binary_to_boolean_type
from utils.args import ArgumentParser
[docs]
class FinalModel(nn.Module):
    @torch.no_grad()
    def __init__(self, dataset: ContinualDataset, args: Namespace, denorm_transform, device):
        super().__init__()
        # moved here to avoid having to install them when building docs
        try:
            import bitsandbytes
        except ImportError:
            raise ImportError("Please install the BitsAndBytes package by running: `pip install -i https://pypi.org/simple/ bitsandbytes`")
        try:
            import accelerate
        except ImportError:
            raise ImportError("Please install the accelerate package by running: `pip install accelerate`")
        try:
            import sentencepiece
        except ImportError:
            raise ImportError("Please install the sentencepiece package by running: `pip install sentencepiece`")
        self.denorm_transform = denorm_transform
        self.device = device
        quantization_config = BitsAndBytesConfig(load_in_4bit=args.load_in_4bit, bnb_4bit_compute_dtype=torch.float16)
        self.model = LlavaForConditionalGeneration.from_pretrained(args.llava_model_name, device_map=device, quantization_config=quantization_config)
        self.processor = AutoProcessor.from_pretrained(args.llava_model_name, pad_token="<pad>", use_fast=False)
        class_names = [' '.join(c.lower().split('_')) for c in dataset.get_class_names()]
        self.class_names = [f'({i+1}) {c}' for i, c in enumerate(class_names)]
        if '<classnames>' in args.classification_prompt:
            classification_prompt = args.classification_prompt.replace('<classnames>', ' '.join(class_names))
        if '<datasetname>' in args.classification_prompt:
            classification_prompt = args.classification_prompt.replace('<datasetname>', dataset.NAME.replace('seq-', '').replace('-224', '').replace('-', ' '))
        self.prompt = args.base_prompt.replace('<prompt>', classification_prompt)
        self.eye = torch.eye(len(class_names))
[docs]
    def get_closest_classname(self, pred_class_name):
        # get the index of the closest class name
        pred_class_name = pred_class_name.lower().replace('_', ' ').strip()
        closest_class_name = [c for c in self.class_names if pred_class_name in c or any(cs for cs in pred_class_name.split('.') if cs.strip() in c)]
        if len(closest_class_name) == 0:
            return -1
        else:
            return self.class_names.index(closest_class_name[0]) 
[docs]
    @torch.no_grad()
    def forward(self, x):
        x = self.denorm_transform(x.cpu())
        inputs = self.processor(text=[self.prompt] * len(x),
                                images=x.to(self.device),
                                padding=True,
                                return_tensors='pt',
                                do_rescale=False).to(self.device)
        outputs = self.model.generate(**inputs, max_new_tokens=20)
        decoded_outputs = self.processor.batch_decode(outputs, skip_special_tokens=True)
        # Extract the class names from the output
        out_class_names = [output.split('ASSISTANT:')[-1].strip().lower() for output in decoded_outputs]
        # Convert the class names to a prediction tensor
        prediction = torch.tensor([self.get_closest_classname(class_name) for class_name in out_class_names])
        preds = torch.zeros(len(prediction), len(self.class_names))
        preds[prediction != -1] = self.eye[prediction[prediction != -1]]
        return preds.to(self.device) 
 
[docs]
class Llava(ContinualModel):
    """STATIC Continual Learning with LLAVA."""
    NAME = 'llava'
    COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual']
[docs]
    @staticmethod
    def get_parser(parser) -> ArgumentParser:
        parser.set_defaults(lr=0, n_epochs=0)  # disable training by default
        parser.add_argument('--llava_model_name', type=str, default='llava-hf/llava-1.5-7b-hf',
                            help='Name of the LLAVA model to use')
        parser.add_argument('--base_prompt', type=str, default="USER: <image>\n<prompt>\nASSISTANT:",
                            help='Base prompt for the LLAVA model')
        parser.add_argument('--load_in_4bit', type=binary_to_boolean_type, default=1, help='Load the model in 4-bit mode')
        parser.add_argument('--classification_prompt', type=str,
                            help='Prompt to use for classification. If <classnames> is present, it will be replaced with the class names. If <datasetname> is present, it will be replaced with the dataset name',
                            default="What object is in the photo ? Pick only one and answer with the class name <classnames>")
        return parser 
    def __init__(self, backbone, loss, args, transform, dataset=None):
        backbone = None
        if args.n_epochs != 0:
            logging.warning(f"LLAVA is a STATIC model, setting n_epochs to {0}")
            args.n_epochs = 0
        super().__init__(backbone, loss, args, transform, dataset=dataset)
        denorm_transform = self.dataset.get_denormalization_transform()
        self.net = FinalModel(self.dataset, args, denorm_transform=denorm_transform, device=self.device)
        self.predictions = []
        self.original_labels = []
[docs]
    def begin_task(self, dataset):
        dataset.test_loaders[-1].dataset.transform = transforms.ToTensor()
        self.eval() 
[docs]
    def observe(self, inputs, labels, not_aug_inputs, epoch=None):
        # do nothing
        return 0 
[docs]
    @torch.no_grad()
    def forward(self, x):
        return self.net(x)[:, :self.n_seen_classes]