Source code for vanguard.warps.decorator

# © Crown Copyright GCHQ
#
# Licensed under the GNU General Public License, version 3 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.gnu.org/licenses/gpl-3.0.en.html
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Contains the SetWarp decorator.
"""

from typing import Any, TypeVar, Union

import numpy as np
import numpy.typing
import torch
from torch import Tensor
from typing_extensions import Self

from vanguard import utils
from vanguard.base import GPController
from vanguard.base.posteriors import Posterior
from vanguard.decoratorutils import Decorator, process_args, wraps_class
from vanguard.warps.basefunction import WarpFunction
from vanguard.warps.intermediate import is_intermediate_warp_function

ControllerT = TypeVar("ControllerT", bound=GPController)


[docs] class SetWarp(Decorator): """ Map a GP through a warp function. :Example: >>> from vanguard.base import GPController >>> from vanguard.warps.warpfunctions import BoxCoxWarpFunction >>> >>> @SetWarp(BoxCoxWarpFunction(1)) ... class MyController(GPController): ... pass """
[docs] def __init__(self, warp_function: WarpFunction, **kwargs: Any): """ Initialise self. :param warp_function: The warp function to be applied to the GP. :param kwargs: Keyword arguments passed to :class:`~vanguard.decoratorutils.basedecorator.Decorator`. """ super().__init__(framework_class=GPController, required_decorators={}, **kwargs) self.warp_function = warp_function
def _decorate_class(self, cls: type[ControllerT]) -> type[ControllerT]: warp_function = self.warp_function @wraps_class(cls) class InnerClass(cls): """ A wrapper for applying a compositional warp to a controller class. """ def __init__(self, *args: Any, **kwargs: Any): all_parameters_as_kwargs = process_args(super().__init__, *args, **kwargs) self.rng = utils.optional_random_generator(all_parameters_as_kwargs.pop("rng", None)) # Pop `rng` from kwargs to ensure we don't provide duplicate values to superclass kwargs.pop("rng", None) super().__init__(*args, rng=self.rng, **kwargs) for warp_component in warp_function.components: if is_intermediate_warp_function(warp_component): warp_component.activate(**all_parameters_as_kwargs) warp_copy = warp_function.copy().float() self.warp = warp_copy self._smart_optimiser.register_module(self.warp) self.train_y = self.train_y.to(self.device) def _unwarp_values( *values: Union[Tensor, numpy.typing.NDArray[np.floating]], ) -> tuple[Tensor, ...]: """ Map values back through the warp. :param values: Values to reverse warping on :return: Values warped back onto original space """ values_as_tensors = ( torch.as_tensor(value, dtype=self.dtype, device=self.device) for value in values ) unwarped_values_as_tensors = (warp_copy.inverse(tensor).squeeze() for tensor in values_as_tensors) return tuple(unwarped_values_as_tensors) def _warp_values( *values: Union[Tensor, numpy.typing.NDArray[np.floating]], ) -> tuple[Tensor, ...]: """ Map values through the warp. :param values: Values to warp on :return: Values warp onto new space """ values_as_tensors = ( torch.as_tensor(value, dtype=self.dtype, device=self.device) for value in values ) warped_values_as_tensors = (warp_copy(tensor).squeeze() for tensor in values_as_tensors) return tuple(warped_values_as_tensors) def _warp_derivative_values( *values: Union[Tensor, numpy.typing.NDArray[np.floating]], ) -> tuple[Tensor, ...]: """ Map values through the derivative of the warp. :param values: Values to compute derivatives of warp for :return: Derivatives of warp for each input value """ values_as_tensors = ( torch.as_tensor(value, dtype=self.dtype, device=self.device) for value in values ) warped_values_as_tensors = (warp_copy.deriv(tensor).squeeze() for tensor in values_as_tensors) return tuple(warped_values_as_tensors) def warp_posterior_class(posterior_class: type[Posterior]) -> type[Posterior]: """Wrap a posterior class to enable warping.""" @wraps_class(posterior_class) class WarpedPosterior(posterior_class): """ Un-scale the distribution at initialisation. """ def prediction(self) -> torch.tensor: # pytest: ignore [reportGeneralTypeIssues] """Un-warp values.""" raise TypeError("The mean and covariance of a warped GP cannot be computed exactly.") def confidence_interval( self, alpha: float = 0.05 ) -> tuple[ numpy.typing.NDArray[np.floating], numpy.typing.NDArray[np.floating], numpy.typing.NDArray[np.floating], ]: """Un-warp values.""" mean, lower, upper = super().confidence_interval(alpha) return _unwarp_values(mean, lower, upper) def log_probability( self, y: tuple[numpy.typing.NDArray[np.floating]] ) -> numpy.typing.NDArray[np.floating]: """Apply the change of variables to the density using the warp.""" warped_y = _warp_values(y) warp_deriv_values = _warp_derivative_values(y) jacobian = np.sum(np.log(np.abs(warp_deriv_values))) return jacobian + super().log_probability(warped_y) return WarpedPosterior self.posterior_class = warp_posterior_class(self.posterior_class) self.posterior_collection_class = warp_posterior_class(self.posterior_collection_class) @classmethod def new(cls, instance: Self, **kwargs: Any) -> Self: """Also apply warping to the new instance.""" new_instance = super().new(instance, **kwargs) new_instance.warp = instance.warp # pylint: disable=protected-access new_instance._gp.train_targets = new_instance.warp(new_instance._gp.train_targets).squeeze(dim=-1) return new_instance def _sgd_round(self, n_iters: int = 100, gradient_every: int = 100) -> torch.Tensor: """Calculate loss and warp train_y.""" loss = super()._sgd_round(n_iters=n_iters, gradient_every=gradient_every) warped_train_y = self.warp(self.train_y).squeeze(dim=-1) self._gp.train_targets = warped_train_y return loss def _unwarp_values(self, *values: Union[Tensor, numpy.typing.NDArray[np.floating]]) -> tuple[Tensor, ...]: """Map values back through the warp.""" values_as_tensors = (torch.as_tensor(value) for value in values) unwarped_values_as_tensors = (self.warp.inverse(tensor).reshape(-1) for tensor in values_as_tensors) return tuple(unwarped_values_as_tensors) def _loss(self, train_x: torch.Tensor, train_y: torch.Tensor) -> torch.Tensor: """Subtract additional derivative term from the mll.""" warped_train_y = self.warp(train_y).squeeze(dim=-1) self._gp.train_targets = warped_train_y nmll = super()._loss(train_x, warped_train_y) return nmll - self.warp.deriv(train_y).squeeze(dim=-1).sum() @staticmethod def warn_normalise_y() -> None: """Override base warning because warping renders y normalisation unimportant.""" return InnerClass