UTILS#
Classes#
- class models.tak_utils.utils.FisherLoader(fisher_cache, dataset_name, device, fp_precision='fp32')[source]#
Bases:
object
Functions#
- models.tak_utils.utils.get_delta_w_backbone(named_params, delta_w, delta_w_names, training_type, device)[source]#
- models.tak_utils.utils.get_delta_w_parameterlist(named_params, delta_w, delta_w_names, peft_type, device)[source]#
- models.tak_utils.utils.get_parameter(shape, device, type_init='orto', transpose=False, requires_grad=True)[source]#
- models.tak_utils.utils.get_params(net, features=True, classifier=False, offset_1=-1, offset_2=-1)[source]#
- Return type:
Tensor
- models.tak_utils.utils.replace_non_dynamically_quantizable_linear(module)[source]#
Recursively replace all NonDynamicallyQuantizableLinear layers with Linear layers in a model.
- models.tak_utils.utils.set_params(net, new_params, features=True, classifier=False, offset_1=-1, offset_2=-1)[source]#
- models.tak_utils.utils.get_delta_w_backbone(named_params, delta_w, delta_w_names, training_type, device)[source]#
- models.tak_utils.utils.get_delta_w_parameterlist(named_params, delta_w, delta_w_names, peft_type, device)[source]#
- models.tak_utils.utils.get_parameter(shape, device, type_init='orto', transpose=False, requires_grad=True)[source]#
- models.tak_utils.utils.get_params(net, features=True, classifier=False, offset_1=-1, offset_2=-1)[source]#
- Return type:
Tensor
- models.tak_utils.utils.replace_non_dynamically_quantizable_linear(module)[source]#
Recursively replace all NonDynamicallyQuantizableLinear layers with Linear layers in a model.