Source code for vanguard.classification.categorical

# © Crown Copyright GCHQ
#
# Licensed under the GNU General Public License, version 3 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.gnu.org/licenses/gpl-3.0.en.html
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Contains the CategoricalClassification decorator.
"""

from typing import Any, TypeVar, Union

import numpy as np
import numpy.typing
from torch import Tensor
from typing_extensions import override

from vanguard import utils
from vanguard.base import GPController
from vanguard.base.posteriors.posterior import Posterior
from vanguard.classification.mixin import Classification, ClassificationMixin
from vanguard.decoratorutils import Decorator, process_args, wraps_class
from vanguard.multitask import Multitask
from vanguard.variational import VariationalInference

ControllerT = TypeVar("ControllerT", bound=GPController)


[docs] class CategoricalClassification(Decorator): """ Enable categorical classification with more than two classes. .. note:: Although the ``y_std`` parameter is not currently used in classification, it must still be passed. This is likely to change in the future, and so the type must still be correct. Passing ``y_std=0`` is suggested. .. note:: The :class:`~vanguard.variational.VariationalInference` and :class:`~vanguard.multitask.decorator.Multitask` decorators are required for this decorator to be applied. :Example: >>> from gpytorch.likelihoods import BernoulliLikelihood >>> from gpytorch.kernels import RBFKernel >>> from gpytorch.mlls import VariationalELBO >>> import numpy as np >>> import torch >>> from vanguard.vanilla import GaussianGPController >>> from vanguard.classification.likelihoods import MultitaskBernoulliLikelihood >>> >>> @CategoricalClassification(num_classes=3) ... @Multitask(num_tasks=3) ... @VariationalInference() ... class CategoricalClassifier(GaussianGPController): ... pass >>> >>> train_x = np.array([0, 0.5, 0.9, 1]) >>> train_y = np.array([[1, 0, 0], [0, 1,0], [0, 0, 1], [0, 0, 1]]) >>> gp = CategoricalClassifier(train_x, train_y, RBFKernel, y_std=0.0, ... likelihood_class=MultitaskBernoulliLikelihood, ... marginal_log_likelihood_class=VariationalELBO) >>> loss = gp.fit(100) >>> >>> test_x = np.array([0.05, 0.95]) >>> predictions, probs = gp.classify_points(test_x) >>> predictions.tolist() [0, 2] """
[docs] def __init__(self, num_classes: int, **kwargs: Any) -> None: """ Initialise self. :param num_classes: The number of target classes. :param kwargs: Keyword arguments passed to :class:`~vanguard.decoratorutils.basedecorator.Decorator`. """ super().__init__(framework_class=GPController, required_decorators={VariationalInference, Multitask}, **kwargs) self.num_classes = num_classes
@property @override def safe_updates(self) -> dict[type, set[str]]: # pylint: disable=import-outside-toplevel from vanguard.learning import LearnYNoise from vanguard.normalise import NormaliseY from vanguard.standardise import DisableStandardScaling from vanguard.warps import SetInputWarp, SetWarp # pylint: enable=import-outside-toplevel return self._add_to_safe_updates( super().safe_updates, { VariationalInference: {"__init__", "_predictive_likelihood", "_fuzzy_predictive_likelihood"}, Multitask: {"__init__", "_match_mean_shape_to_kernel"}, DisableStandardScaling: {"_input_standardise_modules"}, LearnYNoise: {"__init__"}, NormaliseY: {"__init__", "warn_normalise_y"}, SetInputWarp: {"__init__"}, SetWarp: {"__init__", "_loss", "_sgd_round", "warn_normalise_y", "_unwarp_values"}, }, ) def _decorate_class(self, cls: type[ControllerT]) -> type[ControllerT]: decorator = self @Classification(ignore_all=True) @wraps_class(cls, decorator_source=self) class InnerClass(cls, ClassificationMixin): """ A wrapper for implementing categorical classification. """ def __init__(self, *args: Any, **kwargs: Any) -> None: all_parameters_as_kwargs = process_args(super().__init__, *args, **kwargs) self.rng = utils.optional_random_generator(all_parameters_as_kwargs.pop("rng", None)) likelihood_class = all_parameters_as_kwargs.pop("likelihood_class") likelihood_kwargs = all_parameters_as_kwargs.pop("likelihood_kwargs", dict()) likelihood_kwargs["num_classes"] = decorator.num_classes super().__init__( likelihood_class=likelihood_class, likelihood_kwargs=likelihood_kwargs, rng=self.rng, **all_parameters_as_kwargs, ) def classify_points( self, x: Union[float, numpy.typing.NDArray[np.floating], Tensor] ) -> tuple[Tensor, Tensor]: """Classify points.""" predictive_likelihood = super().predictive_likelihood(x) return self._get_predictions_from_posterior(predictive_likelihood) def classify_fuzzy_points( self, x: Union[float, numpy.typing.NDArray[np.floating], Tensor], x_std: Union[float, numpy.typing.NDArray[np.floating], Tensor], ) -> tuple[Tensor, Tensor]: """Classify fuzzy points.""" predictive_likelihood = super().fuzzy_predictive_likelihood(x, x_std) return self._get_predictions_from_posterior(predictive_likelihood) @staticmethod def _get_predictions_from_posterior( posterior: Posterior, ) -> tuple[Tensor, Tensor]: """ Get predictions from a posterior distribution. :param posterior: The posterior distribution. :returns: The predicted class labels, and the certainty probabilities. """ probs: Tensor = posterior.distribution.probs if probs.ndim == 3: # TODO: unsure why this is here? Document this, and then test it if it's intentional # https://github.com/gchq/Vanguard/issues/234 probs = probs.mean(0) normalised_probs = probs / probs.sum(dim=-1).reshape((-1, 1)) prediction_values, predictions = normalised_probs.max(dim=1) return predictions, prediction_values @staticmethod def warn_normalise_y() -> None: """Override base warning because classification renders y normalisation irrelevant.""" return InnerClass