from typing import Literal, overload
import torch
import math
import os
import hashlib
import logging
import warnings
import urllib.request
from urllib.parse import urljoin
from utils import binary_to_boolean_type
from models.utils.continual_model import ContinualModel
from datasets.utils.continual_dataset import ContinualDataset
from tqdm import tqdm
try:
import clip # noqa: F401
except ImportError:
raise ImportError(
"Please install the CLIP package by running: pip install git+https://github.com/openai/CLIP.git"
)
import numpy as np
from utils.conf import get_checkpoint_path
[docs]
def set_requires_grad_to(model, namevars, mode: bool):
for n, p in model.named_parameters():
if n in namevars:
p.requires_grad = mode
[docs]
def add_clip_args(parser):
parser.add_argument(
"--clip_backbone",
type=str,
default="ViT-B/16",
help="Backbone architecture for CLIP",
choices=["ViT-B/16", "ViT-B/32", "ViT-L/14"],
)
parser.add_argument(
"--ft_linears",
type=binary_to_boolean_type,
default=1,
help="Set to 1 fine-tune linear layers",
)
parser.add_argument(
"--ft_attention",
type=binary_to_boolean_type,
default=1,
help="Set to 1 fine-tune attention layers",
)
parser.add_argument(
"--ft_ln",
type=binary_to_boolean_type,
default=1,
help="Set to 1 fine-tune layer norm",
)
parser.add_argument(
"--ft_class_embed",
type=binary_to_boolean_type,
default=1,
help="Set to 1 fine-tune class embedding layers",
)
parser.add_argument(
"--ft_proj",
type=binary_to_boolean_type,
default=1,
help="Set to 1 fine-tune projection layers",
)
parser.add_argument(
"--ft_pos_embed",
type=binary_to_boolean_type,
default=0,
help="Set to 1 fine-tune posistional embedding",
)
parser.add_argument(
"--ft_conv",
type=binary_to_boolean_type,
default=0,
help="Set to 1 fine-tune convolutional layers",
)
[docs]
class OptimizerBuilder:
def __init__(self, cmd_args):
self.args = cmd_args
[docs]
def build_opt_and_sched(self, all_params, num_batches):
opt, sched = None, None
if self.args.optimizer == "adamw":
opt = torch.optim.AdamW(
all_params, lr=self.args.lr, weight_decay=self.args.optim_wd
)
elif self.args.optimizer == "sgd":
opt = torch.optim.SGD(
all_params,
lr=self.args.lr,
momentum=self.args.optim_mom,
weight_decay=self.args.optim_wd,
)
else:
raise ValueError
reduction_factor = getattr(self.args, "epochs_factor_reduction", 1)
if self.args.scheduler_ntk == "none":
pass
elif self.args.scheduler_ntk == "cosine":
num_total_steps = self.args.n_epochs * (
num_batches // self.args.virtual_bs_n
)
sched = cosine_lr(
opt, self.args.lr, 500 / reduction_factor, num_total_steps, 0
)
elif self.args.scheduler_ntk == "cosine_talos":
num_total_steps = self.args.n_epochs * (
num_batches // self.args.virtual_bs_n
)
sched = cosine_lr(opt, self.args.lr, 200, num_total_steps, 0)
elif self.args.scheduler_ntk == "cosine_plus":
num_total_steps = self.args.n_epochs * (
num_batches // self.args.virtual_bs_n
)
warmup_steps = int(0.1 * num_total_steps)
sched = cosine_lr(
opt, self.args.lr, warmup_steps, num_total_steps, 0.1 * self.args.lr
)
elif self.args.scheduler_ntk == "decay":
sched = cosine_lr(opt, self.args.lr, 0, self.args.n_epochs * num_batches, 0)
elif self.args.scheduler_ntk == "step":
num_steps = self.args.n_epochs * num_batches // self.args.virtual_bs_n
warmup_steps = int(0.1 * num_steps)
sched = step_lr_decay(opt, self.args.lr, warmup_steps, num_steps)
else:
raise ValueError
return opt, sched
[docs]
def build_opt_and_sched_multiple_lr(
self, params_group_1, params_group_2, num_batches
):
opt, sched = None, None
lr_group_1 = self.args.lr
lr_group_2 = getattr(self.args, "lr2", None)
if lr_group_2 is None:
lr_group_2 = getattr(self.args, "lr_lin", None)
if lr_group_2 is None:
lr_group_2 = getattr(self.args, "lr_second", None)
if lr_group_2 is None or lr_group_2 == 0:
lr_group_2 = self.args.lr
param_groups = [
{"params": params_group_1, "lr": lr_group_1},
{"params": params_group_2, "lr": lr_group_2},
]
if self.args.optimizer == "adamw":
opt = torch.optim.AdamW(
param_groups, lr=lr_group_1, weight_decay=self.args.optim_wd
)
elif self.args.optimizer == "sgd":
opt = torch.optim.SGD(
param_groups,
lr=lr_group_1,
momentum=self.args.optim_mom,
weight_decay=self.args.optim_wd,
)
else:
raise ValueError
base_lrs = [lr_group_1, lr_group_2]
if self.args.scheduler_ntk == "none":
pass
elif self.args.scheduler_ntk == "cosine":
num_total_steps = self.args.n_epochs * (
num_batches // self.args.virtual_bs_n
)
sched = cosine_lr(opt, base_lrs, 500, num_total_steps, 0)
elif self.args.scheduler_ntk == "cosine_talos":
num_total_steps = self.args.n_epochs * (
num_batches // self.args.virtual_bs_n
)
sched = cosine_lr(opt, base_lrs, 200, num_total_steps, 0)
elif self.args.scheduler_ntk == "cosine_plus":
num_total_steps = self.args.n_epochs * (
num_batches // self.args.virtual_bs_n
)
warmup_steps = int(0.1 * num_total_steps)
min_lrs = [0.1 * lr for lr in base_lrs]
sched = cosine_lr(opt, base_lrs, warmup_steps, num_total_steps, min_lrs)
elif self.args.scheduler_ntk == "decay":
sched = cosine_lr(opt, base_lrs, 0, self.args.n_epochs * num_batches, 0)
elif self.args.scheduler_ntk == "step":
num_steps = self.args.n_epochs * num_batches // self.args.virtual_bs_n
warmup_steps = int(0.1 * num_steps)
sched = step_lr_decay(opt, base_lrs, warmup_steps, num_steps)
else:
raise ValueError
return opt, sched
[docs]
@torch.no_grad()
def compute_acc_on_last_task(model: ContinualModel, dataset: ContinualDataset):
test_loader = dataset.test_loaders[-1]
total_len = len(test_loader) if hasattr(test_loader, "__len__") else None
pbar = tqdm(
test_loader, total=total_len, desc="Evaluating", disable=model.args.non_verbose
)
correct, correct_mask_classes, total = 0.0, 0.0, 0.0
test_iter = iter(test_loader)
i = 0
num_classes = dataset.N_CLASSES
while True:
try:
data = next(test_iter)
except StopIteration:
break
if model.args.debug_mode and i > model.get_debug_iters():
break
inputs, labels = data[0], data[1]
inputs, labels = inputs.to(model.device), labels.to(model.device)
outputs = model.forward(inputs)
assert outputs.shape[1] == num_classes
_, pred = torch.max(outputs, 1)
correct += torch.sum(pred == labels).item()
total += labels.shape[0]
i += 1
pbar.set_postfix(
{f"acc_task_{model.current_task + 1}": max(0, correct / total * 100)},
refresh=False,
)
pbar.set_description(f"Evaluating Task {model.current_task + 1}", refresh=False)
pbar.update(1)
start_c, end_c = dataset.get_offsets(model.current_task)
outputs[:, :start_c] = -float("inf")
outputs[:, end_c:num_classes] = -float("inf")
_, pred = torch.max(outputs.data, 1)
correct_mask_classes += torch.sum(pred == labels).item()
acc = correct / total * 100
acc_mask_classes = correct_mask_classes / total * 100
pbar.close()
return acc, acc_mask_classes
[docs]
def make_psd(x, to64=False):
orig_dtype = x.dtype
if to64:
x = x.to(torch.float64)
eigvals, eigvecs = torch.linalg.eigh(x)
eigvals_clamped = torch.clamp(eigvals, min=0.0)
x_psd = (eigvecs * eigvals_clamped) @ eigvecs.t()
return x_psd.to(orig_dtype) if to64 else x_psd
[docs]
class FisherLoader:
def __init__(self, fisher_cache, dataset_name, device, fp_precision="fp32"):
self.dataset_name = dataset_name
self.device = device
self.fisher_cache = fisher_cache
self.fp_precision = fp_precision
self.postprocessing = None
def _is_hf_source(self) -> bool:
return isinstance(self.fisher_cache, str) and self.fisher_cache.startswith(
"hf://"
)
def _is_http_source(self) -> bool:
return isinstance(self.fisher_cache, str) and self.fisher_cache.startswith(
("http://", "https://")
)
def _is_remote_source(self) -> bool:
return self._is_hf_source() or self._is_http_source()
def _remote_cache_dir(self) -> str:
src_hash = hashlib.sha256(self.fisher_cache.encode("utf-8")).hexdigest()[:16]
cache_dir = os.path.join(
get_checkpoint_path(), "fisher_remote_cache", self.dataset_name, src_hash
)
os.makedirs(cache_dir, exist_ok=True)
return cache_dir
@staticmethod
def _assert_not_lfs_pointer(file_path: str) -> None:
with open(file_path, "rb") as f:
header = f.read(256)
if header.startswith(b"version https://git-lfs.github.com/spec/v1"):
raise ValueError(
f"Downloaded file `{file_path}` is a Git LFS pointer, not the binary artifact. "
"Use a direct/raw artifact URL (or Hugging Face resolve URL) instead of a pointer URL."
)
@staticmethod
def _parse_hf_source(spec: str) -> tuple[str, str, str]:
assert spec.startswith("hf://")
payload = spec[len("hf://") :]
if "@" in payload:
payload, revision = payload.rsplit("@", 1)
else:
revision = "main"
parts = payload.split("/")
if len(parts) < 2:
raise ValueError(
"Invalid HF source format. Use `hf://<owner>/<repo>/<optional/subpath>@<optional_revision>`"
)
repo_id = "/".join(parts[:2])
base_path = "/".join(parts[2:])
return repo_id, base_path, revision
def _download_http_file(self, remote_url: str, local_path: str) -> str:
os.makedirs(os.path.dirname(local_path), exist_ok=True)
with (
urllib.request.urlopen(remote_url) as source,
open(local_path, "wb") as output,
):
output.write(source.read())
self._assert_not_lfs_pointer(local_path)
return local_path
def _resolve_file(self, filename: str) -> str:
if not self._is_remote_source():
return os.path.join(self.fisher_cache, filename)
cache_dir = self._remote_cache_dir()
local_path = os.path.join(cache_dir, filename)
if os.path.exists(local_path):
return local_path
if self._is_hf_source():
try:
from huggingface_hub import hf_hub_download
except ImportError as e:
raise ImportError(
"huggingface_hub is required for `hf://` Fisher cache sources. "
"Install it with `pip install huggingface_hub`."
) from e
repo_id, base_path, revision = self._parse_hf_source(self.fisher_cache)
hf_filename = f"{base_path}/{filename}" if base_path else filename
downloaded_path = hf_hub_download(
repo_id=repo_id,
filename=hf_filename,
revision=revision,
repo_type="model",
cache_dir=cache_dir,
)
self._assert_not_lfs_pointer(downloaded_path)
logging.info(
f"Downloaded Fisher file from HF: {repo_id}/{hf_filename}@{revision}"
)
return downloaded_path
remote_url = urljoin(self.fisher_cache.rstrip("/") + "/", filename)
logging.info(f"Downloading Fisher file from URL: {remote_url}")
return self._download_http_file(remote_url, local_path)
@overload
def load_kfac(
self, task_id: int, only_counts: Literal[True]
) -> tuple[int, int]: ...
@overload
def load_kfac(
self, task_id: int, only_counts: Literal[False] = False
) -> tuple[
dict[str, torch.Tensor],
dict[str, torch.Tensor],
dict[str, torch.Tensor],
int,
int,
]: ...
[docs]
def load_kfac(
self, task_id: int, only_counts: bool = False
) -> (
tuple[
dict[str, torch.Tensor],
dict[str, torch.Tensor],
dict[str, torch.Tensor],
int,
int,
]
| tuple[int, int]
):
base_name = f"{self.dataset_name}_task_{task_id}"
fisher_cache_path_num_aaT = self._resolve_file(f"{base_name}_num_aaT.pt")
fisher_cache_path_num_ggT = self._resolve_file(f"{base_name}_num_ggT.pt")
with warnings.catch_warnings():
warnings.simplefilter("ignore")
if os.path.exists(fisher_cache_path_num_aaT):
assert os.path.exists(fisher_cache_path_num_ggT)
cur_num_aaT: int = torch.load(
fisher_cache_path_num_aaT, map_location="cpu"
).item()
cur_num_ggT: int = torch.load(
fisher_cache_path_num_ggT, map_location="cpu"
).item()
else:
raise FileNotFoundError(
f"Fisher cache file {fisher_cache_path_num_aaT} or {fisher_cache_path_num_ggT} not found. "
)
if only_counts:
logging.info(
f"Loaded Fisher counts for task {task_id} from `{self.fisher_cache}`"
)
return cur_num_ggT, cur_num_aaT
fisher_cache_path_aaT = self._resolve_file(f"{base_name}_aaT.pt")
fisher_cache_path_ggT = self._resolve_file(f"{base_name}_ggT.pt")
fisher_cache_path_ffT = self._resolve_file(f"{base_name}_ffT.pt")
assert os.path.exists(fisher_cache_path_aaT)
assert os.path.exists(fisher_cache_path_ggT)
assert os.path.exists(fisher_cache_path_ffT)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
aaT: dict = torch.load(fisher_cache_path_aaT, map_location=self.device)
ggT: dict = torch.load(fisher_cache_path_ggT, map_location=self.device)
ffT: dict = torch.load(fisher_cache_path_ffT, map_location=self.device)
for key in aaT.keys():
if self.fp_precision == "fp64":
aaT[key] = aaT[key].to(torch.float64)
ggT[key] = ggT[key].to(torch.float64)
elif self.fp_precision == "fp32":
aaT[key] = aaT[key].to(torch.float32)
ggT[key] = ggT[key].to(torch.float32)
else:
raise NotImplementedError
for key in ffT.keys():
if self.fp_precision == "fp64":
# ffT[key] = ffT[key].to(torch.float64)
ffT[key] = ffT[key].to(torch.float64)
elif self.fp_precision == "fp32":
# ffT[key] = ffT[key].to(torch.float32)
ffT[key] = ffT[key].to(torch.float32)
else:
raise NotImplementedError
logging.info(
f"Loaded Fisher tensors for task {task_id} from `{self.fisher_cache}`"
)
return ggT, aaT, ffT, cur_num_ggT, cur_num_aaT
[docs]
def store_kfac(self, task_id, ggT, aaT, ffT, num_ggT, num_aaT):
if self._is_remote_source():
raise ValueError(
"Cannot store Fisher cache to remote sources. "
"Set `--fisher_cache` to a local directory when `--load_fisher=0`."
)
os.makedirs(self.fisher_cache, exist_ok=True)
fisher_cache_path = f"{self.fisher_cache}/{self.dataset_name}_task_{task_id}.pt"
torch.save(ggT, fisher_cache_path.replace(".pt", "_ggT.pt"))
torch.save(aaT, fisher_cache_path.replace(".pt", "_aaT.pt"))
torch.save(ffT, fisher_cache_path.replace(".pt", "_ffT.pt"))
torch.save(
torch.tensor([num_ggT]), fisher_cache_path.replace(".pt", "_num_ggT.pt")
)
torch.save(
torch.tensor([num_aaT]), fisher_cache_path.replace(".pt", "_num_aaT.pt")
)
[docs]
def load_ekfac(
self, task_id, only_counts=False
) -> tuple[
dict[str, torch.Tensor],
dict[str, torch.Tensor],
dict[str, torch.Tensor],
int,
int,
]:
fisher_cache_path = f"{self.fisher_cache}/{self.dataset_name}_task_{task_id}.pt"
fisher_cache_path_num_of_examples = fisher_cache_path.replace(
".pt", "_num_of_examples.pt"
)
assert os.path.exists(fisher_cache_path_num_of_examples), (
f"File {fisher_cache_path_num_of_examples} not found."
)
num_of_examples = torch.load(
fisher_cache_path_num_of_examples, map_location="cpu"
).item()
if only_counts:
return num_of_examples
fisher_cache_path_UA = fisher_cache_path.replace(".pt", "_UA.pt")
fisher_cache_path_UG = fisher_cache_path.replace(".pt", "_UG.pt")
fisher_cache_path_D = fisher_cache_path.replace(".pt", "_D.pt")
fisher_cache_path_ffT = fisher_cache_path.replace(".pt", "_ffT.pt")
assert os.path.exists(fisher_cache_path_UA)
assert os.path.exists(fisher_cache_path_UG)
assert os.path.exists(fisher_cache_path_D)
assert os.path.exists(fisher_cache_path_ffT)
UA = torch.load(fisher_cache_path_UA, map_location=self.device)
UG = torch.load(fisher_cache_path_UG, map_location=self.device)
D = torch.load(fisher_cache_path_D, map_location=self.device)
ffT = torch.load(fisher_cache_path_ffT, map_location=self.device)
assert UA.keys() == UG.keys() == D.keys()
for key in UA.keys():
if self.fp_precision == "fp64":
UA[key] = UA[key].to(torch.float64)
UG[key] = UG[key].to(torch.float64)
D[key] = D[key].to(torch.float64)
elif self.fp_precision == "fp32":
UA[key] = UA[key].to(torch.float32)
UG[key] = UG[key].to(torch.float32)
D[key] = D[key].to(torch.float32)
else:
raise NotImplementedError
for key in ffT.keys():
if self.fp_precision == "fp64":
ffT[key] = ffT[key].to(torch.float64)
elif self.fp_precision == "fp32":
ffT[key] = ffT[key].to(torch.float32)
else:
raise NotImplementedError
return UA, UG, D, ffT, num_of_examples
[docs]
def load_diff_ekfac(
self, task_id, only_counts=False
) -> tuple[
dict[str, torch.Tensor],
dict[str, torch.Tensor],
dict[str, torch.Tensor],
int,
int,
]:
fisher_cache_path = f"{self.fisher_cache}/{self.dataset_name}_task_{task_id}.pt"
fisher_cache_path_num_of_examples = fisher_cache_path.replace(
".pt", "_num_of_examples.pt"
)
assert os.path.exists(fisher_cache_path_num_of_examples)
num_of_examples = torch.load(
fisher_cache_path_num_of_examples, map_location="cpu"
).item()
if only_counts:
return num_of_examples
fisher_cache_universe_path = (
f"{self.fisher_cache}/{self.dataset_name}_universe.pt"
)
fisher_cache_path_UA = fisher_cache_universe_path.replace(".pt", "_UA.pt")
fisher_cache_path_UG = fisher_cache_universe_path.replace(".pt", "_UG.pt")
fisher_cache_path_D = fisher_cache_path.replace(".pt", "_D.pt")
fisher_cache_path_ffT = fisher_cache_path.replace(".pt", "_ffT.pt")
assert os.path.exists(fisher_cache_path_UA)
assert os.path.exists(fisher_cache_path_UG)
assert os.path.exists(fisher_cache_path_D)
assert os.path.exists(fisher_cache_path_ffT)
if task_id == 0:
UA = torch.load(fisher_cache_path_UA, map_location=self.device)
UG = torch.load(fisher_cache_path_UG, map_location=self.device)
else:
UA = {}
UG = {}
D = torch.load(fisher_cache_path_D, map_location=self.device)
ffT = torch.load(fisher_cache_path_ffT, map_location=self.device)
if task_id == 0:
assert UA.keys() == UG.keys() == D.keys()
for key in UA.keys():
if self.fp_precision == "fp64":
UA[key] = UA[key].to(torch.float64)
UG[key] = UG[key].to(torch.float64)
D[key] = D[key].to(torch.float64)
elif self.fp_precision == "fp32":
UA[key] = UA[key].to(torch.float32)
UG[key] = UG[key].to(torch.float32)
D[key] = D[key].to(torch.float32)
else:
raise NotImplementedError
for key in ffT.keys():
if self.fp_precision == "fp64":
ffT[key] = ffT[key].to(torch.float64)
elif self.fp_precision == "fp32":
ffT[key] = ffT[key].to(torch.float32)
else:
raise NotImplementedError
return UA, UG, D, ffT, num_of_examples
[docs]
def get_parameter(
shape,
device,
type_init: str = "orto",
transpose: bool = False,
requires_grad: bool = True,
):
param = torch.zeros(*shape, dtype=torch.float32, device=device)
if type_init == "orto":
torch.nn.init.orthogonal_(param)
if type_init == "gaussian":
torch.nn.init.normal_(param, mean=0.0, std=0.1)
if type_init == "kernel":
torch.nn.init.normal_(param, mean=0.0, std=0.036)
if type_init == "attn":
torch.nn.init.normal_(param, mean=1.0, std=0.03)
if type_init == "kaiming":
torch.nn.init.kaiming_uniform_(param, a=math.sqrt(5))
if type_init == "ones":
torch.nn.init.ones_(param)
if transpose:
param = torch.transpose(param, 1, 2)
return torch.nn.Parameter(param, requires_grad=requires_grad)
[docs]
def get_params(
net, features=True, classifier=False, offset_1=-1, offset_2=-1
) -> torch.Tensor:
params = []
for name, param in net.named_parameters():
if "head" in name:
if classifier:
assert offset_1 > -1 and offset_2 > -1
params.append(param[offset_1:offset_2].view(-1))
elif features:
params.append(param.view(-1))
if len(params):
return torch.cat(params)
else:
return torch.tensor([0.0])
[docs]
def set_params(
net,
new_params: torch.Tensor,
features=True,
classifier=False,
offset_1=-1,
offset_2=-1,
) -> None:
progress = 0
for name, param in net.named_parameters():
if "head" in name:
if classifier:
assert offset_1 > -1 and offset_2 > -1
cur_size = torch.tensor(param.data[offset_1:offset_2].size()).prod()
param.data[offset_1:offset_2] = new_params[
progress : progress + cur_size
].view(param.data[offset_1:offset_2].size())
progress += cur_size
elif features:
cur_size = torch.tensor(param.size()).prod()
cand_params = new_params[progress : progress + cur_size].view(param.size())
param.data = cand_params
progress += cur_size
[docs]
def get_delta_w_backbone(named_params, delta_w, delta_w_names, training_type, device):
params = []
for name, param in named_params():
name = name.replace("visual_encoder.", "")
if "head" not in name:
if name in delta_w_names:
index = delta_w_names.index(name)
cur_delta_w = delta_w[index]
params.append(cur_delta_w.view(-1).to(device))
elif name == "logit_scale":
# else:
# params.append(torch.zeros_like(param).view(-1).to(device))
print(name)
print("ops siamo finiti in sto posto strano ma non facciamo nulla")
# params.append(torch.clone(param).view(-1).to(device))
if len(params):
return torch.cat(params)
else:
return torch.tensor([0.0]).to(device)
[docs]
def get_delta_w_parameterlist(named_params, delta_w, delta_w_names, peft_type, device):
params = []
for name, param in named_params():
if name in delta_w_names:
index = delta_w_names.index(name)
cur_delta_w = None
if peft_type == "lora":
cur_delta_w = delta_w[index][0] @ delta_w[index][1]
elif peft_type == "full":
cur_delta_w = delta_w[index]
assert cur_delta_w
params.append(cur_delta_w.to(device))
else:
params.append(torch.zeros_like(param).to(device))
return params
[docs]
def replace_non_dynamically_quantizable_linear(module):
"""Recursively replace all NonDynamicallyQuantizableLinear layers with Linear layers in a model."""
for name, child in module.named_children():
if isinstance(child, torch.nn.modules.linear.NonDynamicallyQuantizableLinear):
# Replace with an equivalent Linear layer
new_layer = torch.nn.Linear(
child.in_features, child.out_features, bias=child.bias is not None
)
new_layer.weight = torch.nn.Parameter(child.weight.clone()) # Copy weights
if child.bias is not None:
new_layer.bias = torch.nn.Parameter(child.bias.clone()) # Copy bias
setattr(module, name, new_layer)
else:
replace_non_dynamically_quantizable_linear(
child
) # Recursively process children
return module
[docs]
def assign_learning_rate(param_group, new_lr):
param_group["lr"] = new_lr
def _warmup_lr(base_lr, warmup_length, step):
return base_lr * (step + 1) / warmup_length
[docs]
def cosine_lr(optimizer, base_lrs, warmup_length, steps, min_lr):
if not isinstance(base_lrs, list):
base_lrs = [base_lrs for _ in optimizer.param_groups]
if not isinstance(min_lr, list):
min_lr_list = [min_lr for _ in optimizer.param_groups]
else:
min_lr_list = min_lr
assert len(base_lrs) == len(optimizer.param_groups) == len(min_lr_list)
def _lr_adjuster(step):
for param_group, base_lr, group_min_lr in zip(
optimizer.param_groups, base_lrs, min_lr_list
):
if step < warmup_length:
lr = _warmup_lr(base_lr, warmup_length, step)
else:
e = step - warmup_length
es = steps - warmup_length
lr = group_min_lr + 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr
assign_learning_rate(param_group, lr)
return _lr_adjuster
[docs]
def step_lr_decay(optimizer, base_lrs, warmup_length, steps):
if not isinstance(base_lrs, list):
base_lrs = [base_lrs for _ in optimizer.param_groups]
assert len(base_lrs) == len(optimizer.param_groups)
def _lr_adjuster(step):
for param_group, base_lr in zip(optimizer.param_groups, base_lrs):
if step < warmup_length:
lr = _warmup_lr(base_lr, warmup_length, step)
else:
progress = step / steps
if progress < 0.70:
lr = base_lr
elif progress < 0.90:
lr = base_lr * 0.5
else:
lr = base_lr * 0.1
assign_learning_rate(param_group, lr)
return _lr_adjuster