Source code for utils.schedulers

from argparse import Namespace
import math
import numpy as np
import torch
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
import torch.optim.lr_scheduler as scheds

from datasets.utils.continual_dataset import ContinualDataset
from models.utils.continual_model import ContinualModel


[docs] def get_scheduler(model: ContinualModel, args: Namespace, reload_optim=True) -> torch.optim.lr_scheduler._LRScheduler: """ Returns the scheduler to be used for the current dataset. If `reload_optim` is True, the optimizer is reloaded from the model. This should be done at least ONCE every task to ensure that the learning rate is reset to the initial value. """ if args.lr_scheduler is not None: if reload_optim or not hasattr(model, 'opt'): model.opt = model.get_optimizer() # check if lr_scheduler is in torch.optim.lr_scheduler supported_scheds = {sched_name.lower(): sched_name for sched_name in dir(scheds) if sched_name.lower() in ContinualDataset.AVAIL_SCHEDS} sched = None if args.lr_scheduler.lower() in supported_scheds: if args.lr_scheduler.lower() == 'multisteplr': assert args.lr_milestones is not None, 'MultiStepLR requires `--lr_milestones`' sched = getattr(scheds, supported_scheds[args.lr_scheduler.lower()])(model.opt, milestones=args.lr_milestones, gamma=args.sched_multistep_lr_gamma) if sched is None: raise ValueError('Unknown scheduler: {}'.format(args.lr_scheduler)) return sched return None
[docs] class CosineSchedule(_LRScheduler): def __init__(self, optimizer, K): """ Apply cosine learning rate schedule to all the parameters in the optimizer. """ assert K > 1, "K must be greater than 1" self.K = K super().__init__(optimizer)
[docs] def cosine(self, base_lr): if self.last_epoch == 0: return base_lr return base_lr * math.cos((99 * math.pi * (self.last_epoch)) / (200 * (self.K - 1)))
[docs] def get_lr(self): return [self.cosine(base_lr) for base_lr in self.base_lrs]
[docs] class CosineSchedulerWithLinearWarmup(_LRScheduler): def __init__(self, optimizer: Optimizer, base_lrs: list | float, warmup_length: int, steps: int): """ Apply cosine learning rate schedule with warmup to all the parameters in the optimizer. If more than one param_group is passed, the learning rate must either be a list of the same length or a float. Args: optimizer (torch.optim.Optimizer): Optimizer to which the learning rate will be applied. base_lrs (list | float): Initial learning rate. warmup_length (int): Number of warmup steps. The learning rate will linearly increase from 0 to base_lr during this period. steps (int): Total number of steps. """ self.warmup_length = warmup_length self.steps = steps self.base_lrs = base_lrs super().__init__(optimizer) if not isinstance(base_lrs, list): base_lrs = [base_lrs for _ in optimizer.param_groups] assert len(base_lrs) == len(optimizer.param_groups)
[docs] def get_lr(self): ret_lrs = [] for base_lr in self.base_lrs: if self.last_epoch < self.warmup_length: lr = base_lr * (self.last_epoch + 1) / self.warmup_length else: e = self.last_epoch - self.warmup_length es = self.steps - self.warmup_length lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr ret_lrs.append(lr) return ret_lrs