"""
Base class for all models that use the Lipschitz regularization in LiDER (https://arxiv.org/pdf/2210.06443.pdf).
"""
import logging
import torch
import torch.nn.functional as F
from tqdm import tqdm
from typing import List
from models.utils.continual_model import ContinualModel
[docs]
def add_lipschitz_args(parser):
# BUFFER LIP LOSS
parser.add_argument('--alpha_lip_lambda', type=float, required=False, default=0,
help='Lambda parameter for lipschitz minimization loss on buffer samples')
# BUDGET LIP LOSS
parser.add_argument('--beta_lip_lambda', type=float, required=False, default=0,
help='Lambda parameter for lipschitz budget distribution loss')
# Extra
parser.add_argument('--headless_init_act', type=str, choices=["relu", "lrelu"], default="relu")
parser.add_argument('--grad_iter_step', type=int, required=False, default=-2,
help='Step from which to enable gradient computation.')
[docs]
class LiderOptimizer(ContinualModel):
"""
Superclass for all models that use the Lipschitz regularization in LiDER (https://arxiv.org/pdf/2210.06443.pdf).
"""
def __init__(self, backbone, loss, args, transform, dataset=None):
super().__init__(backbone, loss, args, transform, dataset=dataset)
if self.args.alpha_lip_lambda == 0 and self.args.beta_lip_lambda == 0:
logging.error("LiDER is enabled but both `alpha_lip_lambda` and `beta_lip_lambda` are 0. LiDER will not be used.")
[docs]
def transmitting_matrix(self, fm1: torch.Tensor, fm2: torch.Tensor):
if fm1.size(2) > fm2.size(2):
fm1 = F.adaptive_avg_pool2d(fm1, (fm2.size(-2), fm2.size(-1)))
fm1 = fm1.view(fm1.size(0), fm1.size(1), -1)
fm2 = fm2.view(fm2.size(0), fm2.size(1), -1).transpose(1, 2)
fsp = torch.bmm(fm1, fm2) / fm1.size(2)
return fsp
[docs]
def compute_transition_matrix(self, front: torch.Tensor, latter: torch.Tensor):
return torch.bmm(self.transmitting_matrix(front, latter), self.transmitting_matrix(front, latter).transpose(2, 1))
[docs]
def top_eigenvalue(self, K: torch.Tensor, n_power_iterations=10):
"""
Compute the top eigenvalue of a matrix K using the power iteration method.
Stop gradient propagation after `n_power_iterations`.
Args:
K (torch.Tensor): The matrix to compute the top eigenvalue of.
n_power_iterations (int): The number of power iterations to run. If positive, compute gradient only for the first `n_power_iterations` iterations. If negative, compute gradient only for the last `n_power_iterations` iterations.
Returns:
torch.Tensor: The top eigenvalue of K.
"""
if self.args.grad_iter_step < 0:
start_grad_it = n_power_iterations + self.args.grad_iter_step + 1
else:
start_grad_it = self.args.grad_iter_step
assert start_grad_it >= 0 and start_grad_it <= n_power_iterations
v = torch.rand(K.shape[0], K.shape[1], 1).to(K.device, dtype=K.dtype)
for itt in range(n_power_iterations):
with torch.set_grad_enabled(itt >= start_grad_it):
m = torch.bmm(K, v)
n = (torch.norm(m, dim=1).unsqueeze(1) + torch.finfo(torch.float32).eps)
v = m / n
top_eigenvalue = torch.sqrt(n / (torch.norm(v, dim=1).unsqueeze(1) + torch.finfo(torch.float32).eps))
return top_eigenvalue
[docs]
def get_layer_lip_coeffs(self, features_a: torch.Tensor, features_b: torch.Tensor) -> torch.Tensor:
"""
Compute the Lipschitz coefficient of a layer given its batches of input and output features.
Estimates the Lipschitz coefficient with https://arxiv.org/pdf/2108.12905.pdf.
Args:
features_a (torch.Tensor): The batch of input features.
features_b (torch.Tensor): The batch of output features.
Returns:
torch.Tensor: The Lipschitz coefficient of the layer.
"""
features_a, features_b = features_a.double(), features_b.double()
features_a, features_b = features_a / self.get_norm(features_a), features_b / self.get_norm(features_b)
TM_s = self.compute_transition_matrix(features_a, features_b)
L = self.top_eigenvalue(K=TM_s)
return L
[docs]
def get_feature_lip_coeffs(self, features: List[torch.Tensor]) -> List[torch.Tensor]:
"""
Compute the Lipschitz coefficient for all the layers of a network given a list of batches of features.
The features are assumed to be ordered from the input to the output of the network.
Args:
features (List[torch.Tensor]): The list features of each layer.
Returns:
List[torch.Tensor]: The list of Lipschitz coefficients for each layer.
"""
N = len(features) - 1
B = len(features[0])
lip_values = [torch.zeros(B, device=self.device, dtype=features[0].dtype)] * N
for i in range(N):
fma, fmb = features[i], features[i + 1]
fmb = F.adaptive_avg_pool1d(fmb.reshape(*fmb.shape[:2], -1).permute(0, 2, 1), fma.shape[1]).permute(0, 2, 1).reshape(fmb.shape[0], -1, *fmb.shape[2:])
L = self.get_layer_lip_coeffs(fma, fmb)
L = L.reshape(B)
lip_values[i] = L
return lip_values
[docs]
@torch.no_grad()
def init_net(self, dataset):
"""
Compute the target Lipschitz coefficients for the network and initialize the network's Lipschitz coefficients to match them.
Args:
dataset (ContinualDataset): The dataset to use for the computation.
"""
was_training = self.net.training
self.net.eval()
all_lips = []
for i, data in enumerate(tqdm(dataset.train_loader, desc="Computing target L budget")):
inputs, labels = data[0], data[1]
if self.args.debug_mode and i > self.get_debug_iters():
continue
inputs, labels = inputs.to(self.device), labels.to(self.device)
if len(inputs.shape) == 5:
B, n, C, H, W = inputs.shape
inputs = inputs.view(B * n, C, H, W)
_, partial_features = self.net(inputs, returnt='full')
lip_inputs = [inputs] + partial_features
lip_values = self.get_feature_lip_coeffs(lip_inputs)
# (B, F)
lip_values = torch.stack(lip_values, dim=1)
all_lips.append(lip_values)
budget_lip = torch.cat(all_lips, dim=0).mean(0).detach().clone()
inp = next(iter(dataset.train_loader))[0]
_, teacher_feats = self.net(inp.to(self.device), returnt='full')
self.net.lip_coeffs = torch.autograd.Variable(torch.randn(len(teacher_feats), dtype=torch.float), requires_grad=True).to(self.device)
self.net.lip_coeffs.data = budget_lip
self.net.train(was_training)
[docs]
def get_norm(self, t: torch.Tensor):
"""
Compute the norm of a tensor.
Args:
t (torch.Tensor): The tensor.
Returns:
torch.Tensor: The norm of the tensor.
"""
return torch.norm(t, dim=1, keepdim=True) + torch.finfo(torch.float32).eps
[docs]
def minimization_lip_loss(self, features: List[torch.Tensor]) -> torch.Tensor:
"""
Compute the Lipschitz minimization loss for a batch of features (eq. 8).
Args:
features (List[torch.Tensor]): The list features of each layer. The features are assumed to be ordered from the input to the output of the network.
Returns:
torch.Tensor: The Lipschitz minimization loss.
"""
lip_values = self.get_feature_lip_coeffs(features)
# (B, F)
lip_values = torch.stack(lip_values, dim=1)
return lip_values.mean()
[docs]
def dynamic_budget_lip_loss(self, features: List[torch.Tensor]) -> torch.Tensor:
"""
Compute the dynamic budget Lipschitz loss for a batch of features (eq. 7).
Args:
features (List[torch.Tensor]): The list features of each layer. The features are assumed to be ordered from the input to the output of the network.
Returns:
torch.Tensor: The dynamic budget Lipschitz loss.
"""
loss = 0
lip_values = self.get_feature_lip_coeffs(features)
# (B, F)
lip_values = torch.stack(lip_values, dim=1)
if self.args.headless_init_act == "relu":
tgt = F.relu(self.net.lip_coeffs[:len(lip_values[0])])
elif self.args.headless_init_act == "lrelu":
tgt = F.leaky_relu(self.net.lip_coeffs[:len(lip_values[0])])
else:
raise NotImplementedError
tgt = tgt.unsqueeze(0).expand(lip_values.shape)
loss += F.l1_loss(lip_values, tgt)
return loss