TAK#

Arguments#

Tak clip

--clip_backbonestr

Help: Backbone architecture for CLIP

  • Default: ViT-B/16

  • Choices: ViT-B/16, ViT-B/32, ViT-L/14

--ft_linears0|1|True|False -> bool

Help: Set to 1 fine-tune linear layers

  • Default: 1

--ft_attention0|1|True|False -> bool

Help: Set to 1 fine-tune attention layers

  • Default: 1

--ft_ln0|1|True|False -> bool

Help: Set to 1 fine-tune layer norm

  • Default: 1

--ft_class_embed0|1|True|False -> bool

Help: Set to 1 fine-tune class embedding layers

  • Default: 1

--ft_proj0|1|True|False -> bool

Help: Set to 1 fine-tune projection layers

  • Default: 1

--ft_pos_embed0|1|True|False -> bool

Help: Set to 1 fine-tune posistional embedding

  • Default: 0

--ft_conv0|1|True|False -> bool

Help: Set to 1 fine-tune convolutional layers

  • Default: 0

Tak merging

--mergingstr

Help: Merging strategy for task vectors

  • Default: ta

  • Choices: ta, dare, iso, ties, tsv

--alpha_mergingstr

Help: Alpha used during merge.NOTE: some merging strategy (ta, dare, ties) rescale the alpha_merging by the total number of tasks. To avoid errors, the value ‘one’ ensures that alpha=1 for each task

  • Default: one

Tak main

--save_task_vectors0|1|True|False -> bool

Help: Save computed task vectors?

  • Default: 0

--virtual_bs_nint

Help: chose how many chunks for vitual batch size

  • Default: 1

--default_scale_factorfloat

Help: Default scale factor for layer scaling if a single eigenvalue is present. 0 means no scaling, 1 means full scaling.

  • Default: 1

  • Choices: 0, 1

--reg_lambdafloat

Help: Regularization weight (lambda in the paper)

  • Default: 500

--fisher_ft_proj_scalerfloat

Help: Regularization scaling coeff. for the final linear projection

  • Default: 0.1

--fisher_norm_scalerfloat

Help: Regularization scaling coeff. for inner feed-forward layers

  • Default: 10

--scheduler_ntkstr

Help: LR scheduler type

  • Default: cosine

  • Choices: none, cosine, cosine_plus, decay, step

--clip_grad_normfloat

Help: Gradient clipping norm value - used if >0 and not None

  • Default: None

Tak kfac

--load_fisher0|1|True|False -> bool

Help: Load KFAC map from cache?

  • Default: 0

--fisher_cachestr

Help: Path on which to save or load KFAC maps. Supports local directories, HTTP(S) base URLs, and HuggingFace sources using hf://<owner>/<repo>/<optional/subpath>@<optional_revision>.

  • Default: fisher_cache

--train_percentstr

Help: Percentage of training data used to compute the fisher information matrix. If float, it represents the percentage of the training set. If integer, it represents the number of samples used. Put 1.0 to use the entire training set.

  • Default: 1.0

--fisher_task_idint

Help: Compute KFAC approx. on this specific task

  • Default: None

--fisher_ideal0|1|True|False -> bool

Help: Keep and use the fisher of each task (ideal - Eq. 7) or just the accumulated one (Eq. 8)

  • Default: 0

--fisher_num_samples_expectationint

Help: Compute KFAC approx. on a fixed number of samples, subset of all the task

  • Default: 1

Tak ablation

--tangent0|1|True|False -> bool

Help: Use or disable linearized training and inference (NTK regime)

  • Default: 1

Tak extra

--use_lora0|1|True|False -> bool

Help: None

  • Default: 0

--fp_precisionstr

Help: Floating point fp_precision used during KFAC computations

  • Default: fp32

  • Choices: fp32, fp64

--resume0|1|True|False -> bool

Help: Resume previous training? NOTE: requires load_path

  • Default: 0

--load_pathstr

Help: Path from which load the previous task’s task vectors. Used with resume=1

  • Default: None

Tak evaluation

--alpha_sweep_startfloat

Help: Starting merging alpha value for sensitivity analysis - used by compute_metrics_by_alpha

  • Default: 0.1

--alpha_sweep_endfloat

Help: Final merging alpha value for sensitivity analysis - used by compute_metrics_by_alpha

  • Default: 1.5

--alpha_sweep_stepfloat

Help: Step alpha value for sensitivity analysis - used by compute_metrics_by_alpha

  • Default: 0.1

Classes#

class models.tak.TAK(backbone, loss, args, transform, dataset)[source]#

Bases: ContinualModel

Task Arithmetic with KFAC regularization

COMPATIBILITY: List[str] = ['class-il', 'domain-il', 'task-il', 'general-continual']#
NAME: str = 'tak'#
begin_task(dataset)[source]#
compute_metrics_by_alpha()[source]#
create_functional(inputs, delta_names)[source]#
create_lora_param_like(fin, fout, requires_grad, r1=None, r2=None)[source]#
create_param_like(param, requires_grad)[source]#
end_eval(dataset, accs)[source]#
end_task(dataset)[source]#
forward(x)[source]#
get_all_parameters_from_dict()[source]#
get_debug_iters()[source]#
get_parameter_from_dict(name)[source]#
static get_parser(parser)[source]#
Return type:

ArgumentParser

load_task_vectors()[source]#
Returns:

True if loading was successful, False otherwise

Return type:

bool

net: Backbone#
observe(inputs, labels, not_aug_inputs, epoch=None)[source]#
penalty_weight()[source]#