# © Crown Copyright GCHQ
#
# Licensed under the GNU General Public License, version 3 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.gnu.org/licenses/gpl-3.0.en.html
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Contains the VariationalHierarchicalHyperparameters decorator.
"""
from collections.abc import Generator
from typing import Any, Optional, TypeVar, Union
import gpytorch
import numpy as np
import torch
from gpytorch.variational import CholeskyVariationalDistribution, _VariationalDistribution
from linear_operator import to_linear_operator
from numpy.typing import NDArray
from torch import Tensor
from vanguard import utils
from vanguard.decoratorutils import process_args, wraps_class
from vanguard.hierarchical.base import (
BaseHierarchicalHyperparameters,
GPController,
Posterior,
extract_bayesian_hyperparameters,
set_batch_shape,
)
from vanguard.hierarchical.collection import HyperparameterCollection
ControllerT = TypeVar("ControllerT", bound=GPController)
KernelT = TypeVar("KernelT", bound=gpytorch.kernels.Kernel)
LikelihoodT = TypeVar("LikelihoodT", bound=gpytorch.likelihoods.GaussianLikelihood)
PosteriorT = TypeVar("PosteriorT", bound=Posterior)
DistributionT = TypeVar("DistributionT", bound=gpytorch.distributions.Distribution)
[docs]
class VariationalHierarchicalHyperparameters(BaseHierarchicalHyperparameters):
"""
Convert a controller so that variational inference is performed over its hyperparameters.
Note that only those hyperparameters specified using the
:class:`~vanguard.hierarchical.module.BayesianHyperparameters` decorator will be included
for variational inference. The remaining hyperparameters will be inferred as point estimates.
:Example:
>>> from gpytorch.kernels import RBFKernel
>>> import numpy as np
>>> import torch
>>> import torch.random
>>> from vanguard.vanilla import GaussianGPController
>>> from vanguard.hierarchical import (BayesianHyperparameters,
... VariationalHierarchicalHyperparameters)
>>>
>>> @VariationalHierarchicalHyperparameters(num_mc_samples=50)
... class HierarchicalController(GaussianGPController):
... pass
>>>
>>> @BayesianHyperparameters()
... class BayesianRBFKernel(RBFKernel):
... pass
>>>
>>> train_x = torch.tensor([0, 0.5, 0.9, 1])
>>> rng = torch.Generator(device=train_x.device).manual_seed(1234)
>>> train_y = torch.normal(mean=1 / (1 + train_x), std=torch.ones_like(train_x) * 0.005, generator=rng)
>>> gp = HierarchicalController(train_x, train_y, BayesianRBFKernel, y_std=0.0)
>>> loss = gp.fit(100)
>>>
>>> test_x = torch.tensor([0.05, 0.95])
>>> mean, lower, upper = gp.posterior_over_point(test_x).confidence_interval()
>>> (upper > 1/(1 + test_x)).all().item(), (lower < 1/(1 + test_x)).all().item()
(True, True)
"""
[docs]
def __init__(
self,
num_mc_samples: int = 100,
variational_distribution_class: Optional[_VariationalDistribution] = CholeskyVariationalDistribution,
**kwargs: Any,
) -> None:
"""
Initialise self.
:param num_mc_samples: The number of Monte Carlo samples to use when approximating
intractable integrals in the variational ELBO and the
predictive posterior.
:param variational_distribution_class:
The variational distribution to use for the raw hyperparameters' posterior. Defaults
to :class:`~gpytorch.variational.CholeskyVariationalDistribution`.
"""
super().__init__(num_mc_samples=num_mc_samples, **kwargs)
self.variational_distribution_class = variational_distribution_class
def _decorate_class(self, cls: type[ControllerT]) -> type[ControllerT]:
sample_shape = self.sample_shape
variational_distribution_class = self.variational_distribution_class
base_decorated_cls = super()._decorate_class(cls)
@wraps_class(base_decorated_cls)
class InnerClass(base_decorated_cls):
def __init__(self, *args: Any, **kwargs: Any) -> None:
for module_name in ("kernel", "mean", "likelihood"):
set_batch_shape(kwargs, module_name, sample_shape)
all_parameters_as_kwargs = process_args(super().__init__, *args, **kwargs)
self.rng = utils.optional_random_generator(all_parameters_as_kwargs.pop("rng", None))
# Pop `rng` from kwargs to ensure we don't provide duplicate values to superclass init
kwargs.pop("rng", None)
super().__init__(*args, rng=self.rng, **kwargs)
module_hyperparameter_pairs, point_estimate_kernels = extract_bayesian_hyperparameters(self)
_correct_point_estimate_shapes(point_estimate_kernels)
self.hyperparameter_collection = HyperparameterCollection(
module_hyperparameter_pairs, sample_shape, variational_distribution_class
)
self._smart_optimiser.update_registered_module(self._gp)
self._smart_optimiser.register_module(self.hyperparameter_collection.variational_distribution)
def _loss(self, train_x: torch.Tensor, train_y: torch.Tensor) -> torch.Tensor:
"""Add KL term to loss and average over hyperparameter samples."""
self.hyperparameter_collection.sample_and_update()
nmll = super()._loss(train_x, train_y)
return nmll.mean() + self.hyperparameter_collection.kl_term()
return InnerClass
@staticmethod
def _infinite_posterior_samples(
controller: ControllerT, x: Union[Tensor, NDArray[np.floating]]
) -> Generator[torch.Tensor, None, None]:
"""
Yield posterior samples forever.
:param controller: The controller from which to yield samples.
:param x: (n_predictions, n_features) The predictive inputs.
"""
tx = torch.as_tensor(x, dtype=torch.float32, device=controller.device)
while True:
controller.hyperparameter_collection.sample_and_update()
# pylint: disable=protected-access
output = _safe_index_batched_multivariate_normal(controller._gp_forward(x=tx).add_jitter(1e-3))
yield from output
@staticmethod
def _infinite_fuzzy_posterior_samples(
controller: ControllerT,
x: Union[Tensor, NDArray[np.floating]],
x_std: Union[Tensor, NDArray[np.floating], float],
) -> Generator[torch.Tensor, None, None]:
"""
Yield fuzzy posterior samples forever.
:param controller: The controller from which to yield samples.
:param x: (n_predictions, n_features) The predictive inputs.
:param x_std: The input noise standard deviations:
* array_like[float]: (n_features,) The standard deviation per input dimension for the predictions,
* float: Assume homoskedastic noise.
"""
tx = torch.tensor(x, dtype=torch.float32, device=controller.device)
# pylint: disable-next=protected-access
tx_std = controller._process_x_std(std=x_std).to(controller.device)
while True:
controller.hyperparameter_collection.sample_and_update()
# This cunning trick matches the sampled x shape to the MC samples batch shape.
# The results is that each sample from output comes from independent x samples
# and from independent variational posterior samples.
sample_shape = controller.hyperparameter_collection.sample_shape + tx.shape
x_sample = torch.randn(size=sample_shape, device=controller.device) * tx_std + tx
# pylint: disable-next=protected-access
output = _safe_index_batched_multivariate_normal(controller._gp_forward(x=x_sample).add_jitter(1e-3))
yield from output
@staticmethod
def _infinite_likelihood_samples(
controller: ControllerT, x: Union[Tensor, NDArray[np.floating]]
) -> Generator[torch.Tensor, None, None]:
"""
Yield likelihood samples forever.
:param controller: The controller from which to yield samples.
:param x: (n_predictions, n_features) The predictive inputs.
"""
tx = torch.as_tensor(x, dtype=torch.float32, device=controller.device)
while True:
controller.hyperparameter_collection.sample_and_update()
# pylint: disable-next=protected-access
output = _safe_index_batched_multivariate_normal(controller._gp_forward(x=tx).add_jitter(1e-3))
for sample in output:
# pylint: disable-next=protected-access
shape = controller._decide_noise_shape(controller.posterior_class(sample), x=tx)
noise = torch.zeros(shape, dtype=torch.float32, device=controller.device)
# pylint: disable-next=protected-access
likelihood_output = controller._likelihood(sample, noise=noise)
yield likelihood_output
@staticmethod
def _infinite_fuzzy_likelihood_samples(
controller: ControllerT,
x: Union[Tensor, NDArray[np.floating]],
x_std: Union[Tensor, NDArray[np.floating], float],
) -> Generator[torch.Tensor, None, None]:
"""
Yield fuzzy likelihood samples forever.
:param controller: The controller from which to yield samples.
:param x: (n_predictions, n_features) The predictive inputs.
:param x_std: The input noise standard deviations:
* array_like[float]: (n_features,) The standard deviation per input dimension for the predictions,
* float: Assume homoskedastic noise.
"""
tx = torch.tensor(x, dtype=torch.float32, device=controller.device)
# pylint: disable-next=protected-access
tx_std = controller._process_x_std(x_std).to(controller.device)
while True:
controller.hyperparameter_collection.sample_and_update()
# This cunning trick matches the sampled x shape to the MC samples batch shape.
# The results is that each sample from output comes from independent x samples
# and from independent variational posterior samples.
sample_shape = controller.hyperparameter_collection.sample_shape + tx.shape
x_sample = torch.randn(size=sample_shape, device=controller.device) * tx_std + tx
# pylint: disable-next=protected-access
output = _safe_index_batched_multivariate_normal(controller._gp_forward(x=x_sample).add_jitter(1e-3))
for sample in output:
# pylint: disable-next=protected-access
shape = controller._decide_noise_shape(controller.posterior_class(sample), x=tx)
noise = torch.zeros(shape, dtype=torch.float32, device=controller.device)
# pylint: disable-next=protected-access
likelihood_output = controller._likelihood(sample, noise=noise)
yield likelihood_output
def _correct_point_estimate_shapes(point_estimate_kernels: list[KernelT]) -> None:
"""
Adjust the shape of the constants of point estimate kernels.
These will be incorrect due to how GPyTorch handles batch shapes.
"""
for point_estimate_scale_kernel in point_estimate_kernels:
delattr(point_estimate_scale_kernel, "raw_outputscale")
point_estimate_scale_kernel.register_parameter(
name="raw_outputscale", parameter=torch.nn.Parameter(torch.zeros([1]))
)
def _safe_index_batched_multivariate_normal(
distribution: DistributionT,
) -> Generator[DistributionT, None, None]:
"""
Delazifies the batched covariance matrix and yields recreated non-batch normals.
Indexing into the batch dimension of batch :class:`~gpytorch.distributions.MultivariateNormal`
is somewhat brittle when the underlying covariance matrix is lazy (which happens when the covariance
matrix is larger than an obscure threshold). Hopefully this will change, but for now, we will work
around it. This function delazifies the batched covariance matrix and yields recreated non-batch
normals using then relazified individual covariance matrices.
Delazifying the batch covariance matrix doesn't cause any inefficiencies because the individual
covariance matrices would be delazified later anyway. Relazifying the individual matrices just
delays any Cholesky issues, which is good because we have handling for them downstream.
"""
distribution_type = type(distribution)
non_lazy_covariance_matrix = distribution.covariance_matrix
for sub_mean, sub_covariance_matrix in zip(distribution.mean, non_lazy_covariance_matrix):
yield distribution_type(sub_mean, to_linear_operator(sub_covariance_matrix))