Source code for models.tak

from copy import deepcopy
import json
import os
import numpy as np
import torch.nn as nn

from typing import Any

from utils import binary_to_boolean_type

from datasets.utils.continual_dataset import ContinualDataset
from models.utils.continual_model import ContinualModel
from utils.args import ArgumentParser
import torch.func as func
import gc
import open_clip

from typing import Tuple, List

import torch
from models.tak_utils.backbone import Backbone, create_clip
from models.tak_utils.backbone import build_classification_head
from models.tak_utils.utils import FisherLoader
from models.tak_utils.utils import OptimizerBuilder
from models.tak_utils.utils import compute_acc_on_last_task

from models.tak_utils.fisher_kfac import KFACComputer

from models.tak_utils.merging import get_merging_function
from models.tak_utils.utils import get_parameter
from utils.evaluate import evaluate

import wandb


[docs] class TAK(ContinualModel): """Task Arithmetic with KFAC regularization""" NAME = "tak" COMPATIBILITY = ["class-il", "domain-il", "task-il", "general-continual"] net: Backbone
[docs] @staticmethod def get_parser(parser) -> ArgumentParser: parser.set_defaults( optimizer="adamw", lr=0.0003, optim_wd=0.1, ) clip_group = parser.add_argument_group("TAK CLIP") clip_group.add_argument( "--clip_backbone", type=str, default="ViT-B/16", help="Backbone architecture for CLIP", choices=["ViT-B/16", "ViT-B/32", "ViT-L/14"], ) clip_group.add_argument( "--ft_linears", type=binary_to_boolean_type, default=1, help="Set to 1 fine-tune linear layers", ) clip_group.add_argument( "--ft_attention", type=binary_to_boolean_type, default=1, help="Set to 1 fine-tune attention layers", ) clip_group.add_argument( "--ft_ln", type=binary_to_boolean_type, default=1, help="Set to 1 fine-tune layer norm", ) clip_group.add_argument( "--ft_class_embed", type=binary_to_boolean_type, default=1, help="Set to 1 fine-tune class embedding layers", ) clip_group.add_argument( "--ft_proj", type=binary_to_boolean_type, default=1, help="Set to 1 fine-tune projection layers", ) clip_group.add_argument( "--ft_pos_embed", type=binary_to_boolean_type, default=0, help="Set to 1 fine-tune posistional embedding", ) clip_group.add_argument( "--ft_conv", type=binary_to_boolean_type, default=0, help="Set to 1 fine-tune convolutional layers", ) merging_group = parser.add_argument_group("TAK Merging") merging_group.add_argument( "--merging", type=str, default="ta", choices=["ta", "dare", "iso", "ties", "tsv"], help="Merging strategy for task vectors", ) merging_group.add_argument( "--alpha_merging", type=str, default="one", help="Alpha used during merge." "NOTE: some merging strategy (ta, dare, ties) " "rescale the `alpha_merging` by the total number of tasks. " "To avoid errors, the value 'one' ensures that alpha=1 for each task", ) main_group = parser.add_argument_group("TAK Main") main_group.add_argument( "--save_task_vectors", type=binary_to_boolean_type, default=0, help="Save computed task vectors?", ) main_group.add_argument( "--virtual_bs_n", type=int, default=1, help="chose how many chunks for vitual batch size", ) main_group.add_argument( "--default_scale_factor", type=float, default=1, choices=[0, 1], help="Default scale factor for layer scaling if a single eigenvalue is present. 0 means no scaling, 1 means full scaling.", ) main_group.add_argument( "--reg_lambda", type=float, default=500, help="Regularization weight (lambda in the paper)", ) main_group.add_argument( "--fisher_ft_proj_scaler", type=float, default=0.1, help="Regularization scaling coeff. for the final linear projection", ) main_group.add_argument( "--fisher_norm_scaler", type=float, default=10, help="Regularization scaling coeff. for inner feed-forward layers", ) main_group.add_argument( "--scheduler_ntk", type=str, default="cosine", choices=["none", "cosine", "cosine_plus", "decay", "step"], help="LR scheduler type", ) main_group.add_argument( "--clip_grad_norm", type=float, default=None, required=False, help="Gradient clipping norm value - used if >0 and not None", ) kfac_group = parser.add_argument_group("TAK KFAC") kfac_group.add_argument( "--load_fisher", type=binary_to_boolean_type, default=0, help="Load KFAC map from cache?", ) kfac_group.add_argument( "--fisher_cache", type=str, default="fisher_cache", help="Path on which to save or load KFAC maps. " "Supports local directories, HTTP(S) base URLs, and HuggingFace sources " "using `hf://<owner>/<repo>/<optional/subpath>@<optional_revision>`.", ) kfac_group.add_argument( "--train_percent", type=str, default="1.0", help="Percentage of training data used to compute the fisher information matrix. \ If float, it represents the percentage of the training set. \ If integer, it represents the number of samples used. \ Put 1.0 to use the entire training set.", ) kfac_group.add_argument( "--fisher_task_id", type=int, default=None, required=False, help="Compute KFAC approx. on this specific task", ) kfac_group.add_argument( "--fisher_ideal", type=binary_to_boolean_type, default=0, help="Keep and use the fisher of each task (ideal - Eq. 7) or just the accumulated one (Eq. 8)", ) kfac_group.add_argument( "--fisher_num_samples_expectation", type=int, default=1, help="Compute KFAC approx. on a fixed number of samples, subset of all the task", ) ablation_group = parser.add_argument_group("TAK Ablation") ablation_group.add_argument( "--tangent", type=binary_to_boolean_type, default=1, help="Use or disable linearized training and inference (NTK regime)", ) extra_group = parser.add_argument_group("TAK Extra") extra_group.add_argument("--use_lora", type=binary_to_boolean_type, default=0) extra_group.add_argument( "--fp_precision", type=str, default="fp32", choices=["fp32", "fp64"], help="Floating point fp_precision used during KFAC computations", ) extra_group.add_argument( "--resume", type=binary_to_boolean_type, default=0, help="Resume previous training? NOTE: requires `load_path`", ) extra_group.add_argument( "--load_path", type=str, default=None, required=False, help="Path from which load the previous task's task vectors. Used with `resume=1`", ) eval_group = parser.add_argument_group("TAK Evaluation") eval_group.add_argument( "--alpha_sweep_start", type=float, default=0.1, help="Starting merging alpha value for sensitivity analysis - used by `compute_metrics_by_alpha`", ) eval_group.add_argument( "--alpha_sweep_end", type=float, default=1.5, help="Final merging alpha value for sensitivity analysis - used by `compute_metrics_by_alpha`", ) eval_group.add_argument( "--alpha_sweep_step", type=float, default=0.1, help="Step alpha value for sensitivity analysis - used by `compute_metrics_by_alpha`", ) return parser
def _parse_train_percent(self): if not isinstance(self.args.train_percent, str): return self.args raw = self.args.train_percent.strip() if ("." in raw) or ("e" in raw) or ("E" in raw): val = float(raw) assert 0 < val <= 1, "If float, train_percent must be in (0,1]" self.args.train_percent = val else: val = int(raw) assert val >= 1, "If integer, train_percent must be >= 1" self.args.train_percent = val return self.args def _ensure_merging_alpha(self): if self.args.alpha_merging == "one": if self.args.merging in ("ta", "dare", "ties"): self.args.alpha_merging = self.dataset.N_TASKS else: try: self.args.alpha_merging = float(self.args.alpha_merging) except ValueError: raise ValueError("alpha_merging must be 'one' or a float value") def __init__(self, backbone, loss, args, transform, dataset): assert dataset is not None _, train_preprocess, val_preprocess = open_clip.create_model_and_transforms( args.clip_backbone, pretrained="openai", device=torch.device("cpu") ) clip_model = create_clip(args.clip_backbone, torch.device(args.device)) super().__init__(clip_model, loss, args, transform, dataset=dataset) self._ensure_merging_alpha() self._parse_train_percent() if self.args.resume: assert self.args.load_path is not None, ( "Must provide load_path when resume is set to True" ) if self.args.save_task_vectors: os.makedirs(self.args.checkpoint_path, exist_ok=True) tv_path = os.path.join( self.args.checkpoint_path, f"{self.args.conf_jobnum}_{dataset.NAME}_args.json", ) with open(tv_path, "w") as f: json_args = deepcopy(vars(args)) del json_args[ "device" ] # device not serializable because torch people are dumb json.dump(json_args, f) self.net = Backbone(clip_model, dataset, args) self.param_names = [ name for name, _ in self.net.visual_encoder.named_parameters() ] for name, param in self.net.named_parameters(): param.requires_grad = False torch.backends.cuda.enable_mem_efficient_sdp(False) clip_model = clip_model.to(dtype=torch.float32) clip_model.eval() self.clip_model = clip_model self.clip_transform = train_preprocess self.clip_eval_transform = val_preprocess self.fisher_computer = KFACComputer( self.device, self.args.debug_mode == 1, train_percent=self.args.train_percent, num_samples_expectation=self.args.fisher_num_samples_expectation, fp_precision=self.args.fp_precision, ) self.fisher_loader = FisherLoader( self.args.fisher_cache, dataset.NAME, self.device, fp_precision="fp32" ) self.optimizer_builder = OptimizerBuilder(cmd_args=self.args) self.cur_offset = None self.cls_head: nn.Module = None self.delta_w_dict: dict[str, Any] = None self.delta_w_names: list[str] = None self.delta_w_shapes: dict[str, Any] = None self.scheduler1 = None self.tasks_ggT = {} self.tasks_ffT = {} self.tasks_aaT = {} self.coeffs = [] self.layer_scale_factors: dict[str, float] = {} self.num_batches = 0 self.merging = get_merging_function(self.args, self.device) self.merged_task_vector = [] self.num_total_tasks: int = dataset.N_TASKS self.dataset_name: int = dataset.NAME self.individual_acc, self.individual_mask_acc = [], [] self.norm_acc, self.norm_mask_acc = [], [] self.task_loaded = False
[docs] def create_param_like(self, param, requires_grad): return [ torch.zeros_like( param, dtype=torch.float32, requires_grad=requires_grad, device=self.args.device, ) ]
[docs] def create_lora_param_like(self, fin, fout, requires_grad, r1=None, r2=None): r1 = 16 if r1 is None else r1 r2 = 16 if r2 is None else r2 config = ("kaiming", "zeros") return get_parameter( (fout, r2), self.device, config[1], False, requires_grad ), get_parameter((r1, fin), self.device, config[0], False, requires_grad)
[docs] def begin_task(self, dataset): torch.cuda.empty_cache() dataset.test_loaders[-1].dataset.transform = self.clip_eval_transform dataset.train_loader.dataset.transform = self.clip_transform self.cur_offset = self.get_offsets(self.current_task) if isinstance(dataset.N_CLASSES_PER_TASK, int): self.cpt = dataset.N_CLASSES_PER_TASK else: self.cpt = dataset.N_CLASSES_PER_TASK[-1] if self.current_task != 0: self.net.task_id += 1 self.cls_head = build_classification_head( self.clip_model, dataset, self.cur_offset ) print("\nRELOADING CLIP VISUAL ENCODER") self.net.copy_visual_encoder(self.clip_model) for param in self.net.visual_encoder.parameters(): param.requires_grad = False print("\nCLIP VISUAL ENCODER RELOADED\n\n") self.delta_w_dict = {} self.delta_w_shapes = {} for name, param in self.net.visual_encoder.named_parameters(): self.delta_w_shapes[name] = param.shape if self.args.use_lora == 1 and len(param.shape) == 2: fout, fin = param.shape[0], param.shape[1] if "mlp" in name: B, A = self.create_lora_param_like( fin, fout, self.args.ft_linears == 1 ) self.delta_w_dict[name] = [B, A] elif "attn" in name: B, A = self.create_lora_param_like( fin, fout, self.args.ft_attention == 1, r1=16 * 3, r2=16 * 3 ) self.delta_w_dict[name] = [B, A] elif "proj" in name: if name == "proj": # skip, this is the projection layer of the visual encoder which has beeen replaced continue B, A = self.create_lora_param_like( fin, fout, self.args.ft_proj == 1 ) self.delta_w_dict[name] = [B, A] else: if "mlp" in name: self.delta_w_dict[name] = self.create_param_like( param, requires_grad=self.args.ft_linears == 1 ) elif "attn" in name: self.delta_w_dict[name] = self.create_param_like( param, requires_grad=self.args.ft_attention == 1 ) elif "proj" in name: if name == "proj": # skip, this is the projection layer of the visual encoder which has beeen replaced continue self.delta_w_dict[name] = self.create_param_like( param, requires_grad=self.args.ft_proj == 1 ) elif "ln" in name: self.delta_w_dict[name] = self.create_param_like( param, requires_grad=self.args.ft_ln == 1 ) elif "class_embed" in name: self.delta_w_dict[name] = self.create_param_like( param, requires_grad=self.args.ft_class_embed == 1 ) elif "conv" in name: self.delta_w_dict[name] = self.create_param_like( param, requires_grad=self.args.ft_conv == 1 ) elif "positional_embedding" in name: self.delta_w_dict[name] = self.create_param_like( param, requires_grad=self.args.ft_pos_embed == 1 ) self.delta_w_names = list(self.delta_w_dict.keys()) if not self.args.load_fisher: if ( self.args.fisher_task_id is None or self.current_task == self.args.fisher_task_id ): dataset.train_loader.dataset.transform = self.clip_eval_transform ggT, aaT, ffT, num_ggT, num_aaT = self.fisher_computer.compute( self.net, None, self.delta_w_names, dataset, use_head=False ) dataset.train_loader.dataset.transform = self.clip_transform self.fisher_loader.store_kfac( self.current_task, ggT, aaT, ffT, num_ggT, num_aaT ) else: print( f"Skipping Fisher computation for task {self.current_task} as " f"it is not the specified task {self.args.fisher_task_id}." ) else: counts = [ self.fisher_loader.load_kfac(t, only_counts=True) for t in range(dataset.N_TASKS) ] tot_ggT = sum( [ cnt[0] for idx_cnt, cnt in enumerate(counts) if idx_cnt != self.current_task ] ) tot_aaT = sum( [ cnt[1] for idx_cnt, cnt in enumerate(counts) if idx_cnt != self.current_task ] ) assert tot_ggT == tot_aaT self.coeffs = [] num_penalties = dataset.N_TASKS - 1 if self.args.fisher_ideal else 1 for t in range(dataset.N_TASKS): ggT, aaT, ffT, cur_num_ggT, cur_num_aaT = self.fisher_loader.load_kfac( t ) assert cur_num_ggT == cur_num_aaT coeff = cur_num_ggT / tot_ggT if t == 0: for key in aaT.keys(): if key in self.tasks_ggT.keys(): for p_l in range(num_penalties): self.tasks_aaT[key][p_l].zero_() self.tasks_ggT[key][p_l].zero_() else: self.tasks_aaT[key] = [ torch.zeros_like(aaT[key]) for _ in range(num_penalties) ] self.tasks_ggT[key] = [ torch.zeros_like(ggT[key]) for _ in range(num_penalties) ] for key in ffT.keys(): if key in self.tasks_ffT.keys(): self.tasks_ffT[key].zero_() else: self.tasks_ffT[key] = torch.zeros_like(ffT[key]) if t != self.current_task: self.coeffs.append(coeff) for key in ffT.keys(): self.tasks_ffT[key].add_(ffT[key] / tot_ggT) for key in aaT.keys(): aaT[key].div_(cur_num_aaT) ggT[key].div_(cur_num_ggT) if self.args.fisher_ideal == 0: self.tasks_aaT[key][0].add_( (cur_num_aaT / tot_aaT) * aaT[key] ) self.tasks_ggT[key][0].add_(ggT[key]) else: t_hat = t if t <= self.current_task else t - 1 self.tasks_ggT[key][t_hat].copy_(ggT[key]) self.tasks_aaT[key][t_hat].copy_(aaT[key]) del aaT, ggT, ffT all_params = [ p for param_list in self.delta_w_dict.values() for p in param_list ] num_batches: int = len(dataset.train_loader) self.opt, self.scheduler1 = self.optimizer_builder.build_opt_and_sched( all_params, num_batches ) if self.args.resume: self.task_loaded = self.load_task_vectors() if self.task_loaded: print(f"Task vectors for {self.current_task} loaded successfully") self.args.n_epochs = 0 self.n_epochs = 0 self.train()
[docs] def get_parameter_from_dict(self, name): assert name in self.delta_w_names list_params = self.delta_w_dict[name] if len(list_params) == 1: return list_params[0] elif len(list_params) == 2: return list_params[0] @ list_params[1] else: raise ValueError
[docs] def get_all_parameters_from_dict(self): return [self.get_parameter_from_dict(k) for k in self.delta_w_names]
[docs] def end_task( self, dataset: ContinualDataset ) -> None: # TODO set the model in eval mode print(f"Current task: {self.current_task}") self.eval() self.merged_task_vector = [] for i, key in enumerate(self.delta_w_names): self.merged_task_vector.append( torch.clone(self.get_parameter_from_dict(key)) ) actual_seen_classes = self.n_seen_classes self.cls_head = build_classification_head( self.clip_model, dataset, self.cur_offset, all_heads=True ) self._n_seen_classes = dataset.N_CLASSES acc, acc_mask_classes = compute_acc_on_last_task(self, dataset) self.individual_acc.append(acc) self.individual_mask_acc.append(acc_mask_classes) self._n_seen_classes = actual_seen_classes if self.args.save_task_vectors: save_data = [] for key in self.delta_w_names: param_to_save = self.get_parameter_from_dict(key) if isinstance(param_to_save, list): save_data.append([p.clone().cpu() for p in param_to_save]) else: save_data.append(param_to_save.clone().cpu()) base_path = os.path.join( self.args.checkpoint_path, self.args.fisher_cache, f"{self.args.conf_jobnum}_{dataset.NAME}_task_{self.current_task}", ) tv_path = base_path + ".pt" os.makedirs(os.path.dirname(tv_path), exist_ok=True) torch.save(save_data, tv_path) torch.save(self.cls_head.state_dict(), base_path + "_cls_head.pt") torch.save({"delta_w_names": self.delta_w_names}, base_path + "_meta.pt") print(f"Task vector saved to {tv_path}") del self.merged_task_vector[:] del self.merged_task_vector self.cls_head = build_classification_head( self.clip_model, dataset, self.cur_offset, eval=True ) for i, key in enumerate(self.delta_w_names): num_params = len(self.delta_w_dict[key]) for p_l in range(num_params): self.delta_w_dict[key][p_l].requires_grad = False self.merging.add( {key: self.get_parameter_from_dict(key) for key in self.delta_w_names} ) self.merged_task_vector = self.merging.merge(self.delta_w_names) self.opt.zero_grad() self.opt = None self.net.copy_visual_encoder(self.clip_model) torch.cuda.empty_cache() del self.opt, self.scheduler1, self.delta_w_dict gc.collect() return super().end_task(dataset)
[docs] def end_eval(self, dataset: ContinualDataset, accs: Tuple[List, List]) -> None: def safe_den(y, eps=1e-8): return y if abs(y) >= eps else y + eps self.norm_acc = [ acc / safe_den(self.individual_acc[t]) for t, acc in enumerate(accs[0]) ] self.norm_mask_acc = [ acc / safe_den(self.individual_mask_acc[t]) for t, acc in enumerate(accs[1]) ] if self.args.nowand == 0: wandb.log( { "RESULT_mean_norm_acc": sum(self.norm_acc) / len(self.norm_acc), "RESULT_mean_norm_mask_acc": sum(self.norm_mask_acc) / len(self.norm_mask_acc), "Task": self.current_task, } ) if self.current_task == self.num_total_tasks-1: self.compute_metrics_by_alpha()
[docs] def penalty_weight(self): loss_reg, loss_reg_ffT, loss_ft_proj = 0, 0, 0 loss_reg_cls_emb = 0 for name in self.delta_w_names: if name in self.tasks_aaT.keys(): delta_W = self.get_parameter_from_dict(name) bias_name = name.replace("weight", "bias") if name.replace("weight", "bias") in self.delta_w_names: assert "weight" in name delta_bias = self.get_parameter_from_dict(bias_name) delta_W = torch.cat((delta_W, delta_bias.unsqueeze(1)), 1) for task_id in range(len(self.tasks_aaT[name])): aaT_past = self.tasks_aaT[name][task_id] ggT_past = self.tasks_ggT[name][task_id] norm_coeff = self.coeffs[task_id] if self.args.fisher_ideal else 1 loss_ = torch.trace(ggT_past @ delta_W @ aaT_past @ delta_W.T) if name == "lin_proj.weight": loss_ft_proj += norm_coeff * loss_ else: loss_reg += norm_coeff * loss_ if name in self.tasks_ffT.keys(): delta_W = self.get_parameter_from_dict(name).unsqueeze(0) ffT_past = self.tasks_ffT[name] reg_w = torch.trace(delta_W @ ffT_past @ delta_W.T) if "class_embedding" in name: loss_reg_cls_emb += reg_w else: loss_reg_ffT += reg_w return loss_reg, loss_reg_ffT, loss_ft_proj, loss_reg_cls_emb
[docs] def create_functional(self, inputs, delta_names): def func_network(param_values): param = {name: param for name, param in zip(delta_names, param_values)} features = func.functional_call(self.net.visual_encoder, param, inputs) return nn.functional.normalize(features, dim=-1) return func_network
[docs] def observe(self, inputs, labels, not_aug_inputs, epoch=None): if self.args.resume and self.task_loaded: return 0.0 if self.args.tangent: forward_fun = self.create_functional(inputs, self.delta_w_names) params = [ param for name, param in self.net.visual_encoder.named_parameters() if name in self.delta_w_names ] image_features, jvp = func.jvp( forward_fun, (tuple(params),), (tuple(self.get_all_parameters_from_dict()),), ) image_features = image_features + jvp else: tunable_params = [ p for n, p in self.net.visual_encoder.named_parameters() if n in self.delta_w_names ] dict_param = { name: param + net_param for name, param, net_param in zip( self.delta_w_names, self.get_all_parameters_from_dict(), tunable_params, ) } image_features = func.functional_call( self.net.visual_encoder, dict_param, inputs ) image_features = nn.functional.normalize(image_features, dim=-1) similarity = self.cls_head(image_features) loss_task = self.loss(similarity, labels - self.n_past_classes) loss = loss_task / self.args.virtual_bs_n loss.backward() if ( (self.args.load_fisher) and (self.task_iteration > 0) and (self.task_iteration % self.args.virtual_bs_n == 0) ): loss_penalty, loss_reg_ffT, loss_ft_proj, loss_reg_cls_emb = ( self.penalty_weight() ) loss_reg = ( self.args.reg_lambda * loss_penalty + self.args.reg_lambda * self.args.fisher_norm_scaler * loss_reg_ffT + self.args.reg_lambda * self.args.fisher_ft_proj_scaler * loss_ft_proj + self.args.reg_lambda * self.args.fisher_norm_scaler * loss_reg_cls_emb ) loss_reg.backward() if ( self.task_iteration > 0 ) and self.task_iteration % self.args.virtual_bs_n == 0: if self.scheduler1: self.scheduler1(self.task_iteration // self.args.virtual_bs_n) if self.args.clip_grad_norm is not None and self.args.clip_grad_norm > 0: torch.nn.utils.clip_grad_norm_( (p for group in self.opt.param_groups for p in group["params"]), self.args.clip_grad_norm, ) self.opt.step() self.opt.zero_grad() return loss.item()
[docs] @torch.no_grad() def forward(self, x): if self.args.tangent: forward_fun = self.create_functional(x, self.delta_w_names) params = [ param for name, param in self.net.visual_encoder.named_parameters() if name in self.delta_w_names ] image_features, jvp = func.jvp( forward_fun, (tuple(params),), (tuple(self.merged_task_vector),), ) image_features = image_features + jvp else: tunable_params = { n: p for n, p in self.net.visual_encoder.named_parameters() if n in self.delta_w_names } dict_param = {} for i, key in enumerate(self.delta_w_names): dict_param[key] = tunable_params[key] + self.merged_task_vector[i] image_features = func.functional_call( self.net.visual_encoder, dict_param, x ) image_features = nn.functional.normalize(image_features, dim=-1) similarity = self.cls_head(image_features) return similarity[:, : self.n_seen_classes]
[docs] def get_debug_iters(self): return 5
[docs] def load_task_vectors(self): """ Returns: bool: True if loading was successful, False otherwise """ print( f"Loading task vector number: {self.current_task} from: {self.args.load_path}" ) tv_path = self.args.load_path.replace( "_args.json", f"_task_{self.current_task}.pt" ) if os.path.exists(tv_path): print(f"Found task vector {self.current_task}: {tv_path}") else: print(f"WARNING: Missing task vector for task {self.current_task}") return False # Load task vectors task_vectors = torch.load(tv_path, map_location=self.device) # Load metadata from the last task's task vector meta_path = tv_path.replace(".pt", "_meta.pt") if os.path.exists(meta_path): meta_data = torch.load(meta_path, map_location=self.device) delta_w_names = meta_data["delta_w_names"] print(f"Loaded metadata with {len(self.delta_w_names)} parameters") else: print("WARNING: Metadata file not found") return False # self.delta_w_dict = {name: param for name, param in zip(self.delta_w_names, task_vectors)} for name, param in zip(delta_w_names, task_vectors): self.delta_w_dict[name] = [param.clone().detach().to(self.device)] print(f"Added task vector {self.current_task} to merging") return True
[docs] def compute_metrics_by_alpha(self): def safe_den(y, eps=1e-8): return y if abs(y) >= eps else y + eps if not hasattr(self, "_alpha_sweep_next_id"): self._alpha_sweep_next_id = 1 sweep_id = self._alpha_sweep_next_id self._alpha_sweep_next_id += 1 if sweep_id == 1: metric_name = "alpha_sweep" norm_metric_name = "norm_alpha_sweep" else: metric_name = f"alpha_sweep_{sweep_id}" norm_metric_name = f"norm_alpha_sweep_{sweep_id}" if self.args.nowand == 0: if not hasattr(self, "_alpha_metric_defined"): wandb.define_metric("alpha") self._alpha_metric_defined = True wandb.define_metric(metric_name, step_metric="alpha") wandb.define_metric(norm_metric_name, step_metric="alpha") alphas = np.arange( self.args.alpha_sweep_start, self.args.alpha_sweep_end + self.args.alpha_sweep_step, self.args.alpha_sweep_step, ).tolist() alphas = [a * self.num_total_tasks for a in alphas] # alpha sweep for alpha in alphas: self.merging.set_alpha(alpha) self.merged_task_vector = self.merging.merge(self.delta_w_names) accs, accs_mask_classes = evaluate(self, self.dataset) norm_mask_acc = [ acc / safe_den(self.individual_mask_acc[t]) for t, acc in enumerate(accs_mask_classes) ] print( f"Alpha: {alpha} - Acc: {sum(accs_mask_classes) / len(accs_mask_classes):.4f} - Norm Acc: {sum(norm_mask_acc) / len(norm_mask_acc):.4f}" ) print(f"single tasks accs: {accs_mask_classes}") if self.args.nowand == 0: wandb.log( { "alpha": alpha, metric_name: sum(accs_mask_classes) / len(accs_mask_classes), norm_metric_name: sum(norm_mask_acc) / len(norm_mask_acc), } )