UTILS#

Classes#

class models.lora_prototype_utils.utils.AlignmentLoss(seq_dataset, device)[source]#

Bases: Module

forward(classifier, features, labels)[source]#
norm(t)[source]#
normalize_logits(logits)[source]#
per_task_norms(logits)[source]#
set_current_task(current_task)[source]#
class models.lora_prototype_utils.utils.IncrementalClassifier(embed_dim, nb_classes, feat_expand=False)[source]#

Bases: Module

assign(classifier, which_heads=None)[source]#
backup()[source]#
build_optimizer_args(lr, wd=0)[source]#
disable_training()[source]#
enable_training()[source]#
forward(x)[source]#
get_device()[source]#
recall()[source]#
update(nb_classes, freeze_old=True)[source]#

Functions#

models.lora_prototype_utils.utils.create_optimizer(optimizer_name, optimizer_arg, momentum=0.9)[source]#
models.lora_prototype_utils.utils.get_dist(dim, n_comp=5, n_iters=500)[source]#
models.lora_prototype_utils.utils.get_parameter(shape, device, type_init='orto', transpose=False)[source]#
models.lora_prototype_utils.utils.linear_probing_epoch(data_loader, loss_fn, classifier, optim, lr_scheduler, device, debug_mode=False)[source]#