Source code for models.ranpac_utils.toolkit
import torch
[docs]
def target2onehot(targets, n_classes):
onehot = torch.zeros(targets.shape[0], n_classes).to(targets.device)
onehot.scatter_(dim=1, index=targets.long().view(-1, 1), value=1.0)
return onehot