Source code for models.ranpac_utils.toolkit

import torch


[docs] def target2onehot(targets, n_classes): onehot = torch.zeros(targets.shape[0], n_classes).to(targets.device) onehot.scatter_(dim=1, index=targets.long().view(-1, 1), value=1.0) return onehot