Source code for utils.autoaugment

from PIL import Image, ImageEnhance, ImageOps
import numpy as np
import random
import kornia
import torch
from kornia.augmentation.container.params import ParamItem


[docs] class ImageNetPolicy(object): """Randomly choose one of the best 24 Sub-policies on ImageNet. Example: >>> policy = ImageNetPolicy() >>> transformed = policy(image) Example as a PyTorch Transform: >>> transform=transforms.Compose([ >>> transforms.Resize(256), >>> ImageNetPolicy(), >>> transforms.ToTensor()]) """ def __init__(self, fillcolor=(128, 128, 128)): self.policies = [ SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor), SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor), SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor), SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor), SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor), SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor), SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor), SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor), SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor), SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor), SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor), SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor), SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor), SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor), SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor), SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor), SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor), ] def __call__(self, img): policy_idx = random.randint(0, len(self.policies) - 1) return self.policies[policy_idx](img) def __repr__(self): return "AutoAugment ImageNet Policy"
[docs] class CIFAR10Policy(object): """Randomly choose one of the best 25 Sub-policies on CIFAR10. Example: >>> policy = CIFAR10Policy() >>> transformed = policy(image) Example as a PyTorch Transform: >>> transform=transforms.Compose([ >>> transforms.Resize(256), >>> CIFAR10Policy(), >>> transforms.ToTensor()]) """ def __init__(self, fillcolor=(128, 128, 128)): self.policies = [ SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor), SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor), SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor), SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor), SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor), SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor), SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor), SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor), SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor), SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor), SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor), SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor), SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor), SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor), SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor), SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor), SubPolicy(0.2, "equalize", 8, 0.6, "equalize", 4, fillcolor), SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor), SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor), SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor), SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor), SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor), SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor), SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor), SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor), ] def __call__(self, img): policy_idx = random.randint(0, len(self.policies) - 1) return self.policies[policy_idx](img) def __repr__(self): return "AutoAugment CIFAR10 Policy"
[docs] def get_kornia_Cifar10Policy(fillcolor=(128, 128, 128)): policies = [ (("ShearX", 0.9, 4), ("Invert", 0.2, 3)), (("ShearY", 0.9, 8), ("Invert", 0.7, 5)), (("Equalize", 0.6, 5), ("Solarize", 0.6, 6)), (("Invert", 0.9, 3), ("Equalize", 0.6, 3)), (("Equalize", 0.6, 1), ("Rotate", 0.9, 3)), (("ShearX", 0.9, 4), ("AutoContrast", 0.8, 3)), (("ShearY", 0.9, 8), ("Invert", 0.4, 5)), (("ShearY", 0.9, 5), ("Solarize", 0.2, 6)), (("Invert", 0.9, 6), ("AutoContrast", 0.8, 1)), (("Equalize", 0.6, 3), ("Rotate", 0.9, 3)), (("ShearX", 0.9, 4), ("Solarize", 0.3, 3)), (("ShearY", 0.8, 8), ("Invert", 0.7, 4)), (("Equalize", 0.9, 5), ("TranslateY", 0.6, 6)), (("Invert", 0.9, 4), ("Equalize", 0.6, 7)), (("Contrast", 0.3, 3), ("Rotate", 0.8, 4)), (("Invert", 0.8, 5), ("TranslateY", 0.0, 2)), (("ShearY", 0.7, 6), ("Solarize", 0.4, 8)), (("Invert", 0.6, 4), ("Rotate", 0.8, 4)), (("ShearY", 0.3, 7), ("TranslateX", 0.9, 3)), (("ShearX", 0.1, 6), ("Invert", 0.6, 5)), (("Solarize", 0.7, 2), ("TranslateY", 0.6, 7)), (("ShearY", 0.8, 4), ("Invert", 0.8, 8)), (("ShearX", 0.7, 9), ("TranslateY", 0.8, 3)), (("ShearY", 0.8, 5), ("AutoContrast", 0.7, 3)), (("ShearX", 0.7, 2), ("Invert", 0.1, 5)) ] return CustomKorniaRandAugment(n=1, policy=policies)
[docs] class SVHNPolicy(object): """Randomly choose one of the best 25 Sub-policies on SVHN. Example: >>> policy = SVHNPolicy() >>> transformed = policy(image) Example as a PyTorch Transform: >>> transform=transforms.Compose([ >>> transforms.Resize(256), >>> SVHNPolicy(), >>> transforms.ToTensor()]) """ def __init__(self, fillcolor=(128, 128, 128)): self.policies = [ SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor), SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor), SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor), SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor), SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor), SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor), SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor), SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor), SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor), SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor), SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor), SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor), SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor), SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor), SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor), SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor), SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor), SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor), SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor), SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor), SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor), SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor), SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor), SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor), SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor), ] def __call__(self, img): policy_idx = random.randint(0, len(self.policies) - 1) return self.policies[policy_idx](img) def __repr__(self): return "AutoAugment SVHN Policy"
[docs] class SubPolicy(object): def __init__( self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128), ): ranges = { "shearX": np.linspace(0, 0.3, 10), "shearY": np.linspace(0, 0.3, 10), "translateX": np.linspace(0, 150 / 331, 10), "translateY": np.linspace(0, 150 / 331, 10), "rotate": np.linspace(0, 30, 10), "color": np.linspace(0.0, 0.9, 10), "posterize": np.round(np.linspace(8, 4, 10), 0).astype(int), "solarize": np.linspace(256, 0, 10), "contrast": np.linspace(0.0, 0.9, 10), "sharpness": np.linspace(0.0, 0.9, 10), "brightness": np.linspace(0.0, 0.9, 10), "autocontrast": [0] * 10, "equalize": [0] * 10, "invert": [0] * 10, } # from https://stackoverflow.com/questions/5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand def rotate_with_fill(img, magnitude): rot = img.convert("RGBA").rotate(magnitude) return Image.composite( rot, Image.new("RGBA", rot.size, (128,) * 4), rot ).convert(img.mode) func = { "shearX": lambda img, magnitude: img.transform( img.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0), Image.BICUBIC, fillcolor=fillcolor, ), "shearY": lambda img, magnitude: img.transform( img.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0), Image.BICUBIC, fillcolor=fillcolor, ), "translateX": lambda img, magnitude: img.transform( img.size, Image.AFFINE, (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0), fillcolor=fillcolor, ), "translateY": lambda img, magnitude: img.transform( img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])), fillcolor=fillcolor, ), "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude), "color": lambda img, magnitude: ImageEnhance.Color(img).enhance( 1 + magnitude * random.choice([-1, 1]) ), "posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude), "solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude), "contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance( 1 + magnitude * random.choice([-1, 1]) ), "sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance( 1 + magnitude * random.choice([-1, 1]) ), "brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance( 1 + magnitude * random.choice([-1, 1]) ), "autocontrast": lambda img, magnitude: ImageOps.autocontrast(img), "equalize": lambda img, magnitude: ImageOps.equalize(img), "invert": lambda img, magnitude: ImageOps.invert(img), } self.p1 = p1 self.operation1 = func[operation1] self.magnitude1 = ranges[operation1][magnitude_idx1] self.p2 = p2 self.operation2 = func[operation2] self.magnitude2 = ranges[operation2][magnitude_idx2] def __call__(self, img): if random.random() < self.p1: img = self.operation1(img, self.magnitude1) if random.random() < self.p2: img = self.operation2(img, self.magnitude2) return img
[docs] class Cutout: def __init__(self, size=16) -> None: self.size = size def _create_cutout_mask(self, img_height, img_width, num_channels, size): """Creates a zero mask used for cutout of shape `img_height` x `img_width`. Args: img_height: Height of image cutout mask will be applied to. img_width: Width of image cutout mask will be applied to. num_channels: Number of channels in the image. size: Size of the zeros mask. Returns: A mask of shape `img_height` x `img_width` with all ones except for a square of zeros of shape `size` x `size`. This mask is meant to be elementwise multiplied with the original image. Additionally returns the `upper_coord` and `lower_coord` which specify where the cutout mask will be applied. """ # assert img_height == img_width # Sample center where cutout mask will be applied height_loc = np.random.randint(low=0, high=img_height) width_loc = np.random.randint(low=0, high=img_width) size = int(size) # Determine upper right and lower left corners of patch upper_coord = (max(0, height_loc - size // 2), max(0, width_loc - size // 2)) lower_coord = ( min(img_height, height_loc + size // 2), min(img_width, width_loc + size // 2), ) mask_height = lower_coord[0] - upper_coord[0] mask_width = lower_coord[1] - upper_coord[1] assert mask_height > 0 assert mask_width > 0 mask = np.ones((img_height, img_width, num_channels)) zeros = np.zeros((mask_height, mask_width, num_channels)) mask[ upper_coord[0]: lower_coord[0], upper_coord[1]: lower_coord[1], : ] = zeros return mask, upper_coord, lower_coord def __call__(self, pil_img): pil_img = pil_img.copy() img_height, img_width, num_channels = (*pil_img.size, 3) _, upper_coord, lower_coord = self._create_cutout_mask( img_height, img_width, num_channels, self.size ) pixels = pil_img.load() # create the pixel map for i in range(upper_coord[0], lower_coord[0]): # for every col: for j in range(upper_coord[1], lower_coord[1]): # For every row pixels[i, j] = (125, 122, 113, 0) # set the colour accordingly return pil_img
[docs] class RandomErasing(kornia.augmentation._2d.intensity.base.IntensityAugmentationBase2D): def __init__( self, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0.0, same_on_batch: bool = False, p: float = 0.5, keepdim: bool = False, ) -> None: super().__init__(p=p, same_on_batch=same_on_batch, keepdim=keepdim) self.scale = scale self.ratio = ratio self.value = value if isinstance(value, (tuple, list)): value = value[0] self._param_generator = kornia.augmentation.random_generator.RectangleEraseGenerator(scale, ratio, value)
[docs] def apply_transform( self, input, params, flags, transform=None ): _, c, h, w = input.size() if isinstance(self.value, (tuple, list)): values = torch.tensor(self.value).to(input).unsqueeze(-1).unsqueeze(-1).unsqueeze(0) else: values = params["values"].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat(1, *input.shape[1:]).to(input) bboxes = kornia.geometry.bbox.bbox_generator(params["xs"], params["ys"], params["widths"], params["heights"]) mask = kornia.geometry.bbox.bbox_to_mask(bboxes, w, h) # Returns B, H, W mask = mask.unsqueeze(1).repeat(1, c, 1, 1).to(input) # Transform to B, c, H, W transformed = torch.where(mask == 1.0, values, input) return transformed
[docs] def apply_transform_mask( self, input, params, flags, transform=None ): _, c, h, w = input.size() if isinstance(self.value, (tuple, list)): values = torch.tensor(self.value).to(input).unsqueeze(-1).unsqueeze(-1).unsqueeze(0) else: values = params["values"][..., None, None, None].repeat(1, *input.shape[1:]).to(input) # Erase the corresponding areas on masks. values = values.zero_() bboxes = kornia.geometry.bbox.bbox_generator(params["xs"], params["ys"], params["widths"], params["heights"]) mask = kornia.geometry.bbox.bbox_to_mask(bboxes, w, h) # Returns B, H, W mask = mask.unsqueeze(1).repeat(1, c, 1, 1).to(input) # Transform to B, c, H, W transformed = torch.where(mask == 1.0, values, input) return transformed
[docs] class KorniaAugCutout(torch.nn.Module): def __init__(self, img_size, patch_size=16) -> None: super().__init__() self.patch_size = patch_size self.img_size = img_size scale = (img_size / patch_size) ratio = (1, 1) self.base_transform = RandomErasing(p=1.0, scale=(scale, scale), ratio=ratio, same_on_batch=False, value=(125 / 255, 122 / 255, 113 / 255)) def __call__(self, img): return self.base_transform(img).squeeze(0)
[docs] class CustomKorniaRandAugment(kornia.augmentation.auto.PolicyAugmentBase): def __init__(self, n: int, policy) -> None: super().__init__(policy) selection_weights = torch.tensor([1.0 / len(self)] * len(self)) self.rand_selector = torch.distributions.Categorical(selection_weights) self.n = n def _getpolicy(self, name, p, m): if 'shear' in name.lower() or 'solarize' in name.lower() or 'rotate' in name.lower() or 'translate' in name.lower() or name.lower().startswith('contrast'): return getattr(kornia.augmentation.auto.operations.ops, name)(m, p) else: return getattr(kornia.augmentation.auto.operations.ops, name)(p, m)
[docs] def compose_subpolicy_sequential(self, subpolicy): return kornia.augmentation.auto.PolicySequential(*[self._getpolicy(name, p, m) for (name, p, m) in subpolicy])
[docs] def get_forward_sequence(self, params=None): if params is None: idx = self.rand_selector.sample((self.n,)) return self.get_children_by_indices(idx) return self.get_children_by_params(params)
[docs] def forward_parameters(self, batch_shape: torch.Size): named_modules = self.get_forward_sequence() params = [] for name, module in named_modules: mod_param = module.forward_parameters(batch_shape) # Compose it param = ParamItem(name, [ParamItem(mname, mp)[1] for (mname, _), mp in zip(module.named_children(), mod_param)]) params.append(param) return params