# © Crown Copyright GCHQ
#
# Licensed under the GNU General Public License, version 3 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.gnu.org/licenses/gpl-3.0.en.html
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Contains the BaseHierarchicalHyperparameters decorator.
"""
import warnings
from collections.abc import Generator
from typing import Any, TypeVar, Union
import gpytorch
import numpy as np
import torch
from gpytorch.kernels import ScaleKernel
from numpy.typing import NDArray
from torch import Tensor
from typing_extensions import Self, override
from vanguard.base import GPController
from vanguard.base.posteriors import MonteCarloPosteriorCollection, Posterior
from vanguard.decoratorutils import Decorator, wraps_class
from vanguard.warnings import _JITTER_WARNING, NumericalWarning
ControllerT = TypeVar("ControllerT", bound=GPController)
DistributionT = TypeVar("DistributionT", bound=gpytorch.distributions.Distribution)
PosteriorT = TypeVar("PosteriorT", bound=Posterior)
ModuleT = TypeVar("ModuleT", bound=torch.nn.Module)
[docs]
class BaseHierarchicalHyperparameters(Decorator):
"""
Convert a controller so that Bayesian inference is performed over its hyperparameters.
Note that only those hyperparameters specified using the
:class:`~vanguard.hierarchical.module.BayesianHyperparameters` decorator will be included
for Bayesian inference. The remaining hyperparameters will be inferred as point estimates.
"""
[docs]
def __init__(self, num_mc_samples: int = 100, **kwargs: Any) -> None:
"""
Initialise self.
:param num_mc_samples: The number of Monte Carlo samples to use when approximating
intractable integrals in the variational ELBO and the
predictive posterior.
"""
self.sample_shape = torch.Size([num_mc_samples])
super().__init__(framework_class=GPController, required_decorators={}, **kwargs)
[docs]
@override
def verify_decorated_class(self, cls: type[ControllerT]) -> None:
super().verify_decorated_class(cls)
for previous_decorator in cls.__decorators__:
if issubclass(previous_decorator, BaseHierarchicalHyperparameters):
msg = (
f"This class is already decorated with `{previous_decorator.__name__}`. "
f"Please use only one hierarchical hyperparameters decorator at once."
)
raise TypeError(msg)
def _decorate_class(self, cls: type[ControllerT]) -> type[ControllerT]:
decorator = self
@wraps_class(cls)
class InnerClass(cls):
@classmethod
def new(cls, instance: Self, **kwargs: Any) -> Self:
"""Make sure that the hyperparameter collection is copied over."""
new_instance = super().new(instance, **kwargs)
new_instance.hyperparameter_collection = instance.hyperparameter_collection
return new_instance
def _get_posterior_over_point(self, x: Union[Tensor, NDArray[np.floating]]) -> type[PosteriorT]:
"""
Predict the y-value of a single point. The mode (eval vs train) of the model is not changed.
:param x: (n_predictions, n_features) The predictive inputs.
:returns: The prior distribution.
"""
x = torch.as_tensor(x)
posteriors = (
self.posterior_class(posterior_sample)
# pylint: disable=protected-access
for posterior_sample in decorator._infinite_posterior_samples(self, x)
)
posterior_collection = self.posterior_collection_class(posteriors)
return posterior_collection
def _predictive_likelihood(self, x: Union[Tensor, NDArray[np.floating]]) -> type[PosteriorT]:
"""
Predict the likelihood value of a single point. The mode (eval vs train) of the model is not changed.
:param x: (n_predictions, n_features) The predictive inputs.
:returns: The prior distribution.
"""
x = torch.as_tensor(x)
likelihoods = (
self.posterior_class(posterior_sample)
# pylint: disable=protected-access
for posterior_sample in decorator._infinite_likelihood_samples(self, x)
)
likelihood_collection = self.posterior_collection_class(likelihoods)
return likelihood_collection
def _get_posterior_over_fuzzy_point_in_eval_mode(
self, x: Union[Tensor, NDArray[np.floating]], x_std: Union[Tensor, NDArray[np.floating], float]
) -> type[MonteCarloPosteriorCollection]:
"""
Obtain Monte Carlo integration samples from the predictive posterior with Gaussian input noise.
.. warning:
The ``n_features`` must match with :attr:`self.dim`.
:param x: (n_predictions, n_features) The predictive inputs.
:param x_std: The input noise standard deviations:
* array_like[float]: (n_features,) The standard deviation per input dimension for the predictions,
* float: Assume homoskedastic noise.
:returns: The prior distribution.
"""
x = torch.as_tensor(x)
x_std = torch.as_tensor(x_std)
self.set_to_evaluation_mode()
posteriors = (
self.posterior_class(x_sample)
# pylint: disable=protected-access
for x_sample in decorator._infinite_fuzzy_posterior_samples(self, x, x_std)
)
posterior_collection = self.posterior_collection_class(posteriors)
return posterior_collection
def _fuzzy_predictive_likelihood(
self, x: Union[Tensor, NDArray[np.floating]], x_std: Union[Tensor, NDArray[np.floating], float]
) -> type[MonteCarloPosteriorCollection]:
"""
Obtain Monte Carlo integration samples from the predictive likelihood with Gaussian input noise.
.. warning:
The ``n_features`` must match with :attr:`self.dim`.
:param x: (n_predictions, n_features) The predictive inputs.
:param x_std: The input noise standard deviations:
* array_like[float]: (n_features,) The standard deviation per input dimension for the predictions,
* float: Assume homoskedastic noise.
:returns: The prior distribution.
"""
x = torch.as_tensor(x)
x_std = torch.as_tensor(x_std)
self.set_to_evaluation_mode()
likelihoods = (
self.posterior_class(posterior_sample)
# pylint: disable=protected-access
for posterior_sample in decorator._infinite_fuzzy_likelihood_samples(self, x, x_std)
)
likelihood_collection = self.posterior_collection_class(likelihoods)
return likelihood_collection
def _gp_forward(self, x: torch.Tensor) -> DistributionT:
"""
Run the forward method of the internal GP model.
Overloading is necessary to remove fast_pred_var.
See here: https://github.com/cornellius-gp/gpytorch/issues/864
"""
x = torch.as_tensor(x)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=NumericalWarning, message=_JITTER_WARNING)
output = self._gp(x)
return output
return InnerClass
@staticmethod
def _infinite_posterior_samples(
controller: ControllerT, x: Union[Tensor, NDArray[np.floating]]
) -> Generator[torch.Tensor, None, None]:
raise NotImplementedError
@staticmethod
def _infinite_fuzzy_posterior_samples(
controller: ControllerT, x: Union[Tensor, NDArray[np.floating]], x_std: Union[Tensor, NDArray[np.floating]]
) -> Generator[torch.Tensor, None, None]:
raise NotImplementedError
@staticmethod
def _infinite_likelihood_samples(
controller: ControllerT, x: Union[Tensor, NDArray[np.floating]]
) -> Generator[torch.Tensor, None, None]:
raise NotImplementedError
@staticmethod
def _infinite_fuzzy_likelihood_samples(
controller: ControllerT, x: Union[Tensor, NDArray[np.floating]], x_std: Union[Tensor, NDArray[np.floating]]
) -> Generator[torch.Tensor, None, None]:
raise NotImplementedError
def _get_bayesian_hyperparameters(module: ModuleT) -> tuple[list, ...]:
"""
Find the bayesian hyperparameters of a GPyTorch module (mean, kernel or likelihood).
Searches through all sub-modules for parameters and extracts the hyperparameter names,
the modules to which they belong, their shapes, their constraints and their priors.
Also finds the ScaleKernels that are not Bayesian (i.e. standard point estimate
hyperparameters). These are needed to adjust batch_shapes.
.. note::
This function is designed to work with modules that have been decorated with
:class:`~vanguard.hierarchical.module.BayesianHyperparameters`. If that
decorator has not been applied, then this function does nothing.
:param module: The module from which to extract the hyperparameters.
:returns:
* The module, hyperparameter pairs,
* The modules (at any depth) corresponding to ScaleKernels with point estimate hyperparameters.
"""
point_estimates_scale_kernels = []
bayesian_hyperparameters = getattr(module, "bayesian_hyperparameters", [])
module_hyperparameter_pairs = [(module, hyperparameter) for hyperparameter in bayesian_hyperparameters]
for sub_module in module.children():
sub_hyperparameters, sub_point_estimates_scale_kernels = _get_bayesian_hyperparameters(sub_module)
module_hyperparameter_pairs.extend(sub_hyperparameters)
point_estimates_scale_kernels.extend(sub_point_estimates_scale_kernels)
if isinstance(module, ScaleKernel) and not hasattr(module, "bayesian_hyperparameters"):
point_estimates_scale_kernels.append(module)
return module_hyperparameter_pairs, point_estimates_scale_kernels
[docs]
def set_batch_shape(kwargs: Any, module_name: str, batch_shape: tuple[int, ...]) -> None:
"""Set the batch shape in kwargs dictionary which may not exist."""
kwargs_name = f"{module_name}_kwargs"
module_kwargs = kwargs.pop(kwargs_name, {})
module_kwargs["batch_shape"] = batch_shape
kwargs[kwargs_name] = module_kwargs