Source code for vanguard.warps.input

# © Crown Copyright GCHQ
#
# Licensed under the GNU General Public License, version 3 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.gnu.org/licenses/gpl-3.0.en.html
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Contains the Python decorators for applying input warping.
"""

from typing import Any, TypeVar

import torch
from typing_extensions import Self, override

from vanguard import utils
from vanguard.base import GPController
from vanguard.classification.mixin import Classification, ClassificationMixin
from vanguard.decoratorutils import Decorator, process_args, wraps_class
from vanguard.variational import VariationalInference
from vanguard.warps.basefunction import WarpFunction

ControllerT = TypeVar("ControllerT", bound=GPController)
ModuleT = TypeVar("ModuleT", bound=torch.nn.Module)


class _SetModuleInputWarp:
    """
    Set the input warp for a `torch.nn.Module` instance.

    Input warping is formulated so that the index (input) space of the GP must be transformed using the input warp.
    As such, to obtain the desired model with the chosen mean and kernel in the warped space, the mean and kernel
    functions must be composed with the inverse warp.

    Since kernels and means are implemented as subclasses of `torch.nn.Module` in GPyTorch, we can apply the inverse
    warping to both using this class alone.
    """

    def __init__(self, warp: WarpFunction) -> None:
        self.warp = warp

    def __call__(self, module_class: type[ModuleT]) -> type[ModuleT]:
        warp = self.warp

        @wraps_class(module_class)
        class InnerClass(module_class):
            """Apply the inner warp."""

            def forward(self, *args: Any, **kwargs: Any):
                """Map all inputs through the warp inverse."""
                inverse_warped_inputs = [warp.inverse(x) for x in args]
                return super().forward(*inverse_warped_inputs, **kwargs)

        return InnerClass


[docs] class SetInputWarp(Decorator): """ Apply input warping to a GP to achieve non-Gaussian input uncertainty. :Example: >>> from vanguard.base import GPController >>> from vanguard.warps.warpfunctions import BoxCoxWarpFunction >>> >>> @SetInputWarp(BoxCoxWarpFunction(1)) ... class MyController(GPController): ... pass """
[docs] def __init__(self, warp_function: WarpFunction, **kwargs: Any) -> None: """ Initialise self. :param warp_function: The warp function to be applied to the GP inputs. :param kwargs: Keyword arguments passed to :class:`~vanguard.decoratorutils.basedecorator.Decorator`. """ super().__init__(framework_class=GPController, required_decorators={}, **kwargs) self.warp_function = warp_function
@property @override def safe_updates(self) -> dict[type, set[str]]: # pylint: disable=import-outside-toplevel from vanguard.classification import ( BinaryClassification, CategoricalClassification, DirichletMulticlassClassification, ) from vanguard.classification.kernel import DirichletKernelMulticlassClassification from vanguard.features import HigherRankFeatures from vanguard.hierarchical import LaplaceHierarchicalHyperparameters, VariationalHierarchicalHyperparameters from vanguard.learning import LearnYNoise from vanguard.multitask import Multitask from vanguard.normalise import NormaliseY from vanguard.standardise import DisableStandardScaling from vanguard.warps import SetWarp # pylint: enable=import-outside-toplevel return self._add_to_safe_updates( super().safe_updates, { BinaryClassification: { "__init__", "classify_points", "classify_fuzzy_points", "_get_predictions_from_prediction_means", "warn_normalise_y", }, CategoricalClassification: { "__init__", "classify_points", "classify_fuzzy_points", "_get_predictions_from_posterior", "warn_normalise_y", }, ClassificationMixin: {"classify_points", "classify_fuzzy_points"}, Classification: { "posterior_over_point", "posterior_over_fuzzy_point", "fuzzy_predictive_likelihood", "predictive_likelihood", }, DirichletKernelMulticlassClassification: { "__init__", "classify_points", "classify_fuzzy_points", "_get_predictions_from_prediction_means", }, DirichletMulticlassClassification: { "__init__", "_loss", "_noise_transform", "classify_points", "classify_fuzzy_points", "_get_predictions_from_prediction_means", "warn_normalise_y", }, DisableStandardScaling: {"_input_standardise_modules"}, HigherRankFeatures: {"__init__"}, LaplaceHierarchicalHyperparameters: { "__init__", "_compute_hyperparameter_laplace_approximation", "_compute_loss_hessian", "_fuzzy_predictive_likelihood", "_get_posterior_over_fuzzy_point_in_eval_mode", "_get_posterior_over_point", "_gp_forward", "_predictive_likelihood", "_sample_and_set_hyperparameters", "_sgd_round", "_update_hyperparameter_posterior", "auto_temperature", }, LearnYNoise: {"__init__"}, Multitask: {"__init__", "_match_mean_shape_to_kernel"}, NormaliseY: {"__init__", "warn_normalise_y"}, SetInputWarp: {"__init__"}, SetWarp: {"__init__", "_loss", "_sgd_round", "warn_normalise_y", "_unwarp_values"}, VariationalHierarchicalHyperparameters: { "__init__", "_fuzzy_predictive_likelihood", "_get_posterior_over_fuzzy_point_in_eval_mode", "_get_posterior_over_point", "_gp_forward", "_loss", "_predictive_likelihood", }, VariationalInference: {"__init__", "_predictive_likelihood", "_fuzzy_predictive_likelihood"}, }, ) def _decorate_class(self, cls: type[ControllerT]) -> type[ControllerT]: warp_function = self.warp_function @wraps_class(cls, decorator_source=self) class InnerClass(cls): """ A wrapper for applying a warp to inputs for non-Gaussian input uncertainty. """ def __init__(self, *args: Any, **kwargs: Any) -> None: all_parameters_as_kwargs = process_args(super().__init__, *args, **kwargs) self.rng = utils.optional_random_generator(all_parameters_as_kwargs.pop("rng", None)) module_decorator = _SetModuleInputWarp(warp_function) mean_class = all_parameters_as_kwargs.pop("mean_class") kernel_class = all_parameters_as_kwargs.pop("kernel_class") super().__init__( kernel_class=module_decorator(kernel_class), mean_class=module_decorator(mean_class), rng=self.rng, **all_parameters_as_kwargs, ) self.input_warp = warp_function @classmethod def new(cls, instance: Self, **kwargs: Any) -> Self: """Also apply warping to the new instance.""" new_instance = super().new(instance, **kwargs) new_instance.input_warp = instance.input_warp return new_instance return InnerClass