"""
Adaptation of OpenAI's CLIP.
Requires:
- pip install git+https://github.com/openai/CLIP.git
.. note::
Checkpoints are loaded from the OpenAI repository.
* RN50: "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"
* RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"
* RN50x4: "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt"
* RN50x16: "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt"
* RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt"
* ViT-B/32: "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"
* ViT-B/16: "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"
* ViT-L/14: "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"
* ViT-L/14@336px: "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"
"""
import torch
import torch.nn as nn
from utils import binary_to_boolean_type
try:
import clip
except ImportError:
raise ImportError("Please install the CLIP package by running: pip install git+https://github.com/openai/CLIP.git")
from datasets.utils.continual_dataset import ContinualDataset
from models.utils.continual_model import ContinualModel
from utils.args import ArgumentParser
from utils.conf import get_device
[docs]
class FinalModel(nn.Module):
@torch.no_grad()
def __init__(self, clip_model, dataset: ContinualDataset, args) -> None:
super().__init__()
self.dataset = dataset
self.clip_model = clip_model
self.args = args
self.classes = self.dataset.get_class_names()
if args.use_templates:
templates = self.dataset.get_prompt_templates()
text_inputs = []
for t in templates:
t_inputs = torch.cat([clip.tokenize(t.format(c)) for c in self.classes]).to(get_device())
t_inputs = self.clip_model.encode_text(t_inputs)
t_inputs /= t_inputs.norm(dim=-1, keepdim=True) # double normalization if use templates is expected (see https://github.dev/KaiyangZhou/CoOp)
text_inputs.append(t_inputs)
self.text_features = torch.stack(text_inputs).mean(0)
else:
text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in self.classes]).to(get_device())
self.text_features = self.clip_model.encode_text(text_inputs)
self.text_features /= self.text_features.norm(dim=-1, keepdim=True) # double normalization if use templates is expected
self.task_id = 0
[docs]
@torch.no_grad()
def forward(self, x):
image_features = self.clip_model.encode_image(x)
text_features = self.text_features
image_features /= image_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * (image_features @ text_features.T)).softmax(dim=-1)
return similarity
[docs]
class CLIP(ContinualModel):
"""STATIC Continual Learning with CLIP"""
NAME = 'clip'
COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual']
[docs]
@staticmethod
def get_parser(parser) -> ArgumentParser:
parser.set_defaults(lr=0, n_epochs=0) # disable training by default
parser.add_argument('--clip_backbone', type=str, default='ViT-L/14',
choices=list(clip.available_models()),
help='Backbone architecture for CLIP')
parser.add_argument('--save_predictions', type=binary_to_boolean_type, default=0,
help='Whether to save predictions of the TRAINING set after each task')
parser.add_argument('--use_templates', type=binary_to_boolean_type, default=0,
help='Whether to use prompt templates for CLIP. NOTE: Datasets NEED to have a `get_prompt_templates` method implemented.')
return parser
def __init__(self, backbone, loss, args, transform, dataset=None):
backbone, clip_transform = clip.load(args.clip_backbone, device=get_device())
n_epochs = 1 if args.save_predictions else 0
if args.n_epochs != n_epochs:
print(f"CLIP is a STATIC model, setting n_epochs to {n_epochs}")
args.n_epochs = n_epochs
super().__init__(backbone, loss, args, transform, dataset=dataset)
self.net = FinalModel(self.net, self.dataset, args)
self.clip_transform = clip_transform
self.predictions = []
self.original_labels = []
[docs]
def begin_task(self, dataset):
dataset.test_loaders[-1].dataset.transform = self.clip_transform
if self.args.save_predictions:
dataset.train_loader.dataset.transform = self.clip_transform
if self.current_task != 0:
self.net.task_id += 1
self.eval()
[docs]
def end_task(self, dataset: ContinualDataset) -> None:
if self.args.save_predictions:
self.predictions = torch.cat(self.predictions, dim=0).cpu()
self.original_labels = torch.cat(self.original_labels, dim=0).cpu()
torch.save((self.predictions, self.original_labels), f'predictions_{self.args.dataset}_{self.current_task}.pt')
print(f"Predictions saved for task {self.current_task} in 'predictions_{self.args.dataset}_{self.current_task}.pt'")
self.predictions = []
self.original_labels = []
return super().end_task(dataset)
[docs]
def observe(self, inputs, labels, not_aug_inputs, epoch=None):
if self.args.save_predictions:
with torch.no_grad():
self.predictions.append(self.net(inputs))
self.original_labels.append(labels)
return 0
[docs]
@torch.no_grad()
def forward(self, x):
return self.net(x)[:, :self.n_seen_classes]