Source code for vanguard.optimise.optimiser

# © Crown Copyright GCHQ
#
# Licensed under the GNU General Public License, version 3 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.gnu.org/licenses/gpl-3.0.en.html
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Vanguard defines its own optimiser wrapper to enable additional features.
"""

import inspect
from collections import deque
from collections.abc import Generator
from functools import total_ordering
from heapq import heappush, heappushpop, nlargest
from typing import Any, Callable, Generic, Optional, TypeVar, Union, overload

import numpy as np
import torch
from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer

OptimiserT = TypeVar("OptimiserT", bound=Optimizer)


[docs] class SmartOptimiser(Generic[OptimiserT]): """ A smart wrapper around the standard optimisers found in PyTorch which can enable early stopping. .. warning:: When setting the learning rate, using the :meth:`learning_rate` property, the parameters for each registered module are re-initialised. """ _stored_initial_state_dicts: dict[Module, dict[str, Tensor]] last_n_losses: deque[float] _internal_optimiser: OptimiserT
[docs] def __init__( self, optimiser_class: type[OptimiserT], *initial_modules: Module, early_stop_patience: Optional[int] = None, **optimiser_kwargs: Any, ) -> None: """ Initialise self. :param optimiser_class: An uninstantiated subclass of :class:`torch.optim.Optimizer` to be used to create the internal optimiser. :param initial_modules: Initial modules whose parameters will be added to the internal optimiser. :param early_stop_patience: How many consecutive gradient steps of worsening loss to allow before stopping early. Defaults to ``None`` which disables early stopping. :param optimiser_kwargs: Additional keyword arguments to be passed to the internal optimiser. """ self._internal_optimiser_class = optimiser_class self._internal_optimiser_kwargs = optimiser_kwargs self._learning_rate = self._internal_optimiser_kwargs.pop("lr", 0.1) self._early_stop_patience = early_stop_patience self._stored_initial_state_dicts = {} self.last_n_losses = self._get_last_n_losses_structure(self._early_stop_patience) initial_parameters = [] for module in initial_modules: self._cache_module_parameters(module) initial_parameters.append({"params": module.parameters()}) self._internal_optimiser = self._internal_optimiser_class( initial_parameters, lr=self._learning_rate, **self._internal_optimiser_kwargs ) self._set_step_method()
@property def learning_rate(self) -> float: """Return the learning rate.""" return self._learning_rate @learning_rate.setter def learning_rate(self, value: float) -> None: """Set the value of the learning rate.""" self._learning_rate = value self.reset()
[docs] def parameters(self) -> Generator[Any, None, None]: """Get all parameters known to the optimiser.""" for param_group in self._internal_optimiser.param_groups: yield from param_group["params"]
[docs] def reset(self) -> None: """Reset everything.""" self._reset_module_parameters() self._reset_internal_optimiser() self.last_n_losses = self._get_last_n_losses_structure(self._early_stop_patience)
[docs] def zero_grad(self, set_to_none: bool = False) -> None: """Set the gradients of all optimized :class:`torch.Tensor` s to zero.""" self._internal_optimiser.zero_grad(set_to_none=set_to_none)
@overload def step(self, loss: Union[float, torch.Tensor], closure: None = ...) -> None: ... # pragma: no cover @overload def step( self, loss: Union[float, torch.Tensor], closure: Callable[[], float] ) -> Union[float, torch.Tensor]: ... # pragma: no cover
[docs] def step( self, loss: Union[float, torch.Tensor], closure: Optional[Callable[[], float]] = None ) -> Optional[Union[float, torch.Tensor]]: """Perform a single optimisation step.""" step_result = self._step(loss, closure=closure) self.last_n_losses.append(float(loss)) no_improvement = self.last_n_losses[0] <= min(self.last_n_losses) if no_improvement: print_friendly_losses = ", ".join(f"{loss:.3f}" for loss in self.last_n_losses) raise NoImprovementError( f"Stopping early due to no improvement on {len(self.last_n_losses) - 1} " f"consecutive steps: [{print_friendly_losses}]" ) return step_result
[docs] def register_module(self, module: Module) -> None: """Register the parameters for a module.""" self._cache_module_parameters(module) parameters = {"params": module.parameters()} self._internal_optimiser.add_param_group(parameters)
[docs] def update_registered_module(self, module: Module) -> None: """Update the parameters of a registered module if the module has been modified.""" if module not in self._stored_initial_state_dicts: raise KeyError( f"{module!r} - Trying to update a module that isn't registered. Use `register_module` instead." ) self._cache_module_parameters(module) self._reset_internal_optimiser()
[docs] def set_parameters(self) -> None: """Tidy up after optimisation is completed."""
def _reset_module_parameters(self) -> None: """ Load afresh the stored initialisation values for all registered modules' parameters into the module. .. note:: Calling this in isolation will restore the initialised values for all parameters, but it does not reset the optimiser. To do this, call :meth:`_reset_internal_optimiser` additionally. """ for module, state_dict in self._stored_initial_state_dicts.items(): module.load_state_dict(state_dict) def _reset_internal_optimiser(self) -> None: """ Reset the internal optimiser. .. note:: Calling this in isolation will not affect the current value of the parameters as learned thus far. To reset these, call :meth:`_reset_module_parameters` additionally. """ parameters = [{"params": module.parameters()} for module in self._stored_initial_state_dicts] self._internal_optimiser = self._internal_optimiser_class( parameters, lr=self._learning_rate, **self._internal_optimiser_kwargs ) @overload def _step(self, loss: Union[torch.Tensor, float], closure: None = ...) -> None: ... # pragma: no cover @overload def _step(self, loss: Union[torch.Tensor, float], closure: Callable[[], float]) -> float: ... # pragma: no cover def _step(self, loss: Union[torch.Tensor, float], closure: Optional[Callable[[], float]] = None) -> Optional[float]: """Perform a single optimisation step.""" raise NotImplementedError def _cache_module_parameters(self, module: Module) -> None: """Cache the parameters for a module.""" state_dict = module.state_dict() for parameter_name, parameter in state_dict.items(): state_dict[parameter_name] = parameter.detach().clone() self._stored_initial_state_dicts[module] = state_dict def _set_step_method(self) -> None: """Create and set the :meth:`_step` method according to the internal optimiser.""" internal_step_signature = inspect.signature(self._internal_optimiser.step) def new_step_with_loss(loss, closure=None): """Pass the loss to the step function.""" return self._internal_optimiser.step(loss, closure=closure) def new_step_without_loss(loss, closure=None): """Don't pass the loss to the step function.""" try: return self._internal_optimiser.step(closure=closure) except TypeError as e: # This is in case the internal step signature is just (*args, **kwargs). if "missing 1 required positional argument: 'loss'" in str(e): result = self._internal_optimiser.step(loss, closure=closure) # If we got here, the above still worked, so set the step method to the one that uses the loss # by default to avoid the expensive exception catching on each step self._step = new_step_with_loss return result else: raise if "loss" in internal_step_signature.parameters: self._step = new_step_with_loss else: self._step = new_step_without_loss @staticmethod def _get_last_n_losses_structure(n: Optional[int]) -> deque[float]: """ Get the structure which will contain the last :math`n` losses. Returns an instance of :class:`collections.deque`. This is always initialised with at least one ``nan`` value. Whilst ``nan`` values occur in the structure, the minimum value will also be ``nan`` meaning that the minimum value will not be equal to the first element (because ``nan <= nan`` is ALWAYS ``False``. If ``n`` is a finite integer then these ``nan`` values will be popped from the structure after ``n+1`` additions. If ``n`` is ``None`` then the structure is infinite and this will never happen. :Example: >>> x = SmartOptimiser._get_last_n_losses_structure(2) >>> x deque([nan, nan, nan], maxlen=3) >>> for loss in range(2): ... x.append(loss) ... print(x[0], min(x), bool(x[0] <= min(x))) nan nan False nan nan False >>> x deque([nan, 0, 1], maxlen=3) >>> x.append(2) >>> print(x[0], min(x), bool(x[0] <= min(x))) 0 0 True >>> >>> y = SmartOptimiser._get_last_n_losses_structure(None) >>> y deque([nan]) >>> for loss in range(100): ... x.append(loss) >>> print(y[0], min(y), bool(y[0] <= min(y))) nan nan False """ if n is None: last_n_losses = deque([float("nan")], maxlen=None) else: max_length = n + 1 last_n_losses = deque([float("nan")] * max_length, maxlen=max_length) return last_n_losses
[docs] @total_ordering class Parameters: """ Wrapped for module state_dicts and an objective value of their quality. """
[docs] def __init__(self, module_state_dicts: dict[Module, dict[str, Tensor]], value: float = np.inf) -> None: """Initialise self.""" self.module_state_dicts = { module: self._clone_state_dict(state_dict) for module, state_dict in module_state_dicts.items() } self.priority_value = value
def __lt__(self, other: "Parameters") -> bool: if not isinstance(other, Parameters): return NotImplemented return self.priority_value < other.priority_value def __eq__(self, other: "Parameters") -> bool: if not isinstance(other, Parameters): return NotImplemented return self.priority_value == other.priority_value @staticmethod def _clone_state_dict(state_dict: dict[str, Tensor]) -> dict[str, Tensor]: """Detach and clone a state_dict so its tensors are not changed external to this class.""" return {key: value.detach().clone() for key, value in state_dict.items()}
T = TypeVar("T")
[docs] class MaxLengthHeapQ(Generic[T]): """A heapq of fixed maximum length."""
[docs] def __init__(self, max_length: int) -> None: """Initialise self.""" self.max_length = max_length self.heap = []
[docs] def push(self, item: T) -> None: """Push to the heapq.""" if len(self.heap) < self.max_length: heappush(self.heap, item) else: heappushpop(self.heap, item)
[docs] def nlargest(self, n: int) -> list[T]: """Get the top elements on the heapq.""" return nlargest(n, self.heap)
[docs] def best(self) -> T: """Get the top element.""" return self.nlargest(1)[0]
def __contains__(self, item): return item in self.heap
[docs] class GreedySmartOptimiser(SmartOptimiser[OptimiserT], Generic[OptimiserT]): """ Always choose parameters with the minimum loss value, regardless of the iteration at which they occur. .. note:: This is the default smart optimiser for some :class:`vanguard.vanilla.GaussianGPController`. To disable the greedy loss behaviour and revert to keeping the parameters at the final iteration of training, using :class:`vanguard.optimise.optimiser.SmartOptimiser` or a different subclass thereof. """ N_RETAINED_PARAMETERS = 1
[docs] def __init__( self, optimiser_class: type[OptimiserT], *initial_modules: Module, early_stop_patience: Optional[int] = None, **optimiser_kwargs: Any, ) -> None: super().__init__(optimiser_class, *initial_modules, early_stop_patience=early_stop_patience, **optimiser_kwargs) self._top_n_parameters: MaxLengthHeapQ[Parameters] = MaxLengthHeapQ(self.N_RETAINED_PARAMETERS)
[docs] def step(self, loss: Union[float, torch.Tensor], closure: Optional[Callable[[], float]] = None) -> None: """Step the optimiser and update the record best parameters.""" super().step(loss, closure=closure) state_dicts = {module: module.state_dict() for module in self._stored_initial_state_dicts} loss_at_current_step = self.last_n_losses[-1] parameters = Parameters(state_dicts, -loss_at_current_step) self._top_n_parameters.push(parameters)
[docs] def set_parameters(self) -> None: """Tidy up after optimisation by setting the parameters to the best.""" best_parameters = self._top_n_parameters.best() for module, state_dict in best_parameters.module_state_dicts.items(): module.load_state_dict(state_dict)
[docs] def reset(self) -> None: """Reset all parameters to start values and clear record of best parameters.""" super().reset() self._top_n_parameters = MaxLengthHeapQ(self.N_RETAINED_PARAMETERS)
[docs] class NoImprovementError(RuntimeError): """Raised when the loss of the model is consistently increasing."""