from typing import Literal, overload
import torch
import math
import os
import hashlib
import json
import logging
import re
import warnings
import urllib.request
from urllib.error import HTTPError
from urllib.parse import quote, urljoin
from utils import binary_to_boolean_type
from models.utils.continual_model import ContinualModel
from datasets.utils.continual_dataset import ContinualDataset
from tqdm import tqdm
try:
import clip # noqa: F401
except ImportError:
raise ImportError(
"Please install the CLIP package by running: pip install git+https://github.com/openai/CLIP.git"
)
import numpy as np
from utils.conf import get_checkpoint_path
[docs]
def set_requires_grad_to(model, namevars, mode: bool):
for n, p in model.named_parameters():
if n in namevars:
p.requires_grad = mode
[docs]
def add_clip_args(parser):
parser.add_argument(
"--clip_backbone",
type=str,
default="ViT-B/16",
help="Backbone architecture for CLIP",
choices=["ViT-B/16", "ViT-B/32", "ViT-L/14"],
)
parser.add_argument(
"--ft_linears",
type=binary_to_boolean_type,
default=1,
help="Set to 1 fine-tune linear layers",
)
parser.add_argument(
"--ft_attention",
type=binary_to_boolean_type,
default=1,
help="Set to 1 fine-tune attention layers",
)
parser.add_argument(
"--ft_ln",
type=binary_to_boolean_type,
default=1,
help="Set to 1 fine-tune layer norm",
)
parser.add_argument(
"--ft_class_embed",
type=binary_to_boolean_type,
default=1,
help="Set to 1 fine-tune class embedding layers",
)
parser.add_argument(
"--ft_proj",
type=binary_to_boolean_type,
default=1,
help="Set to 1 fine-tune projection layers",
)
parser.add_argument(
"--ft_pos_embed",
type=binary_to_boolean_type,
default=0,
help="Set to 1 fine-tune posistional embedding",
)
parser.add_argument(
"--ft_conv",
type=binary_to_boolean_type,
default=0,
help="Set to 1 fine-tune convolutional layers",
)
[docs]
class OptimizerBuilder:
def __init__(self, cmd_args):
self.args = cmd_args
[docs]
def build_opt_and_sched(self, all_params, num_batches):
opt, sched = None, None
if self.args.optimizer == "adamw":
opt = torch.optim.AdamW(
all_params, lr=self.args.lr, weight_decay=self.args.optim_wd
)
elif self.args.optimizer == "sgd":
opt = torch.optim.SGD(
all_params,
lr=self.args.lr,
momentum=self.args.optim_mom,
weight_decay=self.args.optim_wd,
)
else:
raise ValueError
reduction_factor = getattr(self.args, "epochs_factor_reduction", 1)
if self.args.scheduler_ntk == "none":
pass
elif self.args.scheduler_ntk == "cosine":
num_total_steps = self.args.n_epochs * (
num_batches // self.args.virtual_bs_n
)
sched = cosine_lr(
opt, self.args.lr, 500 / reduction_factor, num_total_steps, 0
)
elif self.args.scheduler_ntk == "cosine_talos":
num_total_steps = self.args.n_epochs * (
num_batches // self.args.virtual_bs_n
)
sched = cosine_lr(opt, self.args.lr, 200, num_total_steps, 0)
elif self.args.scheduler_ntk == "cosine_plus":
num_total_steps = self.args.n_epochs * (
num_batches // self.args.virtual_bs_n
)
warmup_steps = int(0.1 * num_total_steps)
sched = cosine_lr(
opt, self.args.lr, warmup_steps, num_total_steps, 0.1 * self.args.lr
)
elif self.args.scheduler_ntk == "decay":
sched = cosine_lr(opt, self.args.lr, 0, self.args.n_epochs * num_batches, 0)
elif self.args.scheduler_ntk == "step":
num_steps = self.args.n_epochs * num_batches // self.args.virtual_bs_n
warmup_steps = int(0.1 * num_steps)
sched = step_lr_decay(opt, self.args.lr, warmup_steps, num_steps)
else:
raise ValueError
return opt, sched
[docs]
def build_opt_and_sched_multiple_lr(
self, params_group_1, params_group_2, num_batches
):
opt, sched = None, None
lr_group_1 = self.args.lr
lr_group_2 = getattr(self.args, "lr2", None)
if lr_group_2 is None:
lr_group_2 = getattr(self.args, "lr_lin", None)
if lr_group_2 is None:
lr_group_2 = getattr(self.args, "lr_second", None)
if lr_group_2 is None or lr_group_2 == 0:
lr_group_2 = self.args.lr
param_groups = [
{"params": params_group_1, "lr": lr_group_1},
{"params": params_group_2, "lr": lr_group_2},
]
if self.args.optimizer == "adamw":
opt = torch.optim.AdamW(
param_groups, lr=lr_group_1, weight_decay=self.args.optim_wd
)
elif self.args.optimizer == "sgd":
opt = torch.optim.SGD(
param_groups,
lr=lr_group_1,
momentum=self.args.optim_mom,
weight_decay=self.args.optim_wd,
)
else:
raise ValueError
base_lrs = [lr_group_1, lr_group_2]
if self.args.scheduler_ntk == "none":
pass
elif self.args.scheduler_ntk == "cosine":
num_total_steps = self.args.n_epochs * (
num_batches // self.args.virtual_bs_n
)
sched = cosine_lr(opt, base_lrs, 500, num_total_steps, 0)
elif self.args.scheduler_ntk == "cosine_talos":
num_total_steps = self.args.n_epochs * (
num_batches // self.args.virtual_bs_n
)
sched = cosine_lr(opt, base_lrs, 200, num_total_steps, 0)
elif self.args.scheduler_ntk == "cosine_plus":
num_total_steps = self.args.n_epochs * (
num_batches // self.args.virtual_bs_n
)
warmup_steps = int(0.1 * num_total_steps)
min_lrs = [0.1 * lr for lr in base_lrs]
sched = cosine_lr(opt, base_lrs, warmup_steps, num_total_steps, min_lrs)
elif self.args.scheduler_ntk == "decay":
sched = cosine_lr(opt, base_lrs, 0, self.args.n_epochs * num_batches, 0)
elif self.args.scheduler_ntk == "step":
num_steps = self.args.n_epochs * num_batches // self.args.virtual_bs_n
warmup_steps = int(0.1 * num_steps)
sched = step_lr_decay(opt, base_lrs, warmup_steps, num_steps)
else:
raise ValueError
return opt, sched
[docs]
@torch.no_grad()
def compute_acc_on_last_task(model: ContinualModel, dataset: ContinualDataset):
test_loader = dataset.test_loaders[-1]
total_len = len(test_loader) if hasattr(test_loader, "__len__") else None
pbar = tqdm(
test_loader, total=total_len, desc="Evaluating", disable=model.args.non_verbose
)
correct, correct_mask_classes, total = 0.0, 0.0, 0.0
test_iter = iter(test_loader)
i = 0
num_classes = dataset.N_CLASSES
while True:
try:
data = next(test_iter)
except StopIteration:
break
if model.args.debug_mode and i > model.get_debug_iters():
break
inputs, labels = data[0], data[1]
inputs, labels = inputs.to(model.device), labels.to(model.device)
outputs = model.forward(inputs)
assert outputs.shape[1] == num_classes
_, pred = torch.max(outputs, 1)
correct += torch.sum(pred == labels).item()
total += labels.shape[0]
i += 1
pbar.set_postfix(
{f"acc_task_{model.current_task + 1}": max(0, correct / total * 100)},
refresh=False,
)
pbar.set_description(f"Evaluating Task {model.current_task + 1}", refresh=False)
pbar.update(1)
start_c, end_c = dataset.get_offsets(model.current_task)
outputs[:, :start_c] = -float("inf")
outputs[:, end_c:num_classes] = -float("inf")
_, pred = torch.max(outputs.data, 1)
correct_mask_classes += torch.sum(pred == labels).item()
acc = correct / total * 100
acc_mask_classes = correct_mask_classes / total * 100
pbar.close()
return acc, acc_mask_classes
[docs]
def make_psd(x, to64=False):
orig_dtype = x.dtype
if to64:
x = x.to(torch.float64)
eigvals, eigvecs = torch.linalg.eigh(x)
eigvals_clamped = torch.clamp(eigvals, min=0.0)
x_psd = (eigvecs * eigvals_clamped) @ eigvecs.t()
return x_psd.to(orig_dtype) if to64 else x_psd
[docs]
class FisherLoader:
def __init__(
self,
fisher_cache,
dataset_name,
device,
fp_precision="fp32",
fallback_dataset_name=None,
):
self.dataset_name = dataset_name
self.fallback_dataset_name = fallback_dataset_name
self.device = device
self.fisher_cache = fisher_cache
self.fp_precision = fp_precision
self.postprocessing = None
self._dataset_name_hints: list[str] | None = None
@staticmethod
def _append_unique(items: list[str], value: str | None) -> None:
if value and value not in items:
items.append(value)
def _extract_path_dataset_hints(self) -> list[str]:
hints: list[str] = []
if not isinstance(self.fisher_cache, str):
return hints
parts = self.fisher_cache.replace("@", "/").split("/")
for part in parts:
if not part.startswith("fisher_"):
continue
dataset_hint = part[len("fisher_") :].strip()
if dataset_hint in {"", "cache"}:
continue
self._append_unique(hints, dataset_hint)
if not dataset_hint.startswith("seq-"):
self._append_unique(hints, f"seq-{dataset_hint}")
return hints
def _list_cache_entries(self) -> list[str]:
if self._is_hf_source():
repo_id, base_path, revision = self._parse_hf_source(self.fisher_cache)
encoded_repo = quote(repo_id, safe="")
encoded_revision = quote(revision, safe="")
encoded_base_path = quote(base_path.strip("/"), safe="/")
endpoint = (
f"https://huggingface.co/api/models/{encoded_repo}/tree/{encoded_revision}"
)
if encoded_base_path:
endpoint = f"{endpoint}/{encoded_base_path}"
try:
with urllib.request.urlopen(endpoint) as response:
payload = json.loads(response.read().decode("utf-8"))
if not isinstance(payload, list):
return []
return [
item["path"]
for item in payload
if isinstance(item, dict)
and item.get("type") == "file"
and isinstance(item.get("path"), str)
]
except Exception:
return []
if self._is_http_source():
return []
try:
return [
os.path.join(self.fisher_cache, name)
for name in os.listdir(self.fisher_cache)
]
except Exception:
return []
def _extract_dataset_hints_from_entries(self) -> list[str]:
hints: list[str] = []
pattern = re.compile(
r"^(?P<dataset>.+)_task_\d+_(?:num_(?:aaT|ggT)|aaT|ggT|ffT)\.pt$"
)
for entry in self._list_cache_entries():
filename = os.path.basename(entry)
match = pattern.match(filename)
if not match:
continue
dataset_hint = match.group("dataset")
self._append_unique(hints, dataset_hint)
return hints
def _get_dataset_name_hints(self) -> list[str]:
if self._dataset_name_hints is not None:
return self._dataset_name_hints
hints: list[str] = []
for hint in self._extract_path_dataset_hints():
self._append_unique(hints, hint)
for hint in self._extract_dataset_hints_from_entries():
self._append_unique(hints, hint)
self._dataset_name_hints = hints
return hints
def _dataset_name_candidates(self) -> list[str]:
names = [self.dataset_name]
if (
self.fallback_dataset_name is not None
and self.fallback_dataset_name not in names
):
names.append(self.fallback_dataset_name)
for hint in self._get_dataset_name_hints():
self._append_unique(names, hint)
return names
def _try_resolve_file(self, filename: str) -> str | None:
try:
file_path = self._resolve_file(filename)
except FileNotFoundError:
return None
except HTTPError as e:
if e.code == 404:
return None
raise
except Exception as e:
if e.__class__.__name__ in {"EntryNotFoundError", "LocalEntryNotFoundError"}:
return None
raise
if os.path.exists(file_path):
return file_path
return None
def _resolve_count_paths(self, task_id: int) -> tuple[str, str, str]:
for dataset_name in self._dataset_name_candidates():
base_name = f"{dataset_name}_task_{task_id}"
aaT_count_path = self._try_resolve_file(f"{base_name}_num_aaT.pt")
ggT_count_path = self._try_resolve_file(f"{base_name}_num_ggT.pt")
if aaT_count_path is None and ggT_count_path is None:
continue
if aaT_count_path is None or ggT_count_path is None:
raise FileNotFoundError(
f"Incomplete Fisher counts for `{base_name}` in `{self.fisher_cache}`. "
"Expected both `_num_aaT.pt` and `_num_ggT.pt`."
)
return base_name, aaT_count_path, ggT_count_path
expected = [f"{name}_task_{task_id}" for name in self._dataset_name_candidates()]
raise FileNotFoundError(
f"Fisher cache for task {task_id} not found in `{self.fisher_cache}`. "
f"Tried dataset prefixes: {expected}."
)
[docs]
def has_task(self, task_id: int) -> bool:
try:
self._resolve_count_paths(task_id)
return True
except FileNotFoundError:
return False
[docs]
def get_available_task_ids(self, max_tasks: int) -> list[int]:
task_ids = []
for task_id in range(max_tasks):
if self.has_task(task_id):
task_ids.append(task_id)
return task_ids
def _is_hf_source(self) -> bool:
return isinstance(self.fisher_cache, str) and self.fisher_cache.startswith(
"hf://"
)
def _is_http_source(self) -> bool:
return isinstance(self.fisher_cache, str) and self.fisher_cache.startswith(
("http://", "https://")
)
def _is_remote_source(self) -> bool:
return self._is_hf_source() or self._is_http_source()
def _remote_cache_dir(self) -> str:
src_hash = hashlib.sha256(self.fisher_cache.encode("utf-8")).hexdigest()[:16]
cache_dir = os.path.join(
get_checkpoint_path(), "fisher_remote_cache", self.dataset_name, src_hash
)
os.makedirs(cache_dir, exist_ok=True)
return cache_dir
@staticmethod
def _assert_not_lfs_pointer(file_path: str) -> None:
with open(file_path, "rb") as f:
header = f.read(256)
if header.startswith(b"version https://git-lfs.github.com/spec/v1"):
raise ValueError(
f"Downloaded file `{file_path}` is a Git LFS pointer, not the binary artifact. "
"Use a direct/raw artifact URL (or Hugging Face resolve URL) instead of a pointer URL."
)
@staticmethod
def _parse_hf_source(spec: str) -> tuple[str, str, str]:
assert spec.startswith("hf://")
payload = spec[len("hf://") :]
if "@" in payload:
payload, revision = payload.rsplit("@", 1)
else:
revision = "main"
parts = payload.split("/")
if len(parts) < 2:
raise ValueError(
"Invalid HF source format. Use `hf://<owner>/<repo>/<optional/subpath>@<optional_revision>`"
)
repo_id = "/".join(parts[:2])
base_path = "/".join(parts[2:])
return repo_id, base_path, revision
def _download_http_file(self, remote_url: str, local_path: str) -> str:
os.makedirs(os.path.dirname(local_path), exist_ok=True)
with (
urllib.request.urlopen(remote_url) as source,
open(local_path, "wb") as output,
):
output.write(source.read())
self._assert_not_lfs_pointer(local_path)
return local_path
def _resolve_file(self, filename: str) -> str:
if not self._is_remote_source():
return os.path.join(self.fisher_cache, filename)
cache_dir = self._remote_cache_dir()
local_path = os.path.join(cache_dir, filename)
if os.path.exists(local_path):
return local_path
if self._is_hf_source():
try:
from huggingface_hub import hf_hub_download
except ImportError as e:
raise ImportError(
"huggingface_hub is required for `hf://` Fisher cache sources. "
"Install it with `pip install huggingface_hub`."
) from e
repo_id, base_path, revision = self._parse_hf_source(self.fisher_cache)
hf_filename = f"{base_path}/{filename}" if base_path else filename
downloaded_path = hf_hub_download(
repo_id=repo_id,
filename=hf_filename,
revision=revision,
repo_type="model",
cache_dir=cache_dir,
)
self._assert_not_lfs_pointer(downloaded_path)
logging.info(
f"Downloaded Fisher file from HF: {repo_id}/{hf_filename}@{revision}"
)
return downloaded_path
remote_url = urljoin(self.fisher_cache.rstrip("/") + "/", filename)
logging.info(f"Downloading Fisher file from URL: {remote_url}")
return self._download_http_file(remote_url, local_path)
@overload
def load_kfac(
self, task_id: int, only_counts: Literal[True]
) -> tuple[int, int]: ...
@overload
def load_kfac(
self, task_id: int, only_counts: Literal[False] = False
) -> tuple[
dict[str, torch.Tensor],
dict[str, torch.Tensor],
dict[str, torch.Tensor],
int,
int,
]: ...
[docs]
def load_kfac(
self, task_id: int, only_counts: bool = False
) -> (
tuple[
dict[str, torch.Tensor],
dict[str, torch.Tensor],
dict[str, torch.Tensor],
int,
int,
]
| tuple[int, int]
):
base_name, fisher_cache_path_num_aaT, fisher_cache_path_num_ggT = (
self._resolve_count_paths(task_id)
)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
cur_num_aaT: int = torch.load(
fisher_cache_path_num_aaT, map_location="cpu"
).item()
cur_num_ggT: int = torch.load(
fisher_cache_path_num_ggT, map_location="cpu"
).item()
if only_counts:
logging.info(
f"Loaded Fisher counts for task {task_id} from `{self.fisher_cache}`"
)
return cur_num_ggT, cur_num_aaT
fisher_cache_path_aaT = self._try_resolve_file(f"{base_name}_aaT.pt")
fisher_cache_path_ggT = self._try_resolve_file(f"{base_name}_ggT.pt")
fisher_cache_path_ffT = self._try_resolve_file(f"{base_name}_ffT.pt")
if (
fisher_cache_path_aaT is None
or fisher_cache_path_ggT is None
or fisher_cache_path_ffT is None
):
raise FileNotFoundError(
f"Incomplete Fisher tensors for `{base_name}` in `{self.fisher_cache}`. "
"Expected `_aaT.pt`, `_ggT.pt`, and `_ffT.pt`."
)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
aaT: dict = torch.load(fisher_cache_path_aaT, map_location=self.device)
ggT: dict = torch.load(fisher_cache_path_ggT, map_location=self.device)
ffT: dict = torch.load(fisher_cache_path_ffT, map_location=self.device)
for key in aaT.keys():
if self.fp_precision == "fp64":
aaT[key] = aaT[key].to(torch.float64)
ggT[key] = ggT[key].to(torch.float64)
elif self.fp_precision == "fp32":
aaT[key] = aaT[key].to(torch.float32)
ggT[key] = ggT[key].to(torch.float32)
else:
raise NotImplementedError
for key in ffT.keys():
if self.fp_precision == "fp64":
# ffT[key] = ffT[key].to(torch.float64)
ffT[key] = ffT[key].to(torch.float64)
elif self.fp_precision == "fp32":
# ffT[key] = ffT[key].to(torch.float32)
ffT[key] = ffT[key].to(torch.float32)
else:
raise NotImplementedError
logging.info(
f"Loaded Fisher tensors for task {task_id} from `{self.fisher_cache}`"
)
return ggT, aaT, ffT, cur_num_ggT, cur_num_aaT
[docs]
def store_kfac(self, task_id, ggT, aaT, ffT, num_ggT, num_aaT):
if self._is_remote_source():
raise ValueError(
"Cannot store Fisher cache to remote sources. "
"Set `--fisher_cache` to a local directory when `--load_fisher=0`."
)
os.makedirs(self.fisher_cache, exist_ok=True)
fisher_cache_path = f"{self.fisher_cache}/{self.dataset_name}_task_{task_id}.pt"
torch.save(ggT, fisher_cache_path.replace(".pt", "_ggT.pt"))
torch.save(aaT, fisher_cache_path.replace(".pt", "_aaT.pt"))
torch.save(ffT, fisher_cache_path.replace(".pt", "_ffT.pt"))
torch.save(
torch.tensor([num_ggT]), fisher_cache_path.replace(".pt", "_num_ggT.pt")
)
torch.save(
torch.tensor([num_aaT]), fisher_cache_path.replace(".pt", "_num_aaT.pt")
)
[docs]
def load_ekfac(
self, task_id, only_counts=False
) -> tuple[
dict[str, torch.Tensor],
dict[str, torch.Tensor],
dict[str, torch.Tensor],
int,
int,
]:
fisher_cache_path = f"{self.fisher_cache}/{self.dataset_name}_task_{task_id}.pt"
fisher_cache_path_num_of_examples = fisher_cache_path.replace(
".pt", "_num_of_examples.pt"
)
assert os.path.exists(fisher_cache_path_num_of_examples), (
f"File {fisher_cache_path_num_of_examples} not found."
)
num_of_examples = torch.load(
fisher_cache_path_num_of_examples, map_location="cpu"
).item()
if only_counts:
return num_of_examples
fisher_cache_path_UA = fisher_cache_path.replace(".pt", "_UA.pt")
fisher_cache_path_UG = fisher_cache_path.replace(".pt", "_UG.pt")
fisher_cache_path_D = fisher_cache_path.replace(".pt", "_D.pt")
fisher_cache_path_ffT = fisher_cache_path.replace(".pt", "_ffT.pt")
assert os.path.exists(fisher_cache_path_UA)
assert os.path.exists(fisher_cache_path_UG)
assert os.path.exists(fisher_cache_path_D)
assert os.path.exists(fisher_cache_path_ffT)
UA = torch.load(fisher_cache_path_UA, map_location=self.device)
UG = torch.load(fisher_cache_path_UG, map_location=self.device)
D = torch.load(fisher_cache_path_D, map_location=self.device)
ffT = torch.load(fisher_cache_path_ffT, map_location=self.device)
assert UA.keys() == UG.keys() == D.keys()
for key in UA.keys():
if self.fp_precision == "fp64":
UA[key] = UA[key].to(torch.float64)
UG[key] = UG[key].to(torch.float64)
D[key] = D[key].to(torch.float64)
elif self.fp_precision == "fp32":
UA[key] = UA[key].to(torch.float32)
UG[key] = UG[key].to(torch.float32)
D[key] = D[key].to(torch.float32)
else:
raise NotImplementedError
for key in ffT.keys():
if self.fp_precision == "fp64":
ffT[key] = ffT[key].to(torch.float64)
elif self.fp_precision == "fp32":
ffT[key] = ffT[key].to(torch.float32)
else:
raise NotImplementedError
return UA, UG, D, ffT, num_of_examples
[docs]
def load_diff_ekfac(
self, task_id, only_counts=False
) -> tuple[
dict[str, torch.Tensor],
dict[str, torch.Tensor],
dict[str, torch.Tensor],
int,
int,
]:
fisher_cache_path = f"{self.fisher_cache}/{self.dataset_name}_task_{task_id}.pt"
fisher_cache_path_num_of_examples = fisher_cache_path.replace(
".pt", "_num_of_examples.pt"
)
assert os.path.exists(fisher_cache_path_num_of_examples)
num_of_examples = torch.load(
fisher_cache_path_num_of_examples, map_location="cpu"
).item()
if only_counts:
return num_of_examples
fisher_cache_universe_path = (
f"{self.fisher_cache}/{self.dataset_name}_universe.pt"
)
fisher_cache_path_UA = fisher_cache_universe_path.replace(".pt", "_UA.pt")
fisher_cache_path_UG = fisher_cache_universe_path.replace(".pt", "_UG.pt")
fisher_cache_path_D = fisher_cache_path.replace(".pt", "_D.pt")
fisher_cache_path_ffT = fisher_cache_path.replace(".pt", "_ffT.pt")
assert os.path.exists(fisher_cache_path_UA)
assert os.path.exists(fisher_cache_path_UG)
assert os.path.exists(fisher_cache_path_D)
assert os.path.exists(fisher_cache_path_ffT)
if task_id == 0:
UA = torch.load(fisher_cache_path_UA, map_location=self.device)
UG = torch.load(fisher_cache_path_UG, map_location=self.device)
else:
UA = {}
UG = {}
D = torch.load(fisher_cache_path_D, map_location=self.device)
ffT = torch.load(fisher_cache_path_ffT, map_location=self.device)
if task_id == 0:
assert UA.keys() == UG.keys() == D.keys()
for key in UA.keys():
if self.fp_precision == "fp64":
UA[key] = UA[key].to(torch.float64)
UG[key] = UG[key].to(torch.float64)
D[key] = D[key].to(torch.float64)
elif self.fp_precision == "fp32":
UA[key] = UA[key].to(torch.float32)
UG[key] = UG[key].to(torch.float32)
D[key] = D[key].to(torch.float32)
else:
raise NotImplementedError
for key in ffT.keys():
if self.fp_precision == "fp64":
ffT[key] = ffT[key].to(torch.float64)
elif self.fp_precision == "fp32":
ffT[key] = ffT[key].to(torch.float32)
else:
raise NotImplementedError
return UA, UG, D, ffT, num_of_examples
[docs]
def get_parameter(
shape,
device,
type_init: str = "orto",
transpose: bool = False,
requires_grad: bool = True,
):
param = torch.zeros(*shape, dtype=torch.float32, device=device)
if type_init == "orto":
torch.nn.init.orthogonal_(param)
if type_init == "gaussian":
torch.nn.init.normal_(param, mean=0.0, std=0.1)
if type_init == "kernel":
torch.nn.init.normal_(param, mean=0.0, std=0.036)
if type_init == "attn":
torch.nn.init.normal_(param, mean=1.0, std=0.03)
if type_init == "kaiming":
torch.nn.init.kaiming_uniform_(param, a=math.sqrt(5))
if type_init == "ones":
torch.nn.init.ones_(param)
if transpose:
param = torch.transpose(param, 1, 2)
return torch.nn.Parameter(param, requires_grad=requires_grad)
[docs]
def get_params(
net, features=True, classifier=False, offset_1=-1, offset_2=-1
) -> torch.Tensor:
params = []
for name, param in net.named_parameters():
if "head" in name:
if classifier:
assert offset_1 > -1 and offset_2 > -1
params.append(param[offset_1:offset_2].view(-1))
elif features:
params.append(param.view(-1))
if len(params):
return torch.cat(params)
else:
return torch.tensor([0.0])
[docs]
def set_params(
net,
new_params: torch.Tensor,
features=True,
classifier=False,
offset_1=-1,
offset_2=-1,
) -> None:
progress = 0
for name, param in net.named_parameters():
if "head" in name:
if classifier:
assert offset_1 > -1 and offset_2 > -1
cur_size = torch.tensor(param.data[offset_1:offset_2].size()).prod()
param.data[offset_1:offset_2] = new_params[
progress : progress + cur_size
].view(param.data[offset_1:offset_2].size())
progress += cur_size
elif features:
cur_size = torch.tensor(param.size()).prod()
cand_params = new_params[progress : progress + cur_size].view(param.size())
param.data = cand_params
progress += cur_size
[docs]
def get_delta_w_backbone(named_params, delta_w, delta_w_names, training_type, device):
params = []
for name, param in named_params():
name = name.replace("visual_encoder.", "")
if "head" not in name:
if name in delta_w_names:
index = delta_w_names.index(name)
cur_delta_w = delta_w[index]
params.append(cur_delta_w.view(-1).to(device))
elif name == "logit_scale":
# else:
# params.append(torch.zeros_like(param).view(-1).to(device))
print(name)
print("ops siamo finiti in sto posto strano ma non facciamo nulla")
# params.append(torch.clone(param).view(-1).to(device))
if len(params):
return torch.cat(params)
else:
return torch.tensor([0.0]).to(device)
[docs]
def get_delta_w_parameterlist(named_params, delta_w, delta_w_names, peft_type, device):
params = []
for name, param in named_params():
if name in delta_w_names:
index = delta_w_names.index(name)
cur_delta_w = None
if peft_type == "lora":
cur_delta_w = delta_w[index][0] @ delta_w[index][1]
elif peft_type == "full":
cur_delta_w = delta_w[index]
assert cur_delta_w
params.append(cur_delta_w.to(device))
else:
params.append(torch.zeros_like(param).to(device))
return params
[docs]
def replace_non_dynamically_quantizable_linear(module):
"""Recursively replace all NonDynamicallyQuantizableLinear layers with Linear layers in a model."""
for name, child in module.named_children():
if isinstance(child, torch.nn.modules.linear.NonDynamicallyQuantizableLinear):
# Replace with an equivalent Linear layer
new_layer = torch.nn.Linear(
child.in_features, child.out_features, bias=child.bias is not None
)
new_layer.weight = torch.nn.Parameter(child.weight.clone()) # Copy weights
if child.bias is not None:
new_layer.bias = torch.nn.Parameter(child.bias.clone()) # Copy bias
setattr(module, name, new_layer)
else:
replace_non_dynamically_quantizable_linear(
child
) # Recursively process children
return module
[docs]
def assign_learning_rate(param_group, new_lr):
param_group["lr"] = new_lr
def _warmup_lr(base_lr, warmup_length, step):
return base_lr * (step + 1) / warmup_length
[docs]
def cosine_lr(optimizer, base_lrs, warmup_length, steps, min_lr):
if not isinstance(base_lrs, list):
base_lrs = [base_lrs for _ in optimizer.param_groups]
if not isinstance(min_lr, list):
min_lr_list = [min_lr for _ in optimizer.param_groups]
else:
min_lr_list = min_lr
assert len(base_lrs) == len(optimizer.param_groups) == len(min_lr_list)
def _lr_adjuster(step):
for param_group, base_lr, group_min_lr in zip(
optimizer.param_groups, base_lrs, min_lr_list
):
if step < warmup_length:
lr = _warmup_lr(base_lr, warmup_length, step)
else:
e = step - warmup_length
es = steps - warmup_length
lr = group_min_lr + 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr
assign_learning_rate(param_group, lr)
return _lr_adjuster
[docs]
def step_lr_decay(optimizer, base_lrs, warmup_length, steps):
if not isinstance(base_lrs, list):
base_lrs = [base_lrs for _ in optimizer.param_groups]
assert len(base_lrs) == len(optimizer.param_groups)
def _lr_adjuster(step):
for param_group, base_lr in zip(optimizer.param_groups, base_lrs):
if step < warmup_length:
lr = _warmup_lr(base_lr, warmup_length, step)
else:
progress = step / steps
if progress < 0.70:
lr = base_lr
elif progress < 0.90:
lr = base_lr * 0.5
else:
lr = base_lr * 0.1
assign_learning_rate(param_group, lr)
return _lr_adjuster