# © Crown Copyright GCHQ
#
# Licensed under the GNU General Public License, version 3 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.gnu.org/licenses/gpl-3.0.en.html
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Contains model classes to enable classification in Vanguard.
"""
import warnings
from typing import Any, Optional, Union
import gpytorch
import torch
from gpytorch import settings
from gpytorch.distributions import MultivariateNormal
from gpytorch.means import ZeroMean
from gpytorch.models import ExactGP
from gpytorch.utils.warnings import GPInputWarning
from linear_operator import LinearOperator
from linear_operator.operators import DiagLinearOperator
from torch import Tensor
from typing_extensions import override
from vanguard.models import ExactGPModel
from vanguard.utils import DummyDistribution
class DummyKernelDistribution(DummyDistribution):
"""
A dummy distribution to hold a kernel matrix and some one-hot labels.
"""
# TODO: Lying to the type checker here feels like bad code, and should only be a very temporary measure. Should
# probably just inherit from Distribution, and type hint downstream code to expect an arbitrary Distribution.
# https://github.com/gchq/Vanguard/issues/394
__class__ = MultivariateNormal
def __init__(self, labels: Union[Tensor, LinearOperator], kernel: Union[Tensor, LinearOperator]) -> None:
"""
Initialise self.
:param labels: The one-hot labels, shape: torch.Size([n_points, num_classes]).
:param kernel: The kernel matrix.
"""
self.labels = labels
self.kernel = kernel
try:
self.mean = self.kernel @ self.labels.to_dense()
self.covariance_matrix = torch.zeros(
self.mean.shape[-1], self.mean.shape[-1], self.kernel.shape[0], self.kernel.shape[0]
)
# The last two dimensions represent the pairwise covariances between the test points
# The first two dimensions represent the covariances between the classes for each pair of test points.
except RuntimeError:
self.mean = labels
self.covariance_matrix = kernel
def add_jitter(self, jitter: float = 1e-3):
"""
Adds a small constant diagonal to the covariance matrix for numerical stability.
:param jitter: The size of the constant diagonal.
:return: The instance with the updated covariance matrix.
"""
jitter_matrix = torch.eye(self.covariance_matrix.shape[-1]) * jitter
jitter_matrix = jitter_matrix.unsqueeze(0).unsqueeze(0).expand(self.covariance_matrix.shape)
assert jitter_matrix.shape == self.covariance_matrix.shape
# Add jitter to the diagonal elements
self.covariance_matrix += jitter_matrix
return self
[docs]
class InertKernelModel(ExactGPModel):
"""
An inert model wrapping a kernel matrix.
Uses a given kernel for prior and posterior and returns a dummy distribution holding the
kernel matrix.
"""
[docs]
def __init__(
self,
train_inputs: Optional[torch.Tensor],
train_targets: Optional[torch.Tensor],
covar_module: gpytorch.kernels.Kernel,
mean_module: Optional[gpytorch.means.Mean],
likelihood: gpytorch.likelihoods.Likelihood,
num_classes: int,
**_: Any,
) -> None:
"""
Initialise self.
Note that while arbitrary keyword arguments are accepted, they are not inspected or used. This is to allow
passing keyword parameters that are required by other GP models (e.g. `rng`) without raising a `TypeError`,
which allows more generic code.
:param train_inputs: (n_samples, n_features) The training inputs (features).
:param train_targets: (n_samples,) The training targets (response).
:param covar_module: The prior kernel function to use.
:param mean_module: Not used, remaining in the signature for compatibility.
:param likelihood: Likelihood to use with model.
:param num_classes: The number of classes to use.
"""
super(ExactGP, self).__init__()
if train_inputs is None:
self.train_inputs = None
self.train_targets = None
else:
if torch.is_tensor(train_inputs):
train_inputs = (train_inputs,)
try:
self.train_inputs = tuple(tri.unsqueeze(-1) if tri.ndimension() == 1 else tri for tri in train_inputs)
except AttributeError as exc:
raise TypeError("Train inputs must be a tensor, or a list/tuple of tensors") from exc
self.train_targets = train_targets
self.prediction_strategy = None
self.n_classes = num_classes
self.covar_module = covar_module
self.mean_module = ZeroMean()
self.likelihood = likelihood
[docs]
def train(self, mode: bool = True) -> ExactGPModel:
"""Set to training mode, if data is not None."""
if mode is True and (self.train_inputs is None or self.train_targets is None):
raise RuntimeError(
"train_inputs, train_targets cannot be None in training mode. "
"Call .eval() for prior predictions, or call .set_train_data() to add training data."
)
return super().train(mode)
def _label_tensor(self, targets: torch.Tensor) -> LinearOperator:
return DiagLinearOperator(torch.ones(self.n_classes))[targets.long()]
@override
def __call__(self, *args: Any, **kwargs: Any) -> DummyKernelDistribution:
"""Perform training or inference, depending on the current mode."""
# TODO: Why do we accept variable numbers of arguments here? It seems to throw errors if you provide too many
# arguments, and the GPyTorch documentation seems very thin here. Also, `kwargs` is ignored entirely.
# https://github.com/gchq/Vanguard/issues/292
train_inputs = list(self.train_inputs) if self.train_inputs is not None else []
inputs = [arg.unsqueeze(-1) if arg.ndimension() == 1 else arg for arg in args]
input_equals_training_inputs = all(
torch.equal(train_input, input) for train_input, input in zip(train_inputs, inputs)
)
if self.training:
if settings.debug.on() and not input_equals_training_inputs:
raise RuntimeError("You must train on the training inputs!")
kernel_matrix = self.covar_module(*inputs)
elif settings.prior_mode.on() or self.train_inputs is None or self.train_targets is None:
# TODO: Prior mode evaluation fails due to a shape mismatch, seemingly due to the reference to
# train_targets in the return value.
# https://github.com/gchq/Vanguard/issues/291
kernel_matrix = self.covar_module(*args)
else:
if settings.debug.on() and input_equals_training_inputs:
warnings.warn(
"The input matches the stored training data. Did you forget to call model.train()?",
GPInputWarning,
)
kernel_matrix = self.covar_module(*inputs, *train_inputs)
# TODO: This will fail if train_targets is None. (AttributeError: 'NoneType' object has no attribute 'long')
# https://github.com/gchq/Vanguard/issues/291
labels = self._label_tensor(self.train_targets)
assert labels.shape == torch.Size([kernel_matrix.shape[-1], self.n_classes])
assert kernel_matrix.shape == torch.Size([inputs[0].shape[0], train_inputs[0].shape[0]])
return DummyKernelDistribution(labels=labels, kernel=kernel_matrix)