CONDITIONAL BN#
Classes#
- class utils.conditional_bn.ConditionalBatchNorm1d(num_features, num_conditions)[source]#
Bases:
Module
Conditional Batch Normalization for 1D inputs.
- forward(x, cond_id, flag_stop_grad=None)[source]#
Compute Conditional Batch Normalization for 1D inputs.
The input tensor x is applied with a specific 1D batch normalization layer, specified by the cond_id.
- Parameters:
x (torch.Tensor) – The input tensor.
cond_id (torch.Tensor) – The index of the conditioning.
flag_stop_grad (torch.Tensor, optional) – The flag to stop gradients for gamma and beta.
- Returns:
The output tensor.
- Return type:
torch.Tensor
- class utils.conditional_bn.ConditionalBatchNorm2d(num_features, num_conditions)[source]#
Bases:
Module
Conditional Batch Normalization for 2D inputs.
- forward(x, cond_id, flag_stop_grad=None)[source]#
Compute Conditional Batch Normalization for 2D inputs.
The input tensor x is applied with a specific 2D batch normalization layer, specified by the cond_id.
- Parameters:
x (torch.Tensor) – The input tensor.
cond_id (torch.Tensor) – The index of the conditioning.
flag_stop_grad (torch.Tensor, optional) – The flag to stop gradients for gamma and beta.
- Returns:
The output tensor.
- Return type:
torch.Tensor
- class utils.conditional_bn.SelectClearGrad(*args, **kwargs)[source]#
Bases:
Function
A custom autograd function that clears gradients for selected indices.