Source code for vanguard.features

# © Crown Copyright GCHQ
#
# Licensed under the GNU General Public License, version 3 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.gnu.org/licenses/gpl-3.0.en.html
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Contains decorators to deal with input features that aren't vectors.
"""

import warnings
from functools import partial
from typing import Any, TypeVar, Union

import numpy as np
import torch
from gpytorch.models import GP
from typing_extensions import override

from vanguard import utils
from vanguard.base import GPController
from vanguard.classification.mixin import Classification, ClassificationMixin
from vanguard.decoratorutils import Decorator, process_args, wraps_class
from vanguard.decoratorutils.basedecorator import T
from vanguard.decoratorutils.errors import BadCombinationWarning
from vanguard.variational import VariationalInference
from vanguard.warnings import warn_experimental

ControllerT = TypeVar("ControllerT", bound=GPController)
GPModelT = TypeVar("GPModelT", bound=GP)


[docs] class HigherRankFeatures(Decorator): """ Make a :class:`~vanguard.base.gpcontroller.GPController` compatible with higher rank features. GPyTorch assumes that input features are rank-1 (vectors) and a variety of RuntimeErrors are thrown from different places in the code if this is not true. This decorator modifies the gp model class to make it compatible with higher rank features. .. warning:: This decorator is EXPERIMENTAL. It may cause errors or give incorrect results, and may have breaking changes without warning. :Example: >>> @HigherRankFeatures(2) ... class NewController(GPController): ... pass """
[docs] def __init__(self, rank: int, **kwargs: Any) -> None: """ :param rank: The rank of the input features. Should be a positive integer. """ warn_experimental("The HigherRankFeatures decorator") super().__init__(framework_class=GPController, required_decorators={}, **kwargs) self.rank = rank
[docs] @override def verify_decorated_class(self, cls: type[T]) -> None: super().verify_decorated_class(cls) decorators = getattr(cls, "__decorators__", []) if any(issubclass(decorator, HigherRankFeatures) for decorator in decorators): warnings.warn( "Multiple instances of `@HigherRankFeatures` not supported." " Please only apply one instance of `@HigherRankFeatures` at once.", BadCombinationWarning, stacklevel=3, )
@property @override def safe_updates(self) -> dict[type, set[str]]: # pylint: disable=import-outside-toplevel from vanguard.classification import BinaryClassification from vanguard.learning import LearnYNoise from vanguard.normalise import NormaliseY from vanguard.standardise import DisableStandardScaling from vanguard.warps import SetInputWarp, SetWarp # pylint: enable=import-outside-toplevel return self._add_to_safe_updates( super().safe_updates, { ClassificationMixin: {"classify_points", "classify_fuzzy_points"}, Classification: { "posterior_over_point", "posterior_over_fuzzy_point", "fuzzy_predictive_likelihood", "predictive_likelihood", }, VariationalInference: {"__init__", "_predictive_likelihood", "_fuzzy_predictive_likelihood"}, DisableStandardScaling: {"_input_standardise_modules"}, LearnYNoise: {"__init__"}, NormaliseY: {"__init__", "warn_normalise_y"}, SetInputWarp: {"__init__"}, SetWarp: {"__init__", "_loss", "_sgd_round", "warn_normalise_y", "_unwarp_values"}, BinaryClassification: { "__init__", "classify_points", "classify_fuzzy_points", "_get_predictions_from_prediction_means", "warn_normalise_y", }, }, ) def _decorate_class(self, cls: type[ControllerT]) -> type[ControllerT]: rank = self.rank @wraps_class(cls, decorator_source=self) class InnerClass(cls): def __init__(self, *args: Any, **kwargs: Any) -> None: all_parameters_as_kwargs = process_args(super().__init__, *args, **kwargs) self.rng = utils.optional_random_generator(all_parameters_as_kwargs.pop("rng", None)) train_x = all_parameters_as_kwargs["train_x"] self.gp_model_class = _HigherRankFeaturesModel(train_x.shape[-rank:])(self.gp_model_class) kernel_class = all_parameters_as_kwargs.pop("kernel_class") super().__init__(kernel_class=kernel_class, rng=self.rng, **all_parameters_as_kwargs) return InnerClass
class _HigherRankFeaturesModel: """ A decorator for a model, enabling higher rank features. GPyTorch assumes that input features are rank-1 (vectors) and a variety of RuntimeErrors are thrown from different places in the code if this is not true. This decorator can be applied to a GPyTorch model and deals with the feature shapes to avoid these issues. The decorator intercepts the training data and any data passed to ``__call__``, flattening it so that the shapes work out correctly. The data are then returned to their native shape before any actual computation (e.g. inside kernels) is performed. """ def __init__(self, shape: Union[tuple[int, ...], torch.Size]) -> None: """ :param shape: The native shape of a single data point. """ self.shape = tuple(shape) self.flat_shape = int(np.prod(self.shape)) def __call__(self, model_cls: type[GPModelT]) -> type[GPModelT]: shape = self.shape flat_shape = self.flat_shape _flatten = partial(self._flatten, item_shape=shape, item_flat_shape=flat_shape) _unflatten = partial(self._unflatten, item_shape=shape) @wraps_class(model_cls) class InnerClass(model_cls): def __init__(self, train_x: torch.Tensor, *args: Any, **kwargs: Any) -> None: super().__init__(_flatten(train_x), *args, **kwargs) def __call__(self, *args, **kwargs): args = [_flatten(arg) for arg in args] return super().__call__(*args, **kwargs) def forward(self, x): return super().forward(_unflatten(x)) return InnerClass @staticmethod def _flatten(tensor: torch.Tensor, item_shape: tuple[int, ...], item_flat_shape: int) -> torch.Tensor: """ Reshapes tensors to flat (rank - 1) features. :param tensor: The tensor to reshape. :param item_shape: The native shape of a single item. :param item_flat_shape: The flatten length of a single item. :returns: Reshape tensor. """ new_shape = tuple(tensor.shape[: -len(item_shape)]) new_shape = new_shape + (item_flat_shape,) return tensor.reshape(new_shape) @staticmethod def _unflatten(tensor: torch.Tensor, item_shape: tuple[int, ...]) -> torch.Tensor: """ Reshapes flatten tensors to native feature shape. :param tensor: The tensor to reshape. :param item_shape: The native shape of a single item. :returns: Reshape tensor. """ new_shape = tuple(tensor.shape[:-1]) new_shape = new_shape + item_shape return tensor.reshape(new_shape)