CONDITIONAL BN#
Classes#
- class utils.conditional_bn.ConditionalBatchNorm1d(num_features, num_conditions)[source]#
Bases:
Module
Conditional Batch Normalization for 1D inputs.
- forward(x, cond_id, flag_stop_grad=None)[source]#
Compute Conditional Batch Normalization for 1D inputs.
The input tensor x is applied with a specific 1D batch normalization layer, specified by the cond_id.
- Parameters:
x (torch.Tensor) – The input tensor.
cond_id (torch.Tensor) – The index of the conditioning.
flag_stop_grad (torch.Tensor, optional) – The flag to stop gradients for gamma and beta.
- Returns:
The output tensor.
- Return type:
- class utils.conditional_bn.ConditionalBatchNorm2d(num_features, num_conditions)[source]#
Bases:
Module
Conditional Batch Normalization for 2D inputs.
- forward(x, cond_id, flag_stop_grad=None)[source]#
Compute Conditional Batch Normalization for 2D inputs.
The input tensor x is applied with a specific 2D batch normalization layer, specified by the cond_id.
- Parameters:
x (torch.Tensor) – The input tensor.
cond_id (torch.Tensor) – The index of the conditioning.
flag_stop_grad (torch.Tensor, optional) – The flag to stop gradients for gamma and beta.
- Returns:
The output tensor.
- Return type:
- class utils.conditional_bn.SelectClearGrad(*args, **kwargs)[source]#
Bases:
Function
A custom autograd function that clears gradients for selected indices.
- static backward(ctx, grad_output)[source]#
Backward pass of the SelectClearGrad function.
- Parameters:
grad_output (torch.Tensor) – The gradient of the output.
- Returns:
The gradient of the input tensor and None.
- Return type:
Tuple[torch.Tensor, None]
- static forward(ctx, x, indices)[source]#
Forward pass of the SelectClearGrad function.
- Parameters:
x (torch.Tensor) – The input tensor.
indices (torch.Tensor) – The indices to clear gradients for.
- Returns:
The input tensor, reshaped to match the shape of x.
- Return type: