# © Crown Copyright GCHQ
#
# Licensed under the GNU General Public License, version 3 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.gnu.org/licenses/gpl-3.0.en.html
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Contains the SetWarp decorator.
"""
from typing import Any, TypeVar, Union
import numpy as np
import numpy.typing
import torch
from torch import Tensor
from typing_extensions import Self, override
from vanguard import utils
from vanguard.base import GPController
from vanguard.base.posteriors import Posterior
from vanguard.decoratorutils import Decorator, process_args, wraps_class
from vanguard.warps.basefunction import WarpFunction
from vanguard.warps.intermediate import is_intermediate_warp_function
ControllerT = TypeVar("ControllerT", bound=GPController)
[docs]
class SetWarp(Decorator):
"""
Map a GP through a warp function.
:Example:
>>> from vanguard.base import GPController
>>> from vanguard.warps.warpfunctions import BoxCoxWarpFunction
>>>
>>> @SetWarp(BoxCoxWarpFunction(1))
... class MyController(GPController):
... pass
"""
[docs]
def __init__(self, warp_function: WarpFunction, **kwargs: Any):
"""
Initialise self.
:param warp_function: The warp function to be applied to the GP.
:param kwargs: Keyword arguments passed to :class:`~vanguard.decoratorutils.basedecorator.Decorator`.
"""
super().__init__(framework_class=GPController, required_decorators={}, **kwargs)
self.warp_function = warp_function
@property
@override
def safe_updates(self) -> dict[type, set[str]]:
# pylint: disable=import-outside-toplevel
from vanguard.classification import (
BinaryClassification,
CategoricalClassification,
DirichletMulticlassClassification,
)
from vanguard.classification.mixin import Classification, ClassificationMixin
from vanguard.features import HigherRankFeatures
from vanguard.hierarchical import LaplaceHierarchicalHyperparameters, VariationalHierarchicalHyperparameters
from vanguard.learning import LearnYNoise
from vanguard.multitask import Multitask
from vanguard.normalise import NormaliseY
from vanguard.standardise import DisableStandardScaling
from vanguard.variational import VariationalInference
from vanguard.warps import SetInputWarp
# pylint: enable=import-outside-toplevel
return self._add_to_safe_updates(
super().safe_updates,
{
BinaryClassification: {
"__init__",
"classify_points",
"classify_fuzzy_points",
"_get_predictions_from_prediction_means",
"warn_normalise_y",
},
CategoricalClassification: {
"__init__",
"classify_points",
"classify_fuzzy_points",
"_get_predictions_from_posterior",
"warn_normalise_y",
},
ClassificationMixin: {"classify_points", "classify_fuzzy_points"},
Classification: {
"posterior_over_point",
"posterior_over_fuzzy_point",
"fuzzy_predictive_likelihood",
"predictive_likelihood",
},
DisableStandardScaling: {"_input_standardise_modules"},
DirichletMulticlassClassification: {
"__init__",
"_loss",
"_noise_transform",
"classify_points",
"classify_fuzzy_points",
"_get_predictions_from_prediction_means",
"warn_normalise_y",
},
HigherRankFeatures: {"__init__"},
LaplaceHierarchicalHyperparameters: {
"__init__",
"_compute_hyperparameter_laplace_approximation",
"_compute_loss_hessian",
"_fuzzy_predictive_likelihood",
"_get_posterior_over_fuzzy_point_in_eval_mode",
"_get_posterior_over_point",
"_gp_forward",
"_predictive_likelihood",
"_sample_and_set_hyperparameters",
"_sgd_round",
"_update_hyperparameter_posterior",
"auto_temperature",
},
LearnYNoise: {"__init__"},
Multitask: {"__init__", "_match_mean_shape_to_kernel"},
NormaliseY: {"__init__", "warn_normalise_y"},
SetInputWarp: {"__init__"},
SetWarp: {"__init__", "_loss", "_sgd_round", "warn_normalise_y", "_unwarp_values"},
VariationalHierarchicalHyperparameters: {
"__init__",
"_fuzzy_predictive_likelihood",
"_get_posterior_over_fuzzy_point_in_eval_mode",
"_get_posterior_over_point",
"_gp_forward",
"_loss",
"_predictive_likelihood",
},
VariationalInference: {"__init__", "_predictive_likelihood", "_fuzzy_predictive_likelihood"},
},
)
def _decorate_class(self, cls: type[ControllerT]) -> type[ControllerT]:
warp_function = self.warp_function
@wraps_class(cls, decorator_source=self)
class InnerClass(cls):
"""
A wrapper for applying a compositional warp to a controller class.
"""
def __init__(self, *args: Any, **kwargs: Any):
all_parameters_as_kwargs = process_args(super().__init__, *args, **kwargs)
self.rng = utils.optional_random_generator(all_parameters_as_kwargs.pop("rng", None))
# Pop `rng` from kwargs to ensure we don't provide duplicate values to superclass
kwargs.pop("rng", None)
super().__init__(*args, rng=self.rng, **kwargs)
for warp_component in warp_function.components:
if is_intermediate_warp_function(warp_component):
warp_component.activate(**all_parameters_as_kwargs)
warp_copy = warp_function.copy().float()
self.warp = warp_copy
self._smart_optimiser.register_module(self.warp)
self.train_y = self.train_y.to(self.device)
def _unwarp_values(
*values: Union[Tensor, numpy.typing.NDArray[np.floating]],
) -> tuple[Tensor, ...]:
"""
Map values back through the warp.
:param values: Values to reverse warping on
:return: Values warped back onto original space
"""
values_as_tensors = (
torch.as_tensor(value, dtype=self.dtype, device=self.device) for value in values
)
unwarped_values_as_tensors = (warp_copy.inverse(tensor).squeeze() for tensor in values_as_tensors)
return tuple(unwarped_values_as_tensors)
def _warp_values(
*values: Union[Tensor, numpy.typing.NDArray[np.floating]],
) -> tuple[Tensor, ...]:
"""
Map values through the warp.
:param values: Values to warp on
:return: Values warp onto new space
"""
values_as_tensors = (
torch.as_tensor(value, dtype=self.dtype, device=self.device) for value in values
)
warped_values_as_tensors = (warp_copy(tensor).squeeze() for tensor in values_as_tensors)
return tuple(warped_values_as_tensors)
def _warp_derivative_values(
*values: Union[Tensor, numpy.typing.NDArray[np.floating]],
) -> tuple[Tensor, ...]:
"""
Map values through the derivative of the warp.
:param values: Values to compute derivatives of warp for
:return: Derivatives of warp for each input value
"""
values_as_tensors = (
torch.as_tensor(value, dtype=self.dtype, device=self.device) for value in values
)
warped_values_as_tensors = (warp_copy.deriv(tensor).squeeze() for tensor in values_as_tensors)
return tuple(warped_values_as_tensors)
def warp_posterior_class(posterior_class: type[Posterior]) -> type[Posterior]:
"""Wrap a posterior class to enable warping."""
@wraps_class(posterior_class)
class WarpedPosterior(posterior_class):
"""
Un-scale the distribution at initialisation.
"""
def prediction(self) -> torch.tensor: # pytest: ignore [reportGeneralTypeIssues]
"""Un-warp values."""
raise TypeError("The mean and covariance of a warped GP cannot be computed exactly.")
def confidence_interval(
self, alpha: float = 0.05
) -> tuple[
numpy.typing.NDArray[np.floating],
numpy.typing.NDArray[np.floating],
numpy.typing.NDArray[np.floating],
]:
"""Un-warp values."""
mean, lower, upper = super().confidence_interval(alpha)
return _unwarp_values(mean, lower, upper)
def log_probability(
self, y: tuple[numpy.typing.NDArray[np.floating]]
) -> numpy.typing.NDArray[np.floating]:
"""Apply the change of variables to the density using the warp."""
warped_y = _warp_values(y)
warp_deriv_values = _warp_derivative_values(y)
jacobian = np.sum(np.log(np.abs(warp_deriv_values)))
return jacobian + super().log_probability(warped_y)
return WarpedPosterior
self.posterior_class = warp_posterior_class(self.posterior_class)
self.posterior_collection_class = warp_posterior_class(self.posterior_collection_class)
@classmethod
def new(cls, instance: Self, **kwargs: Any) -> Self:
"""Also apply warping to the new instance."""
new_instance = super().new(instance, **kwargs)
new_instance.warp = instance.warp
# pylint: disable=protected-access
new_instance._gp.train_targets = new_instance.warp(new_instance._gp.train_targets).squeeze(dim=-1)
return new_instance
def _sgd_round(self, n_iters: int = 100, gradient_every: int = 100) -> torch.Tensor:
"""Calculate loss and warp train_y."""
loss = super()._sgd_round(n_iters=n_iters, gradient_every=gradient_every)
warped_train_y = self.warp(self.train_y).squeeze(dim=-1)
self._gp.train_targets = warped_train_y
return loss
def _unwarp_values(self, *values: Union[Tensor, numpy.typing.NDArray[np.floating]]) -> tuple[Tensor, ...]:
"""Map values back through the warp."""
values_as_tensors = (torch.as_tensor(value) for value in values)
unwarped_values_as_tensors = (self.warp.inverse(tensor).reshape(-1) for tensor in values_as_tensors)
return tuple(unwarped_values_as_tensors)
def _loss(self, train_x: torch.Tensor, train_y: torch.Tensor) -> torch.Tensor:
"""Subtract additional derivative term from the mll."""
warped_train_y = self.warp(train_y).squeeze(dim=-1)
self._gp.train_targets = warped_train_y
nmll = super()._loss(train_x, warped_train_y)
return nmll - self.warp.deriv(train_y).squeeze(dim=-1).sum()
@staticmethod
def warn_normalise_y() -> None:
"""Override base warning because warping renders y normalisation unimportant."""
return InnerClass