Source code for models.tak_utils.templates

### compositional templates from: https://github.com/umd-huang-lab/perceptionCLIP ###
from datasets.seq_8vision import Sequential8Vision


[docs] def generate_composite_factors(templates, selected_factors=None): """ templates: a dictionary of contextual factors with the following format templates = { "factor_1": { "value_1": ["description_1", "description_2", "description_3"], "value_2": ["description_1", "description_2", "description_3"], } "factor_2": { "value_1": ["description_1", "description_2", "description_3"], "value_2": ["description_1", "description_2", "description_3"], "value_3": ["description_1", "description_2", "description_3"], } } selected_factors: a list of factors that we would like to compose """ # If selected_factors is provided, filter out the keys that are not in the list if selected_factors: templates = {k: templates[k] for k in selected_factors if k in templates} # Base case: if the templates dictionary is empty, return a dictionary with a single empty entry if not templates: return {"": [""]} # Extract the first key-value pair from the templates dictionary key, sub_dict = next(iter(templates.items())) # Create a copy of the templates dictionary without the extracted key rest_templates = {k: v for k, v in templates.items() if k != key} composite = {} # Iterate over each sub_key and its associated values in the sub_dict for sub_key, values in sub_dict.items(): # Recursively generate composite conditions for the rest of the templates for rest_key, rest_values in generate_composite_factors(rest_templates).items(): # Construct the new composite key new_key = f"{sub_key}_{rest_key}" if rest_key else sub_key # Combine every value from the current list with every value from the recursive results # Add commas during composition, but ensure we don't add unnecessary commas for empty values combined_values = [v + (", " + rv if rv else "") if v else rv for v in values for rv in rest_values] composite[new_key] = combined_values return composite
[docs] def compose_template(org_templates, factor_templates): factor_templates = [factor_templates[category] for category in factor_templates] new_templates = [] for factors in factor_templates: new_factors = [] for descriptions in factors: if descriptions: new_factors.append( lambda c, main=org_templates[0], descriptions=descriptions: main( c) + ', ' + descriptions + "." ) else: new_factors.append( lambda c, main=org_templates[0], descriptions=descriptions: main( c) + "." ) new_templates.extend(new_factors) return new_templates
cars_template = [ lambda c: f'a photo of a {c}.', lambda c: f'a photo of the {c}.', lambda c: f'a photo of my {c}.', lambda c: f'i love my {c}!', lambda c: f'a photo of my dirty {c}.', lambda c: f'a photo of my clean {c}.', lambda c: f'a photo of my new {c}.', lambda c: f'a photo of my old {c}.', ] cifar10_template = [ lambda c: f'a photo of a {c}.', lambda c: f'a blurry photo of a {c}.', lambda c: f'a black and white photo of a {c}.', lambda c: f'a low contrast photo of a {c}.', lambda c: f'a high contrast photo of a {c}.', lambda c: f'a bad photo of a {c}.', lambda c: f'a good photo of a {c}.', lambda c: f'a photo of a small {c}.', lambda c: f'a photo of a big {c}.', lambda c: f'a photo of the {c}.', lambda c: f'a blurry photo of the {c}.', lambda c: f'a black and white photo of the {c}.', lambda c: f'a low contrast photo of the {c}.', lambda c: f'a high contrast photo of the {c}.', lambda c: f'a bad photo of the {c}.', lambda c: f'a good photo of the {c}.', lambda c: f'a photo of the small {c}.', lambda c: f'a photo of the big {c}.', ] cifar100_template = [ lambda c: f'a photo of a {c}.', lambda c: f'a blurry photo of a {c}.', lambda c: f'a black and white photo of a {c}.', lambda c: f'a low contrast photo of a {c}.', lambda c: f'a high contrast photo of a {c}.', lambda c: f'a bad photo of a {c}.', lambda c: f'a good photo of a {c}.', lambda c: f'a photo of a small {c}.', lambda c: f'a photo of a big {c}.', lambda c: f'a photo of the {c}.', lambda c: f'a blurry photo of the {c}.', lambda c: f'a black and white photo of the {c}.', lambda c: f'a low contrast photo of the {c}.', lambda c: f'a high contrast photo of the {c}.', lambda c: f'a bad photo of the {c}.', lambda c: f'a good photo of the {c}.', lambda c: f'a photo of the small {c}.', lambda c: f'a photo of the big {c}.', ] dtd_template = [ lambda c: f'a photo of a {c} texture.', lambda c: f'a photo of a {c} pattern.', lambda c: f'a photo of a {c} thing.', lambda c: f'a photo of a {c} object.', lambda c: f'a photo of the {c} texture.', lambda c: f'a photo of the {c} pattern.', lambda c: f'a photo of the {c} thing.', lambda c: f'a photo of the {c} object.', ] eurosat_template = [ lambda c: f'a centered satellite photo of {c}.', lambda c: f'a centered satellite photo of a {c}.', lambda c: f'a centered satellite photo of the {c}.', ] flowers102_template = [ lambda c: f'a photo of {c}, a type of flower.', lambda c: f'a close-up photo of a {c} in bloom.', lambda c: f'a macro photograph of the {c} blossom.', lambda c: f'a botanical photo of a {c}.', lambda c: f'a photo of the {c} in a garden.', lambda c: f'a detailed photo of a {c} flower.', ] food101_template = [ lambda c: f'a photo of {c}, a type of food.', ] gtsrb_template = [ lambda c: f'a zoomed in photo of a "{c}" traffic sign.', lambda c: f'a centered photo of a "{c}" traffic sign.', lambda c: f'a close up photo of a "{c}" traffic sign.', ] oxfordiiitpet_template = [ lambda c: f'a photo of a {c}.', lambda c: f'a photo of the {c}.', lambda c: f'a close-up portrait of a {c}, a type of pet.', lambda c: f'a photo of a {c} looking at the camera.', lambda c: f'a {c} sitting indoors.', lambda c: f'a {c} posing for a portrait.', ] pcam_template = [ lambda c: f'a histopathology image of {c} lymph node tissue.', lambda c: f'a microscopic photo showing {c} tissue.', lambda c: f'a pathology slide containing {c} cells.', lambda c: f'a medical microscopy image of {c} tissue.', ] mnist_template = [ lambda c: f'a photo of the number: "{c}".', ] imagenet_template = [ lambda c: f'a bad photo of a {c}.', lambda c: f'a photo of many {c}.', lambda c: f'a sculpture of a {c}.', lambda c: f'a photo of the hard to see {c}.', lambda c: f'a low resolution photo of the {c}.', lambda c: f'a rendering of a {c}.', lambda c: f'graffiti of a {c}.', lambda c: f'a bad photo of the {c}.', lambda c: f'a cropped photo of the {c}.', lambda c: f'a tattoo of a {c}.', lambda c: f'the embroidered {c}.', lambda c: f'a photo of a hard to see {c}.', lambda c: f'a bright photo of a {c}.', lambda c: f'a photo of a clean {c}.', lambda c: f'a photo of a dirty {c}.', lambda c: f'a dark photo of the {c}.', lambda c: f'a drawing of a {c}.', lambda c: f'a photo of my {c}.', lambda c: f'the plastic {c}.', lambda c: f'a photo of the cool {c}.', lambda c: f'a close-up photo of a {c}.', lambda c: f'a black and white photo of the {c}.', lambda c: f'a painting of the {c}.', lambda c: f'a painting of a {c}.', lambda c: f'a pixelated photo of the {c}.', lambda c: f'a sculpture of the {c}.', lambda c: f'a bright photo of the {c}.', lambda c: f'a cropped photo of a {c}.', lambda c: f'a plastic {c}.', lambda c: f'a photo of the dirty {c}.', lambda c: f'a jpeg corrupted photo of a {c}.', lambda c: f'a blurry photo of the {c}.', lambda c: f'a photo of the {c}.', lambda c: f'a good photo of the {c}.', lambda c: f'a rendering of the {c}.', lambda c: f'a {c} in a video game.', lambda c: f'a photo of one {c}.', lambda c: f'a doodle of a {c}.', lambda c: f'a close-up photo of the {c}.', lambda c: f'a photo of a {c}.', lambda c: f'the origami {c}.', lambda c: f'the {c} in a video game.', lambda c: f'a sketch of a {c}.', lambda c: f'a doodle of the {c}.', lambda c: f'a origami {c}.', lambda c: f'a low resolution photo of a {c}.', lambda c: f'the toy {c}.', lambda c: f'a rendition of the {c}.', lambda c: f'a photo of the clean {c}.', lambda c: f'a photo of a large {c}.', lambda c: f'a rendition of a {c}.', lambda c: f'a photo of a nice {c}.', lambda c: f'a photo of a weird {c}.', lambda c: f'a blurry photo of a {c}.', lambda c: f'a cartoon {c}.', lambda c: f'art of a {c}.', lambda c: f'a sketch of the {c}.', lambda c: f'a embroidered {c}.', lambda c: f'a pixelated photo of a {c}.', lambda c: f'itap of the {c}.', lambda c: f'a jpeg corrupted photo of the {c}.', lambda c: f'a good photo of a {c}.', lambda c: f'a plushie {c}.', lambda c: f'a photo of the nice {c}.', lambda c: f'a photo of the small {c}.', lambda c: f'a photo of the weird {c}.', lambda c: f'the cartoon {c}.', lambda c: f'art of the {c}.', lambda c: f'a drawing of the {c}.', lambda c: f'a photo of the large {c}.', lambda c: f'a black and white photo of a {c}.', lambda c: f'the plushie {c}.', lambda c: f'a dark photo of a {c}.', lambda c: f'itap of a {c}.', lambda c: f'graffiti of the {c}.', lambda c: f'a toy {c}.', lambda c: f'itap of my {c}.', lambda c: f'a photo of a cool {c}.', lambda c: f'a photo of a small {c}.', lambda c: f'a tattoo of the {c}.', ] resisc45_template = [ lambda c: f'satellite imagery of {c}.', lambda c: f'aerial imagery of {c}.', lambda c: f'satellite photo of {c}.', lambda c: f'aerial photo of {c}.', lambda c: f'satellite view of {c}.', lambda c: f'aerial view of {c}.', lambda c: f'satellite imagery of a {c}.', lambda c: f'aerial imagery of a {c}.', lambda c: f'satellite photo of a {c}.', lambda c: f'aerial photo of a {c}.', lambda c: f'satellite view of a {c}.', lambda c: f'aerial view of a {c}.', lambda c: f'satellite imagery of the {c}.', lambda c: f'aerial imagery of the {c}.', lambda c: f'satellite photo of the {c}.', lambda c: f'aerial photo of the {c}.', lambda c: f'satellite view of the {c}.', lambda c: f'aerial view of the {c}.', ] stl10_template = [ lambda c: f'a photo of a {c}.', lambda c: f'a photo of the {c}.', ] sun397_template = [ lambda c: f'a photo of a {c}.', lambda c: f'a photo of the {c}.', ] svhn_template = [ lambda c: f'a photo of the number: "{c}".', ] fer2013_template = [ lambda c: f'a portrait of a person with a {c} expression.', lambda c: f'a face showing a {c} emotion.', lambda c: f'a close-up of a face that looks {c}.', lambda c: f'a person looking {c}.', lambda c: f'a facial expression that appears {c}.', ] cub200_main_template = [ lambda c: f'a photo of a {c}, a type of bird' ] cub200_factor_templates = { "size": { "others": [""], "small": ["small"], "big": ["big"], }, "background": { "others": [""], "land": ["on land"], "water": ["on water"], "forest": ["in forest"], "sky": ["in sky"], "street": ["on street"], "grass": ["on grass"], "tree": ["on tree"], "flowers": ["with flowers"], "beach": ["on beach"], "human": ["with human"], "branch": ["on a branch"], }, "condition": { "normal": [""], "cool": ["cool"], "nice": ["nice"], "weird": ["weird"], } } cub200_template = compose_template(cub200_main_template, generate_composite_factors(cub200_factor_templates)) dataset_to_template = { 'Cars': cars_template, 'CIFAR10': cifar10_template, 'CIFAR100': cifar100_template, 'seq-derm7pt': cifar100_template, 'seq-isic': cifar100_template, 'seq-cifar100-224-5': cifar100_template, 'seq-cifar100-224': cifar100_template, 'seq-cifar100-224-5-permutato': cifar100_template, 'joint-cifar100': cifar100_template, 'CUB200': imagenet_template, # TODO: experiment with the templates for this dataset??? 'CUB200CustomTemplates': cub200_template, 'DomainNet': imagenet_template, # TODO: experiment with the templates for this dataset??? 'DTD': dtd_template, 'seq-dtd': dtd_template, 'joint-dtd': dtd_template, 'seq-gtsrb': gtsrb_template, 'joint-gtsrb': gtsrb_template, 'seq-mnist': mnist_template, 'joint-mnist': mnist_template, 'EuroSAT': eurosat_template, 'seq-eurosat-rgb': eurosat_template, 'joint-eurosat-rgb': eurosat_template, 'seq-resisc45': resisc45_template, 'joint-resisc45': resisc45_template, 'seq-cars196': cars_template, 'joint-cars196': cars_template, 'seq-cropdisease': cifar100_template, 'seq-cub200': cub200_main_template, 'Food101': food101_template, 'Flowers102': flowers102_template, 'seq-flowers102': flowers102_template, 'joint-flowers102': flowers102_template, 'GTSRB': gtsrb_template, 'MNIST': mnist_template, 'seq-mnist-224': mnist_template, 'joint-mnist-224': mnist_template, 'ImageNet': imagenet_template, 'seq-imagenet-r': imagenet_template, 'seq-imagenet1k': imagenet_template, 'seq-imagenet21k': imagenet_template, 'ImageNetR': imagenet_template, 'OxfordIIITPet': oxfordiiitpet_template, 'seq-oxfordiiitpet': oxfordiiitpet_template, 'joint-oxfordiiitpet': oxfordiiitpet_template, 'PCAM': pcam_template, 'seq-pcam': pcam_template, 'joint-pcam': pcam_template, 'RESISC45': resisc45_template, 'STL10': stl10_template, 'seq-stl10': stl10_template, 'joint-stl10': stl10_template, 'SUN397': sun397_template, 'seq-sun397': sun397_template, 'joint-sun397': sun397_template, 'SVHN': svhn_template, 'seq-svhn': svhn_template, 'joint-svhn': svhn_template, 'FER2013': fer2013_template, 'seq-fer2013': fer2013_template, 'joint-fer2013': fer2013_template, }
[docs] def get_templates(dataset_name): if dataset_name.endswith('Val'): return get_templates(dataset_name.replace('Val', '')) if dataset_name == "seq-8vision": return [dataset_to_template[dset_name] for dset_name in Sequential8Vision.DATASET_NAMES] assert dataset_name in dataset_to_template, f'Unsupported dataset: {dataset_name}' return dataset_to_template[dataset_name]