[docs]classApplyLearningRateScheduler(Generic[LRSchedulerT]):""" Apply a torch learning rate scheduler to a torch optimiser. The scheduler is stepped at each step of optimiser. """
[docs]def__init__(self,scheduler_class:type[LRSchedulerT],*args:Any,**kwargs:Any)->None:""" :param scheduler_class: The (uninstantiated) torch learning rate scheduler to be used. """self.scheduler_class=scheduler_classself.scheduler_kwargs=kwargsself.scheduler_args=argsself.scheduler_takes_loss="metrics"ininspect.signature(scheduler_class.step).parameters
def__call__(self,cls:type[OptimiserT])->type[OptimiserT]:"""Apply scheduler to optimiser."""scheduler_class=self.scheduler_classscheduler_kwargs=self.scheduler_kwargsscheduler_args=self.scheduler_argsscheduler_step_func=self._step_scheduler_with_lossifself.scheduler_takes_losselseself._step_scheduler# Can't use @wraps_class here as it causes a unit test failure?classInnerClass(cls):def__init__(self,*args:Any,**kwargs:Any)->None:super().__init__(*args,**kwargs)self._applied_scheduler=scheduler_class(self,*scheduler_args,**scheduler_kwargs)@overloaddefstep(self,loss:Union[float,torch.Tensor],closure:None)->None:...# pragma: no cover@overloaddefstep(self,loss:Union[float,torch.Tensor],closure:Callable[[],float])->Union[float,torch.Tensor]:...# pragma: no coverdefstep(self,loss:Union[float,torch.Tensor],closure:Optional[Callable[[],float]]=None)->Optional[Union[float,torch.Tensor]]:ret=super().step(closure=closure)scheduler_step_func(self._applied_scheduler,loss)returnretreturnInnerClass@staticmethoddef_step_scheduler(scheduler:LRSchedulerT,_)->None:scheduler.step()@staticmethoddef_step_scheduler_with_loss(scheduler:LRSchedulerT,loss:Union[float,torch.Tensor])->None:scheduler.step(loss)