FISHER KFAC#

Classes#

class models.tak_utils.fisher_kfac.KFACComputer(device, debug_mode, train_percent=1.0, num_samples_expectation=2, fp_precision='fp64')[source]#

Bases: Module

compute(net, head, delta_w_names, dataset, use_head=False)[source]#
to_be_fishered(name, module, all_param_finetuned)[source]#
to_be_fishered_layer_norm(name, module, all_param_finetuned)[source]#

Functions#

models.tak_utils.fisher_kfac.get_split(dataset)[source]#
models.tak_utils.fisher_kfac.register_hooks(name, module, forward=True, backward=True, forward_hooks_dict=None, bacward_hooks_dict=None)[source]#
models.tak_utils.fisher_kfac.get_split(dataset)[source]#
models.tak_utils.fisher_kfac.register_hooks(name, module, forward=True, backward=True, forward_hooks_dict=None, bacward_hooks_dict=None)[source]#