CONDITIONAL BN#

Classes#

class utils.conditional_bn.ConditionalBatchNorm1d(num_features, num_conditions)[source]#

Bases: Module

Conditional Batch Normalization for 1D inputs.

forward(x, cond_id, flag_stop_grad=None)[source]#

Compute Conditional Batch Normalization for 1D inputs.

The input tensor x is applied with a specific 1D batch normalization layer, specified by the cond_id.

Parameters:
  • x (torch.Tensor) – The input tensor.

  • cond_id (torch.Tensor) – The index of the conditioning.

  • flag_stop_grad (torch.Tensor, optional) – The flag to stop gradients for gamma and beta.

Returns:

The output tensor.

Return type:

torch.Tensor

class utils.conditional_bn.ConditionalBatchNorm2d(num_features, num_conditions)[source]#

Bases: Module

Conditional Batch Normalization for 2D inputs.

forward(x, cond_id, flag_stop_grad=None)[source]#

Compute Conditional Batch Normalization for 2D inputs.

The input tensor x is applied with a specific 2D batch normalization layer, specified by the cond_id.

Parameters:
  • x (torch.Tensor) – The input tensor.

  • cond_id (torch.Tensor) – The index of the conditioning.

  • flag_stop_grad (torch.Tensor, optional) – The flag to stop gradients for gamma and beta.

Returns:

The output tensor.

Return type:

torch.Tensor

class utils.conditional_bn.SelectClearGrad(*args, **kwargs)[source]#

Bases: Function

A custom autograd function that clears gradients for selected indices.

static backward(ctx, grad_output)[source]#

Backward pass of the SelectClearGrad function.

Parameters:

grad_output (torch.Tensor) – The gradient of the output.

Returns:

The gradient of the input tensor and None.

Return type:

Tuple[torch.Tensor, None]

static forward(ctx, x, indices)[source]#

Forward pass of the SelectClearGrad function.

Parameters:
Returns:

The input tensor, reshaped to match the shape of x.

Return type:

torch.Tensor