# © Crown Copyright GCHQ
#
# Licensed under the GNU General Public License, version 3 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.gnu.org/licenses/gpl-3.0.en.html
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Enable variational inference in a controller.
The :class:`VariationalInference` decorator primes a :class:`~vanguard.base.gpcontroller.GPController` class
for variational inference.
"""
import warnings
from typing import Any, Generic, Optional, TypeVar, Union
import gpytorch.settings
import numpy as np
import numpy.typing
import torch
from torch import Tensor
from typing_extensions import override
from vanguard import utils
from vanguard.base import GPController
from vanguard.base.posteriors import Posterior
from vanguard.decoratorutils import Decorator, process_args, wraps_class
from vanguard.decoratorutils.basedecorator import T
from vanguard.decoratorutils.errors import BadCombinationWarning
from vanguard.variational.models import SVGPModel
ControllerT = TypeVar("ControllerT", bound=GPController)
# pylint: disable-next=protected-access
StrategyT = TypeVar("StrategyT", bound=gpytorch.variational._VariationalStrategy)
# pylint: disable-next=protected-access
DistributionT = TypeVar("DistributionT", bound=gpytorch.variational._VariationalDistribution)
[docs]
class VariationalInference(Decorator, Generic[StrategyT, DistributionT]):
"""
Set-up a :class:`~vanguard.base.gpcontroller.GPController` class for variational inference.
This is best used when:
* the posterior can not be calculated as a closed-form, or
* there are too many points to train a model in a reasonable time (see :cite:`Cheng17`).
.. note::
This decorator does not take the standard parameters in the
:class:`~vanguard.decoratorutils.basedecorator.Decorator`
class, as it only affects the input.
.. warning::
This decorator will force the wrapped controller class to only accept compatible
``gp_model_class`` and ``marginal_log_likelihood_class`` arguments. The former should
be a subclass of :class:`vanguard.variational.models.SVGPModel`, and the latter must take a ``num_data``
:class:`int` argument (e.g. a subclass of :ref:`one of the following
</marginal_log_likelihoods.rst#approximate-gp-inference>`).
:Example:
>>> @VariationalInference(n_inducing_points=100)
... class NewController(GPController):
... pass
"""
[docs]
def __init__(
self,
n_inducing_points: Optional[int] = None,
n_likelihood_samples: int = 10,
variational_strategy_class: Optional[type[StrategyT]] = None,
variational_distribution_class: Optional[type[DistributionT]] = None,
**kwargs: Any,
) -> None:
"""
Initialise self.
:param n_inducing_points: The size of the inducing point approximation. Defaults to None, meaning
that the number of inducing points will be set to the number of points.
:param n_likelihood_samples: If the marginal likelihood cannot be computed exactly (which is usually the
case when using variational inference), it is approximated using
MC integration by sampling from the variational posterior and averaging over
the likelihood values for each sample. This is the number of samples to use.
:param variational_strategy_class: The class for the variational strategy to use.
Default behaviour is defined in
:class:`gpytorch.variational.VariationalStrategy`
(:cite:`Hensman15`).
:param variational_distribution_class: The class for the variational distribution to use.
Default behaviour is defined in
:class:`gpytorch.variational.CholeskyVariationalDistribution` (Cholesky).
"""
super().__init__(framework_class=GPController, required_decorators={}, **kwargs)
self.n_inducing_points = n_inducing_points
self.n_likelihood_samples = n_likelihood_samples
self.variational_strategy_class = variational_strategy_class
self.gp_model_class = self._build_gp_model_class(variational_distribution_class, variational_strategy_class)
@property
@override
def safe_updates(self) -> dict[type, set[str]]:
# pylint: disable=import-outside-toplevel
from vanguard.classification import DirichletMulticlassClassification
from vanguard.classification.mixin import Classification, ClassificationMixin
from vanguard.features import HigherRankFeatures
from vanguard.hierarchical import LaplaceHierarchicalHyperparameters, VariationalHierarchicalHyperparameters
from vanguard.learning import LearnYNoise
from vanguard.multitask import Multitask
from vanguard.normalise import NormaliseY
from vanguard.standardise import DisableStandardScaling
from vanguard.warps import SetInputWarp, SetWarp
# pylint: enable=import-outside-toplevel
return self._add_to_safe_updates(
super().safe_updates,
{
Classification: {
"posterior_over_point",
"posterior_over_fuzzy_point",
"fuzzy_predictive_likelihood",
"predictive_likelihood",
},
ClassificationMixin: {"classify_points", "classify_fuzzy_points"},
DirichletMulticlassClassification: {
"__init__",
"_loss",
"_noise_transform",
"classify_points",
"classify_fuzzy_points",
"_get_predictions_from_prediction_means",
"warn_normalise_y",
},
DisableStandardScaling: {"_input_standardise_modules"},
HigherRankFeatures: {"__init__"},
LaplaceHierarchicalHyperparameters: {
"__init__",
"_compute_hyperparameter_laplace_approximation",
"_compute_loss_hessian",
"_fuzzy_predictive_likelihood",
"_get_posterior_over_fuzzy_point_in_eval_mode",
"_get_posterior_over_point",
"_gp_forward",
"_predictive_likelihood",
"_sample_and_set_hyperparameters",
"_sgd_round",
"_update_hyperparameter_posterior",
"auto_temperature",
},
LearnYNoise: {"__init__"},
Multitask: {"__init__", "_match_mean_shape_to_kernel"},
NormaliseY: {"__init__", "warn_normalise_y"},
SetInputWarp: {"__init__"},
SetWarp: {"__init__", "_loss", "_sgd_round", "warn_normalise_y", "_unwarp_values"},
VariationalHierarchicalHyperparameters: {
"__init__",
"_fuzzy_predictive_likelihood",
"_get_posterior_over_fuzzy_point_in_eval_mode",
"_get_posterior_over_point",
"_gp_forward",
"_loss",
"_predictive_likelihood",
},
},
)
[docs]
@override
def verify_decorated_class(self, cls: type[T]) -> None:
super().verify_decorated_class(cls)
decorators = getattr(cls, "__decorators__", [])
if any(issubclass(decorator, VariationalInference) for decorator in decorators):
warnings.warn(
"Multiple instances of `@VariationalInference` not supported."
" Please only apply one instance of `@VariationalInference` at once.",
BadCombinationWarning,
stacklevel=3,
)
def _build_gp_model_class(
self,
variational_distribution_class: Optional[type[DistributionT]],
variational_strategy_class: Optional[type[StrategyT]],
) -> type[SVGPModel]:
if variational_distribution_class is not None:
@wraps_class(SVGPModel)
class VDistGPModel(SVGPModel):
def _build_variational_distribution(self, n_inducing_points: int) -> DistributionT:
return variational_distribution_class(n_inducing_points)
else:
@wraps_class(SVGPModel)
class VDistGPModel(SVGPModel):
pass
if variational_strategy_class is not None:
variational_strategy_class.approximation_size = self.n_inducing_points
@wraps_class(VDistGPModel)
class NewGPModel(VDistGPModel):
def _build_base_variational_strategy(
self, inducing_points: Tensor, variational_distribution: DistributionT
) -> StrategyT:
return variational_strategy_class(self, inducing_points, variational_distribution)
else:
@wraps_class(VDistGPModel)
class NewGPModel(VDistGPModel):
pass
return NewGPModel
def _decorate_class(self, cls: type[ControllerT]) -> type[ControllerT]:
n_inducing_points = self.n_inducing_points
decorator = self
_gp_model_class = self.gp_model_class
@wraps_class(cls, decorator_source=self)
class InnerClass(cls):
"""
A wrapper for implementing variational inference.
"""
gp_model_class = _gp_model_class
def __init__(self, *args: Any, **kwargs: Any) -> None:
all_parameters_as_kwargs = process_args(super().__init__, *args, **kwargs)
self.rng = utils.optional_random_generator(all_parameters_as_kwargs.pop("rng", None))
train_x = torch.as_tensor(all_parameters_as_kwargs.pop("train_x"))
train_y = torch.as_tensor(all_parameters_as_kwargs.pop("train_y"))
gp_kwargs = all_parameters_as_kwargs.pop("gp_kwargs", {})
gp_kwargs["n_inducing_points"] = n_inducing_points or train_x.shape[0]
mll_kwargs = all_parameters_as_kwargs.pop("mll_kwargs", {})
mll_kwargs["num_data"] = train_y.numel()
try:
super().__init__(
train_x=train_x,
train_y=train_y,
gp_kwargs=gp_kwargs,
mll_kwargs=mll_kwargs,
rng=self.rng,
**all_parameters_as_kwargs,
)
except TypeError as error:
if "__init__() got an unexpected keyword argument 'num_data'" in str(error):
raise TypeError(
"The class passed to `marginal_log_likelihood_class` must take a "
"`num_data: int` argument, since we run variational inference with SGD."
) from error
else:
raise
def _predictive_likelihood(self, x: Union[numpy.typing.NDArray[np.floating], float]) -> Posterior:
with gpytorch.settings.num_likelihood_samples(decorator.n_likelihood_samples):
return super()._predictive_likelihood(x)
def _fuzzy_predictive_likelihood(
self,
x: Union[numpy.typing.NDArray[np.floating], float],
x_std: Union[numpy.typing.NDArray[np.floating], float],
) -> Posterior:
with gpytorch.settings.num_likelihood_samples(decorator.n_likelihood_samples):
return super()._fuzzy_predictive_likelihood(x, x_std)
# Ignore type errors here - static type checkers don't understand that we dynamically inherit from `cls`, so
# `InnerClass` is always a subtype of `cls`
return InnerClass # pyright: ignore[reportReturnType]