from torch import nn
from torch.autograd import Function
[docs]
class SelectClearGrad(Function):
"""
A custom autograd function that clears gradients for selected indices.
"""
[docs]
@staticmethod
def forward(ctx, x, indices):
"""
Forward pass of the SelectClearGrad function.
Args:
x (torch.Tensor): The input tensor.
indices (torch.Tensor): The indices to clear gradients for.
Returns:
torch.Tensor: The input tensor, reshaped to match the shape of x.
"""
ctx.indices = indices
return x.view_as(x)
[docs]
@staticmethod
def backward(ctx, grad_output):
"""
Backward pass of the SelectClearGrad function.
Args:
grad_output (torch.Tensor): The gradient of the output.
Returns:
Tuple[torch.Tensor, None]: The gradient of the input tensor and None.
"""
grad_output[ctx.indices] = 0
return grad_output, None
[docs]
class ConditionalBatchNorm1d(nn.Module):
"""
Conditional Batch Normalization for 1D inputs.
"""
def __init__(self, num_features, num_conditions):
"""
Initializes the ConditionalBatchNorm1d module.
Args:
num_features (int): The number of input features.
num_conditions (int): The number of conditioning variables.
"""
super(ConditionalBatchNorm1d, self).__init__()
self.num_features = num_features
self.bn = nn.BatchNorm1d(num_features, affine=False)
self.embed = nn.Embedding(num_conditions, num_features * 2)
self.embed.weight.data[:, :num_features].normal_(1, 0.02)
self.embed.weight.data[:, num_features:].zero_()
[docs]
def forward(self, x, cond_id, flag_stop_grad=None):
"""
Compute Conditional Batch Normalization for 1D inputs.
The input tensor `x` is applied with a specific 1D batch normalization layer, specified by the `cond_id`.
Args:
x (torch.Tensor): The input tensor.
cond_id (torch.Tensor): The index of the conditioning.
flag_stop_grad (torch.Tensor, optional): The flag to stop gradients for gamma and beta.
Returns:
torch.Tensor: The output tensor.
"""
out = self.bn(x)
gamma, beta = self.embed(cond_id).chunk(2, 1)
if flag_stop_grad is not None:
gamma = SelectClearGrad.apply(gamma, flag_stop_grad)
beta = SelectClearGrad.apply(beta, flag_stop_grad)
out = gamma.view(-1, self.num_features) * out + beta.view(-1, self.num_features)
return out
[docs]
class ConditionalBatchNorm2d(nn.Module):
"""
Conditional Batch Normalization for 2D inputs.
"""
def __init__(self, num_features, num_conditions):
"""
Initializes the ConditionalBatchNorm2d module.
Args:
num_features (int): The number of input features.
num_conditions (int): The number of conditioning variables.
"""
super(ConditionalBatchNorm2d, self).__init__()
self.num_features = num_features
self.bn = nn.BatchNorm2d(num_features, affine=False)
self.embed = nn.Embedding(num_conditions, num_features * 2)
self.embed.weight.data[:, :num_features].normal_(1, 0.02)
self.embed.weight.data[:, num_features:].zero_()
[docs]
def forward(self, x, cond_id, flag_stop_grad=None):
"""
Compute Conditional Batch Normalization for 2D inputs.
The input tensor `x` is applied with a specific 2D batch normalization layer, specified by the `cond_id`.
Args:
x (torch.Tensor): The input tensor.
cond_id (torch.Tensor): The index of the conditioning.
flag_stop_grad (torch.Tensor, optional): The flag to stop gradients for gamma and beta.
Returns:
torch.Tensor: The output tensor.
"""
out = self.bn(x)
gamma, beta = self.embed(cond_id).chunk(2, 1)
if flag_stop_grad is not None:
gamma = SelectClearGrad.apply(gamma, flag_stop_grad)
beta = SelectClearGrad.apply(beta, flag_stop_grad)
out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(-1, self.num_features, 1, 1)
return out