Source code for utils.kornia_utils

from typing import List, Union
import kornia
from torch import nn
import torch
from torchvision import transforms
from kornia.augmentation.container.params import ParamItem
from kornia.constants import Resample

from utils.autoaugment import get_kornia_Cifar10Policy


[docs] class KorniaMultiAug(kornia.augmentation.AugmentationSequential): """ A custom augmentation class that performs multiple Kornia augmentations. Args: n_augs (int): The number of augmentations to apply. aug_list (List[kornia.augmentation.AugmentationBase2D]): The list of augmentations to apply. Methods: forward: Overrides the forward method to apply the transformation without gradient computation. """ def __init__(self, n_augs: int, aug_list: List[kornia.augmentation.AugmentationBase2D]): super().__init__(*aug_list) self.n_augs = n_augs
[docs] @torch.no_grad() def forward(self, x: torch.Tensor) -> torch.Tensor: """ Overrides the forward method to apply the transformation without gradient computation. Args: x (torch.Tensor): The input tensor. Returns: torch.Tensor: The transformed tensor. """ original_shape = x.shape x = super().forward(x.repeat(self.n_augs, 1, 1, 1)) x = x.reshape(self.n_augs, *original_shape) return x
[docs] class KorniaAugNoGrad(kornia.augmentation.AugmentationSequential): """ A custom augmentation class that applies Kornia augmentations without gradient computation. Inherits from `kornia.augmentation.AugmentationSequential`. Args: *args: Variable length argument list. **kwargs: Arbitrary keyword arguments. Methods: _do_transform: Performs the transformation without gradient computation. forward: Overrides the forward method to apply the transformation without gradient computation. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def _do_transform(self, *args, **kwargs) -> torch.Tensor: """ Performs the transformation without gradient computation. Args: *args: Variable length argument list. **kwargs: Arbitrary keyword arguments. Returns: torch.Tensor: The transformed tensor. """ x = super().forward(*args, **kwargs) return x
[docs] @torch.no_grad() def forward(self, *args, **kwargs) -> torch.Tensor: """ Overrides the forward method to apply the transformation without gradient computation. Args: *args: Variable length argument list. **kwargs: Arbitrary keyword arguments. Returns: torch.Tensor: The transformed tensor. """ return self._do_transform(*args, **kwargs)
def _convert_interpolation_to_resample(interpolation: int) -> int: interpolation_name = transforms.InterpolationMode(interpolation).name if hasattr(Resample, interpolation_name): return getattr(Resample, interpolation_name) else: raise NotImplementedError(f"Interpolation mode {interpolation_name} not supported by Kornia.")
[docs] def to_kornia_transform(transform: transforms.Compose, apply: bool = True) -> Union[List[kornia.augmentation.AugmentationBase2D], KorniaAugNoGrad]: """ Converts PIL transforms to Kornia transforms. Args: transform (transforms.Compose): The torchvision transform to be converted. apply (bool, optional): Whether to convert the processed kornia transforms list into a KorniaAugNoGrad object. Defaults to True. Returns: Union[List[kornia.augmentation.AugmentationBase2D], KorniaAugNoGrad]: The converted Kornia transforms. """ if isinstance(transform, kornia.augmentation.AugmentationSequential) or \ (isinstance(transform, nn.Sequential) and isinstance(transform[0], kornia.augmentation.AugmentationBase2D)): return transform if not isinstance(transform, list): if hasattr(transform, "transforms"): transform = list(transform.transforms) else: transform = [transform] ts = [] for t in transform: if isinstance(t, transforms.RandomResizedCrop): ts.append(kornia.augmentation.RandomResizedCrop(size=t.size, scale=t.scale, ratio=t.ratio, resample=_convert_interpolation_to_resample(t.interpolation))) elif isinstance(t, transforms.RandomHorizontalFlip): ts.append(kornia.augmentation.RandomHorizontalFlip(p=t.p)) elif isinstance(t, transforms.RandomVerticalFlip): ts.append(kornia.augmentation.RandomVerticalFlip(p=t.p)) elif isinstance(t, transforms.RandomRotation): ts.append(kornia.augmentation.RandomRotation(degrees=t.degrees, resample=_convert_interpolation_to_resample(t.interpolation))) elif isinstance(t, transforms.RandomGrayscale): ts.append(kornia.augmentation.RandomGrayscale(p=t.p)) elif isinstance(t, transforms.RandomAffine): ts.append( kornia.augmentation.RandomAffine( degrees=t.degrees, translate=t.translate, scale=t.scale, shear=t.shear, resample=_convert_interpolation_to_resample(t.interpolation), fill=t.fill)) elif isinstance(t, transforms.RandomPerspective): ts.append(kornia.augmentation.RandomPerspective(distortion_scale=t.distortion_scale, p=t.p, resample=_convert_interpolation_to_resample(t.interpolation), fill=t.fill)) elif isinstance(t, transforms.RandomCrop): ts.append(kornia.augmentation.RandomCrop(size=t.size, padding=t.padding, pad_if_needed=t.pad_if_needed, fill=t.fill, padding_mode=t.padding_mode)) elif isinstance(t, transforms.RandomErasing): ts.append(kornia.augmentation.RandomErasing(p=t.p, scale=t.scale, ratio=t.ratio, value=t.value, inplace=t.inplace)) elif isinstance(t, transforms.ColorJitter): ts.append(kornia.augmentation.ColorJitter(brightness=t.brightness, contrast=t.contrast, saturation=t.saturation, hue=t.hue)) elif isinstance(t, transforms.RandomApply): ts.append(kornia.augmentation.RandomApply(t.transforms, p=t.p)) elif isinstance(t, transforms.RandomChoice): ts.append(kornia.augmentation.RandomChoice(t.transforms)) elif isinstance(t, transforms.RandomOrder): ts.append(kornia.augmentation.RandomOrder(t.transforms)) elif isinstance(t, transforms.RandomResizedCrop): ts.append(kornia.augmentation.RandomResizedCrop(size=t.size, scale=t.scale, ratio=t.ratio, resample=_convert_interpolation_to_resample(t.interpolation))) elif isinstance(t, transforms.Compose): ts.extend(to_kornia_transform(t, apply=False)) elif isinstance(t, transforms.ToTensor) or isinstance(t, transforms.ToPILImage): pass elif isinstance(t, transforms.Normalize): ts.append(kornia.augmentation.Normalize(mean=t.mean, std=t.std, p=1)) elif isinstance(t, transforms.Resize): ts.append(kornia.augmentation.Resize(size=t.size, antialias=t.antialias, resample=_convert_interpolation_to_resample(t.interpolation))) elif "cifar10policy" in str(type(t)).lower(): ts.append(get_kornia_Cifar10Policy()) else: raise NotImplementedError if not apply: return ts return KorniaAugNoGrad(*ts, same_on_batch=True)
[docs] class CustomKorniaRandAugment(kornia.augmentation.auto.PolicyAugmentBase): """ A custom augmentation class that applies randaug as a Kornia augmentation. Inherits from `kornia.augmentation.auto.PolicyAugmentBase`. Args: n (int): The number of augmentations to apply. policy: The policy of augmentations to apply. Attributes: rand_selector (torch.distributions.Categorical): A categorical distribution for selecting augmentations randomly. n (int): The number of augmentations to apply. Methods: _getpolicy: Returns the Kornia augmentation operation based on the name, probability, and magnitude. compose_subpolicy_sequential: Composes a subpolicy of augmentations sequentially. get_forward_sequence: Returns the forward sequence of augmentations based on the selected indices or parameters. forward_parameters: Computes the forward parameters for the augmentations. """ def __init__(self, n: int, policy) -> None: super().__init__(policy) selection_weights = torch.tensor([1.0 / len(self)] * len(self)) self.rand_selector = torch.distributions.Categorical(selection_weights) self.n = n
[docs] def _getpolicy(self, name, p, m): """ Returns the Kornia augmentation operation based on the name, probability, and magnitude. Args: name (str): The name of the augmentation operation. p (float): The probability of applying the augmentation. m (float): The magnitude of the augmentation. Returns: kornia.augmentation.auto.operations.ops: The Kornia augmentation operation. """ if 'shear' in name.lower() or 'solarize' in name.lower() or 'rotate' in name.lower() or 'translate' in name.lower() or name.lower().startswith('contrast'): # for some reason, some kornia ops have the probability and magnitude in the opposite order return getattr(kornia.augmentation.auto.operations.ops, name)(m, p) else: return getattr(kornia.augmentation.auto.operations.ops, name)(p, m)
[docs] def compose_subpolicy_sequential(self, subpolicy): """ Composes a subpolicy of augmentations sequentially. Args: subpolicy (List[Tuple[str, float, float]]): The subpolicy of augmentations. Returns: kornia.augmentation.auto.PolicySequential: The composed subpolicy of augmentations. """ return kornia.augmentation.auto.PolicySequential(*[self._getpolicy(name, p, m) for (name, p, m) in subpolicy])
[docs] def get_forward_sequence(self, params=None): """ Returns the forward sequence of augmentations based on the selected indices or parameters. Args: params (List[ParamItem], optional): The parameters of the augmentations. Defaults to None. Returns: List[Tuple[str, kornia.augmentation.auto.operations.ops]]: The forward sequence of augmentations. """ if params is None: idx = self.rand_selector.sample((self.n,)) return self.get_children_by_indices(idx) return self.get_children_by_params(params)
[docs] def forward_parameters(self, batch_shape: torch.Size): """ Computes the forward parameters for the augmentations. Args: batch_shape (torch.Size): The shape of the input batch. Returns: List[ParamItem]: The forward parameters for the augmentations. """ named_modules = self.get_forward_sequence() params = [] for name, module in named_modules: mod_param = module.forward_parameters(batch_shape) param = ParamItem(name, [ParamItem(mname, mp)[1] for (mname, _), mp in zip(module.named_children(), mod_param)]) params.append(param) return params