Source code for utils.spkdloss

import torch
from torch import nn
from torch.nn import functional as F


[docs] class SPKDLoss(nn.Module): """ "Similarity-Preserving Knowledge Distillation" """ def __init__(self, reduction): super().__init__() self.reduction = reduction
[docs] def matmul_and_normalize(self, z): z = torch.flatten(z, 1) return F.normalize(torch.matmul(z, torch.t(z)), 1)
[docs] def compute_spkd_loss(self, teacher_outputs, student_outputs): g_t = self.matmul_and_normalize(teacher_outputs) g_s = self.matmul_and_normalize(student_outputs) return torch.norm(g_t - g_s) ** 2
[docs] def forward(self, teacher_outputs, student_outputs): batch_size = teacher_outputs.shape[0] spkd_losses = self.compute_spkd_loss(teacher_outputs, student_outputs) spkd_loss = spkd_losses.sum() return spkd_loss / (batch_size ** 2) if self.reduction == 'batchmean' else spkd_loss