Source code for hyperion.torch.lr_schedulers.exp_lr

"""
 Copyright 2019 Johns Hopkins University  (Author: Jesus Villalba)
 Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
"""


import torch

from .lr_scheduler import LRScheduler


[docs]class ExponentialLR(LRScheduler): """Exponential learning rate scheduler."""
[docs] def __init__( self, optimizer, decay_rate, decay_steps, hold_steps, min_lr=0, warmup_steps=0, epoch=0, step=0, update_lr_on_opt_step=False, ): super(ExponentialLR, self).__init__( optimizer, min_lr, warmup_steps, epoch, step, update_lr_on_opt_step ) self.decay_rate = decay_rate self.decay_steps = decay_steps self.hold_steps = max(hold_steps, self.warmup_steps)
[docs] def get_lr(self, step): if step < self.hold_steps: return self.base_lrs x = step - self.hold_steps return [ max(min_lr, base_lr * self.decay_rate ** (x / self.decay_steps)) for base_lr, min_lr in zip(self.base_lrs, self.min_lrs) ]
[docs] def load_state_dict(self, state_dict): """Loads the schedulers state. Arguments: state_dict (dict): scheduler state. Should be an object returned from a call to :meth:`state_dict`. """ # we only load step and epoch so we can change the scheduler params during training self.step = state_dict["step"] self.epoch = state_dict["epoch"]
# self.__dict__.update(state_dict)