Source code for vanguard.multitask.decorator

# © Crown Copyright GCHQ
#
# Licensed under the GNU General Public License, version 3 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.gnu.org/licenses/gpl-3.0.en.html
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Enabling multitask Gaussian processes.

The :class:`~vanguard.multitask.decorator.Multitask` decorator
converts a controller class into a multitask controller.
"""

import warnings
from typing import Any, Optional, TypeVar

import torch
from gpytorch.kernels import Kernel, MultitaskKernel
from gpytorch.means import ConstantMean, Mean, MultitaskMean
from torch import Tensor
from typing_extensions import override

from vanguard import utils
from vanguard.base import GPController
from vanguard.decoratorutils import Decorator, process_args, wraps_class
from vanguard.decoratorutils.errors import BadCombinationWarning
from vanguard.multitask.kernel import BatchCompatibleMultitaskKernel
from vanguard.multitask.models import (
    independent_variational_multitask_model,
    lmc_variational_multitask_model,
    multitask_model,
)
from vanguard.variational import VariationalInference

ControllerT = TypeVar("ControllerT", bound=GPController)
T = TypeVar("T")


[docs] class Multitask(Decorator): """ Make a GP multitask. :Example: >>> from vanguard.base import GPController >>> >>> @Multitask(num_tasks=2) ... class MyController(GPController): ... pass """
[docs] def __init__(self, num_tasks: int, lmc_dimension: Optional[int] = None, rank: int = 1, **kwargs: Any) -> None: """ Initialise self. :param num_tasks: The number of tasks (i.e. y-value dimension). :param lmc_dimension: If using LMC (linear model of co-regionalisation), how many latent dimensions to use. Bigger means a more complicated model. Should probably be at least as big as the number of tasks, unless you want to specifically make low-rank assumptions about the relationship between tasks. Default (None) means LMC is not used at all. :param rank: The rank of the task-task covar matrix in a Kronecker product multitask kernel. Only relevant for exact GP inference. """ super().__init__(framework_class=GPController, required_decorators={}, **kwargs) self.num_tasks = num_tasks self.lmc_dimension = lmc_dimension self.rank = rank
@property @override def safe_updates(self) -> dict[type, set[str]]: # pylint: disable=import-outside-toplevel from vanguard.hierarchical import LaplaceHierarchicalHyperparameters, VariationalHierarchicalHyperparameters from vanguard.learning import LearnYNoise from vanguard.normalise import NormaliseY from vanguard.standardise import DisableStandardScaling from vanguard.warps import SetInputWarp, SetWarp # pylint: enable=import-outside-toplevel return self._add_to_safe_updates( super().safe_updates, { DisableStandardScaling: {"_input_standardise_modules"}, LaplaceHierarchicalHyperparameters: { "__init__", "_compute_hyperparameter_laplace_approximation", "_compute_loss_hessian", "_fuzzy_predictive_likelihood", "_get_posterior_over_fuzzy_point_in_eval_mode", "_get_posterior_over_point", "_gp_forward", "_predictive_likelihood", "_sample_and_set_hyperparameters", "_sgd_round", "_update_hyperparameter_posterior", "auto_temperature", }, LearnYNoise: {"__init__"}, NormaliseY: {"__init__", "warn_normalise_y"}, SetInputWarp: {"__init__"}, SetWarp: {"__init__", "_loss", "_sgd_round", "warn_normalise_y", "_unwarp_values"}, VariationalHierarchicalHyperparameters: { "__init__", "_fuzzy_predictive_likelihood", "_get_posterior_over_fuzzy_point_in_eval_mode", "_get_posterior_over_point", "_gp_forward", "_loss", "_predictive_likelihood", }, VariationalInference: {"__init__", "_predictive_likelihood", "_fuzzy_predictive_likelihood"}, }, )
[docs] @override def verify_decorated_class(self, cls: type[T]) -> None: super().verify_decorated_class(cls) decorators = getattr(cls, "__decorators__", []) if any(issubclass(decorator, Multitask) for decorator in decorators): warnings.warn( "Multiple instances of `@Multitask` not supported." " Please only apply one instance of `@Multitask` at once.", BadCombinationWarning, stacklevel=3, )
def _decorate_class(self, cls: type[ControllerT]) -> type[ControllerT]: decorator = self is_variational = VariationalInference in cls.__decorators__ @wraps_class(cls, decorator_source=self) class InnerClass(cls): """ A wrapper for converting a controller class to multitask. It is fairly lightweight, extracting the necessary information like number of tasks from the supplied data, converting means to multitask means and slightly modifying a few methods to deal with multitask Gaussian's etc. """ def __init__(self, *args: Any, **kwargs: Any) -> None: all_parameters_as_kwargs = process_args(super().__init__, *args, **kwargs) self.rng = utils.optional_random_generator(all_parameters_as_kwargs.pop("rng", None)) # It's OK to access self.gp_model_class here as it's set in super().__init__ above original_gp_model_class = self.gp_model_class # pylint: disable=access-member-before-definition if is_variational: if decorator.lmc_dimension is not None: gp_model_class = lmc_variational_multitask_model(original_gp_model_class) else: gp_model_class = independent_variational_multitask_model(original_gp_model_class) else: gp_model_class = original_gp_model_class # Pyright cannot resolve dynamic base class @multitask_model class MultitaskGPModelClass(gp_model_class): # pyright: ignore[reportGeneralTypeIssues] """Multitask version of gp_model_class.""" self.gp_model_class = MultitaskGPModelClass self.num_tasks = decorator.num_tasks mean_class = all_parameters_as_kwargs.pop("mean_class", ConstantMean) kernel_class = all_parameters_as_kwargs.pop("kernel_class") kernel_kwargs = all_parameters_as_kwargs.get("kernel_kwargs", {}) mean_kwargs = all_parameters_as_kwargs.get("mean_kwargs", {}) if is_variational: kernel_class = _batchify(kernel_class, kernel_kwargs, decorator.num_tasks, decorator.lmc_dimension) mean_class = _batchify(mean_class, mean_kwargs, decorator.num_tasks, decorator.lmc_dimension) else: kernel_class = _multitaskify_kernel(kernel_class, decorator.num_tasks, decorator.rank) try: mean_class = self._match_mean_shape_to_kernel(mean_class, kernel_class, mean_kwargs, kernel_kwargs) except TypeError as exc: # Check for batch shape mismatches and reraise with a more informative message. if "batch_shape" in mean_kwargs: batch_shape = mean_kwargs["batch_shape"] if not isinstance(batch_shape, torch.Size): msg = ( f"Expected mean_kwargs['batch_shape'] to be of type `torch.Size`; " f"got `{batch_shape.__class__.__name__}` instead" ) raise TypeError(msg) from exc # If it's some other TypeError, just re-raise it. raise likelihood_kwargs = all_parameters_as_kwargs.pop("likelihood_kwargs", {}) likelihood_kwargs["num_tasks"] = decorator.num_tasks gp_kwargs = all_parameters_as_kwargs.pop("gp_kwargs", {}) gp_kwargs["num_tasks"] = decorator.num_tasks super().__init__( kernel_class=kernel_class, mean_class=mean_class, likelihood_kwargs=likelihood_kwargs, gp_kwargs=gp_kwargs, rng=self.rng, **all_parameters_as_kwargs, ) @property def likelihood_noise(self) -> Tensor: """Return the fixed noise of the likelihood.""" try: return self._likelihood.fixed_noise except AttributeError as exc: raise AttributeError( "'fixed_noise' appears to have not been set yet. This can be set " "with the `likelihood_noise` method" ) from exc @likelihood_noise.setter def likelihood_noise(self, value: Tensor) -> None: """Set the fixed noise of the likelihood.""" self._likelihood.fixed_noise = value @staticmethod def _match_mean_shape_to_kernel( mean_class: type[Mean], kernel_class: type[Kernel], mean_kwargs: dict[str, Any], kernel_kwargs: dict[str, Any], ) -> type[Mean]: """ Construct a mean class suitable for multitask GPs that matches the form of the kernel, if possible. :param mean_class: An uninstantiated :class:`gpytorch.means.Mean`. :param kernel_class: An uninstantiated :class:`gpytorch.kernels.Kernel`. :param mean_kwargs: Keyword arguments to be passed to the mean_class constructor. :param kernel_kwargs: Keyword arguments to be passed to the kernel_class constructor. :returns: An uninstantiated :class:`gpytorch.means.Mean` like mean_class but modified to have the same form/shape as kernel_class, if possible. :raises TypeError: If the supplied mean_class has a batch_shape and it doesn't match the batch_shape of the kernel_class, or is a :class:`gpytorch.kernels.MultitaskKernel` and has num_tasks which doesn't match that of the kernel_class. """ example_kernel = kernel_class(**kernel_kwargs) example_mean = mean_class(**mean_kwargs) if isinstance(example_kernel, MultitaskKernel): return _multitaskify_mean(mean_class, decorator.num_tasks) if len(example_kernel.batch_shape) > 0 and example_mean.batch_shape != example_kernel.batch_shape: msg = ( f"The provided mean has batch_shape {example_mean.batch_shape} but the " f"provided kernel has batch_shape {example_kernel.batch_shape}. " f"They must match." ) raise ValueError(msg) return mean_class # Pyright does not detect that wraps_class renames InnerClass return InnerClass # pyright: ignore [reportReturnType]
def _batchify(module_class: type[T], _kwargs: dict[str, Any], num_tasks: int, lmc_dimension: Optional[int]) -> type[T]: """ Add a batch shape to a class so it can be used for multitask variational GPs. :param module_class: The class to batchify, typically a kernel or mean. :param _kwargs: Remaining in signature for compatibility. :param num_tasks: The number of tasks for the multitask GP. :param lmc_dimension: The number of LMC dimensions (if using LMC). :returns: The adapted class. """ batch_size = lmc_dimension if lmc_dimension is not None else num_tasks @wraps_class(module_class) class InnerClass(module_class): def __init__(self, *args: Any, **kwargs: Any) -> None: batch_shape = kwargs.pop("batch_shape", torch.Size([])) + torch.Size([batch_size]) kwargs["batch_shape"] = batch_shape super().__init__(*args, **kwargs) # Pyright does not detect that wraps_class renames InnerClass return InnerClass # pyright: ignore [reportReturnType] def _multitaskify_kernel(kernel_class: type[Kernel], num_tasks: int, rank: int = 1) -> type[MultitaskKernel]: """ If necessary, make a kernel multitask using the GPyTorch Multitask kernel. :param kernel_class: The kernel to multitaskify. :param num_tasks: The number of tasks for the multitask GP. :param rank: The rank of the task-task covariance matrix. :returns: The adapted kernel class. """ if issubclass(kernel_class, MultitaskKernel): return kernel_class else: rank = min(num_tasks, rank) class InnerKernelClass(BatchCompatibleMultitaskKernel): def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(kernel_class(*args, **kwargs), num_tasks=num_tasks, rank=rank, **kwargs) return InnerKernelClass def _multitaskify_mean(mean_class: type[Mean], num_tasks: int) -> type[MultitaskMean]: """ If necessary, make a mean multitask using the GPyTorch Multitask mean. :param mean_class: The mean to multitaskify. :param num_tasks: The number of tasks for the multitask GP. :returns: The adapted mean class. """ if issubclass(mean_class, MultitaskMean): return mean_class else: class InnerMeanClass(MultitaskMean): def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(mean_class(*args, **kwargs), num_tasks=num_tasks) return InnerMeanClass