FISHER DIAG#

Classes#

class models.tak_utils.fisher_diag.DiagComputer(device, debug_mode, train_percent=1.0, num_samples_expectation=0, fp_precision='fp64')[source]#

Bases: Module

compute(net, head, delta_w_names, dataset, use_head=False)[source]#
to_be_fishered(name, module, all_param_finetuned)[source]#
to_be_fishered_layer_norm(name, module, all_param_finetuned)[source]#
class models.tak_utils.fisher_diag.LossDiagComputer(device, debug_mode, train_percent=1.0, fp_precision='fp64')[source]#

Bases: DiagComputer

compute(net, head, delta_w_names, dataset, use_head=False)[source]#
class models.tak_utils.fisher_diag.LossDiagComputerSampling(device, debug_mode, train_percent=1.0, fp_precision='fp64')[source]#

Bases: DiagComputer

compute(net, head, delta_w_names, dataset, use_head=False)[source]#

Functions#

models.tak_utils.fisher_diag.get_split(dataset)[source]#
models.tak_utils.fisher_diag.hook_backward_cls_token_diag(module, _, grad_output)[source]#
models.tak_utils.fisher_diag.hook_backward_diag(module, _, grad_output)[source]#
models.tak_utils.fisher_diag.hook_backward_layer_norm_diag(module, _, grad_output)[source]#
models.tak_utils.fisher_diag.register_hooks(name, module, forward=True, backward=True, forward_hooks_dict=None, bacward_hooks_dict=None)[source]#
models.tak_utils.fisher_diag.get_split(dataset)[source]#
models.tak_utils.fisher_diag.hook_backward_cls_token_diag(module, _, grad_output)[source]#
models.tak_utils.fisher_diag.hook_backward_diag(module, _, grad_output)[source]#
models.tak_utils.fisher_diag.hook_backward_layer_norm_diag(module, _, grad_output)[source]#
models.tak_utils.fisher_diag.register_hooks(name, module, forward=True, backward=True, forward_hooks_dict=None, bacward_hooks_dict=None)[source]#