# Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import PIL
import numpy as np
import torch
[docs]
class DeNormalize(object):
def __init__(self, mean, std):
"""
Initializes a DeNormalize object.
Args:
mean (list): List of mean values for each channel.
std (list): List of standard deviation values for each channel.
"""
if isinstance(mean, (list, tuple)):
mean = torch.tensor(mean)
elif isinstance(mean, np.ndarray):
mean = torch.from_numpy(mean)
if isinstance(std, (list, tuple)):
std = torch.tensor(std)
elif isinstance(std, np.ndarray):
std = torch.from_numpy(std)
self.mean = mean
self.std = std
def __call__(self, tensor: torch.Tensor | PIL.Image.Image):
"""
Applies denormalization to the input tensor.
Args:
tensor (Tensor): Tensor of images of size ([B,] C, H, W) to be denormalized.
Returns:
Tensor: Denormalized tensor.
"""
if isinstance(tensor, PIL.Image.Image):
tensor = torch.tensor(np.array(tensor).transpose(2, 0, 1)).float()
if tensor.ndimension() == 3:
tensor = tensor.unsqueeze(0)
if tensor.device != self.mean.device:
self.mean = self.mean.to(tensor.device)
self.std = self.std.to(tensor.device)
return (tensor * self.std[:, None, None]) + self.mean[:, None, None]