Source code for utils.mixup

# Copyright 2021-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import torch
from torch.distributions.beta import Beta


[docs] def mixup(couples, alpha, force_lambda=None): """ Applies mixup augmentation to the given couples of inputs. Args: couples (list): A list of tuples, where each tuple contains two inputs to be mixed. alpha (float): The alpha parameter for the Beta distribution used to sample the mixing coefficients. force_lambda (float or None, optional): If not None, forces the use of a specific mixing coefficient for all inputs. Returns: tuple or torch.Tensor: If more than one mixed input is generated, a tuple of mixed inputs is returned. Otherwise, a single mixed input is returned. """ lamda = Beta(alpha, alpha).rsample((len(couples[0][0]),)).to(couples[0][0].device) lamda = torch.max(lamda, 1 - lamda) if force_lambda is not None: lamda = torch.tensor(force_lambda).repeat((len(couples[0][0]),)).to(couples[0][0].device) returns = [] for (i1, i2) in couples: lamda = lamda.view([lamda.shape[0]] + [1] * (len(i1.shape) - 1)) assert i1.shape == i2.shape x_out = lamda * i1 + (1 - lamda) * i2 returns.append(x_out) return tuple(returns) if len(returns) > 1 else returns[0]