Source code for vanguard.distribute.aggregators

# © Crown Copyright GCHQ
#
# Licensed under the GNU General Public License, version 3 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.gnu.org/licenses/gpl-3.0.en.html
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
A suite of aggregators to be used with the :class:`~vanguard.distribute.decorator.Distributed` decorator.

These are responsible for combining the predictions of several independent expert controllers.
"""

from typing import Optional

import torch


[docs] class BadPriorVarShapeError(ValueError): pass
[docs] class BaseAggregator: """ Aggregate experts' posteriors to an approximate predictive posterior. All aggregators should inherit from this class. :param means: List with `d` elements, each element is an array of a single expert's predictive mean at the evaluation points. :param covars: List with `d` elements, each :math:`d \times d` element is the individual experts posterior predictive covariance at the test points. :param prior_var: Tensor with `d` elements, with each element being the diagonal of the test kernel with added noise. """
[docs] def __init__( self, means: list[torch.Tensor], covars: list[torch.Tensor], prior_var: Optional[torch.Tensor] = None ) -> None: """ Initialise the BaseAggregator class. """ self.means = torch.stack(means).type(torch.float32) self.covars = torch.stack(covars).type(torch.float32) self.variances = self.covars.diagonal(dim1=1, dim2=2) self.prior_var = torch.as_tensor(prior_var).type(torch.float32) if prior_var is not None else None self.n_experts = self.means.shape[0] if prior_var is None: self.prior_var = None else: self.prior_var = torch.as_tensor(prior_var).type(torch.float32) if self.prior_var.dim() >= self.variances.dim() and self.prior_var.shape != self.variances.shape: raise BadPriorVarShapeError( f"Prior var shape {self.prior_var.shape} doesn't match variances shape {self.variances.shape}" )
[docs] def aggregate(self) -> tuple[torch.Tensor, torch.Tensor]: """ Combine the predictions of the individual experts into a single PoE prediction. :return: The mean and variance of the combined experts. """ # TODO: should this be an abstract method? # https://github.com/gchq/Vanguard/issues/241 raise NotImplementedError
[docs] def _beta_correction(self, delta_diff: torch.Tensor, delta_val: torch.Tensor) -> torch.Tensor: """ Implement the correction to experts' weights. .. note:: Delta is the difference in differential entropy between prior and posterior :cite:`Deisenroth15`. ``delta_diff`` and ``delta_val`` are the same in :class:`XBCMAggregator`. :param delta_diff: The delta used to determine if correction is applied (proxy for in-vs-out of training data). :param delta_val: The delta value to be corrected. Must be the same shape as ``delta_diff``. :return: The corrected expert weights, the same shape as ``delta_diff``. """ in_training_data = delta_diff > 1 not_in_training_data = delta_diff <= 1 corrected_weights = in_training_data * delta_val corrected_weights += not_in_training_data * (delta_val + ((1 - delta_val) / self.n_experts)) return corrected_weights
@staticmethod def _make_pseudo_covar(variance: torch.Tensor) -> torch.Tensor: """ Convert a variance to a covariance matrix, where all entries except the diagonal are zeros. :param variance: Tensor of variances for each point of interest :return: Covariance matrix, where all non-diagonal elements are zero """ dim = variance.size(-1) covar = torch.zeros((dim, dim), dtype=variance.dtype) covar[range(dim), range(dim)] = variance return covar
[docs] class POEAggregator(BaseAggregator): r""" Implements the Product-of-Experts method of :cite:`Deisenroth15`. Formulae for covariances from :cite:`Cao14`. Given the posteriors of the experts :math:`p_{i}(y|x) = N(\mu_{i}(x), \sigma_{i}^{2}(x))` for :math:`i=1, 2, ..., M`, we define the joint posterior as a Gaussian with moments .. math :: \mu &= \sigma^{2} \sum_{i} \sigma_{i}^{-2}(x) \mu_{i}(x) \\ \sigma^{-2} &= \sum_{i} \sigma_{i}^{-2}(x) """
[docs] def aggregate(self) -> tuple[torch.Tensor, torch.Tensor]: """ Combine the predictions of the individual experts into a single PoE prediction. :return: The mean and variance of the combined experts. """ covar_inverses = torch.stack([torch.inverse(covar) for covar in self.covars]) covar = torch.inverse(torch.sum(covar_inverses, dim=0)) mean = torch.tensordot( torch.sum( torch.stack( [ torch.tensordot(mean.reshape(1, -1), covar_inverse, dims=1) for mean, covar_inverse in zip(self.means, covar_inverses) ] ), dim=0, ), covar, dims=1, ).reshape(-1) return mean, covar
[docs] class EKPOEAggregator(POEAggregator): r""" Implements a correction to the Product-of-Experts method. Given the posteriors of the experts :math:`p_{i}(y|x) = N(\mu_{i}(x), \sigma_{i}^{2}(x))` for :math:`i=1, 2, ..., M`, we define the joint posterior as a Gaussian with moments .. math :: \mu &= M \sigma^{2} \sum_{i} \sigma_{i}^{-2}(x) \mu_{i}(x) \\ \sigma^{-2} &= \frac{1}{M} \sum_{i} \sigma_{i}^{-2}(x) """
[docs] def aggregate(self) -> tuple[torch.Tensor, torch.Tensor]: """ Combine the predictions of the individual experts into a single PoE prediction. :return: The mean and variance of the combined experts. """ mean, variance = super().aggregate() return mean, variance * self.n_experts
[docs] class GPOEAggregator(BaseAggregator): r""" Implements the Generalised Product-of-Experts method of :cite:`Deisenroth15`. Given the posteriors of the experts :math:`p_{i}(y|x) = N(\mu_{i}(x), \sigma_{i}^{2}(x))` for :math:`i=1, 2, ..., M`, we define the joint posterior as a Gaussian with moments .. math :: \mu &= \sigma^{2} \sum_{i} \beta_{i} \sigma_{i}^{-2}(x) \mu_{i}(x) \\ \sigma^{-2} &= \sum_{i} \beta_{i} \sigma_{i}^{-2}(x) where :math:`\beta_{i}=\frac{1}{M}`. """
[docs] def aggregate(self) -> tuple[torch.Tensor, torch.Tensor]: """ Combine the predictions of the individual experts into a single PoE prediction. :return: The mean and variance of the combined experts. """ beta = torch.ones_like(self.means) / self.n_experts mean = torch.sum((beta / self.variances) * self.means, dim=0) variance = torch.sum(beta / self.variances, dim=0) return mean / variance, self._make_pseudo_covar(1 / variance)
[docs] class BCMAggregator(BaseAggregator): r""" Implements the Bayesian Committee Machine method of :cite:`Deisenroth15`. Given the posteriors of the experts :math:`p_{i}(y|x) = N(\mu_{i}(x), \sigma_{i}^{2}(x))` for :math:`i=1, 2, ..., M`, we define the joint posterior as a Gaussian with moments .. math :: \mu &= \sigma^{2} \sum_{i} \sigma_{i}^{-2}(x) \mu_{i}(x) \\ \sigma^{-2} &= \sum_{i} \sigma_{i}^{-2}(x) + \bigg( 1 - M \bigg) \sigma_{**}^{-2} where :math:`\sigma_{**}^{-2}` is the diagonal of the covariance matrix formed by applying the kernel on all pairs of points in :math:`x`. """
[docs] def aggregate(self) -> tuple[torch.Tensor, torch.Tensor]: """ Combine the predictions of the individual experts into a single BCM prediction. :return: The mean and variance of the combined experts. """ beta = torch.ones_like(self.means) mean = torch.sum((beta / self.variances) * self.means, dim=0) variance = torch.sum(beta / self.variances, dim=0) variance -= (self.n_experts - 1) / self.prior_var return mean / variance, self._make_pseudo_covar(1 / variance)
[docs] class RBCMAggregator(BaseAggregator): r""" Implements the Robust Bayesian Committee Machine method of :cite:`Deisenroth15`. Given the posteriors of the experts :math:`p_{i}(y|x) = N(\mu_{i}(x), \sigma_{i}^{2}(x))` for :math:`i=1, 2, ..., M`, we define the joint posterior as a Gaussian with moments .. math :: \mu &= \sigma^{2} \sum_{i} \sigma_{i}^{-2}(x) \mu_{i}(x) \\ \sigma^{-2} &= \sum_{i} \sigma_{i}^{-2}(x) + \bigg( 1 - \sum_{i} \beta_{i} \bigg) \sigma_{**}^{-2} where :math:`\beta_{i}=0.5(\log \sigma_{*}^{2} - \log \sigma_{i}^{2}(x))` is the difference between the prior and the posterior, and :math:`\sigma_{**}^{-2}` is the diagonal of the covariance matrix formed by applying the kernel on all pairs of points in :math:`x`. """
[docs] def aggregate(self) -> tuple[torch.Tensor, torch.Tensor]: """ Combine the predictions of the individual experts into a single RBCM prediction. :return: The mean and variance of the combined experts. """ beta = 0.5 * (torch.log(self.prior_var) - torch.log(self.variances)).reshape(self.n_experts, -1) mean = torch.sum((beta / self.variances) * self.means, dim=0) variance = torch.sum((beta / self.variances) - (beta / self.prior_var), dim=0) + (1 / self.prior_var) return mean / variance, self._make_pseudo_covar(1 / variance)
[docs] class XBCMAggregator(BaseAggregator): r""" Implements the Corrected Bayesian Committee Machine method. We define the joint posterior as in :class:`RBCMAggregator`, but with a correction on \beta. (For further details see :meth:`BaseAggregator._beta_correction`.) """
[docs] def aggregate(self) -> tuple[torch.Tensor, torch.Tensor]: """ Combine the predictions of the individual experts into a single XBCM prediction. :return: The mean and variance of the combined experts. """ delta = 0.5 * (torch.log(self.prior_var) - torch.log(self.variances)).reshape(self.n_experts, -1) beta = self._beta_correction(delta, delta) mean = torch.sum((beta / self.variances) * self.means, dim=0) variance = torch.sum((beta / self.variances) - (beta / self.prior_var), dim=0) + (1 / self.prior_var) return mean / variance, self._make_pseudo_covar(1 / variance)
[docs] class GRBCMAggregator(BaseAggregator): r""" Implements the Generalised Robust Bayesian Committee Machine method of :cite:`Liu18`. Given the posteriors of the experts :math:`p_{i}(y|x) = N(\mu_{i}(x), \sigma_{i}^{2}(x))` for :math:`i=1, 2, ..., M`, we define the joint posterior as a Gaussian with moments .. math :: \mu &= \sigma^{2} \bigg[ \sum_{i=2}^{M} \beta_{i} \sigma_{i}^{-2}(x) \mu_{i}(x) + \bigg( 1 - \sum_{i=2}^{M} \beta_{i} \bigg) \sigma_{1}^{-2}(x) \mu_{1}(x) \bigg] \\ \sigma^{-2} &= \sum_{i=2}^{M} \beta_{i} \sigma_{i}^{-2}(x) + \bigg( 1 - \sum_{i=2}^{M} \beta_{i} \bigg) \sigma_{1}^{-2}(x) where .. math :: \beta_{i} = \begin{cases} 1, & i=2 \\ 0.5(\log \sigma_{1}^{2}(x) - \log \sigma_{i}^{2}(x)), & 3 \leq i \leq M \end{cases} """
[docs] def aggregate(self) -> tuple[torch.Tensor, torch.Tensor]: """ Combine the predictions of the individual experts into a single GRBCM prediction. :return: The mean and variance of the combined experts. """ comm_mean = self.means[0] comm_var = self.variances[0] means = self.means[1:] variances = self.variances[1:] beta = 0.5 * (torch.log(comm_var) - torch.log(variances)).reshape(self.n_experts - 1, -1) beta[0, :] = 1 mean = torch.sum((beta / variances) * means - (beta * comm_mean / comm_var), dim=0) + (comm_mean / comm_var) variance = torch.sum((beta / variances) - (beta / comm_var), dim=0) + (1 / comm_var) return mean / variance, self._make_pseudo_covar(1 / variance)
[docs] class XGRBCMAggregator(BaseAggregator): r""" Implements the Corrected Generalised Robust Bayesian Committee Machine method. We define the joint posterior as in :class:`RBCMAggregator`, but with a correction on \beta. (For further details see :meth:`BaseAggregator._beta_correction`.) """
[docs] def aggregate(self) -> tuple[torch.Tensor, torch.Tensor]: """ Combine the predictions of the individual experts into a single XGRBCM prediction. :return: The mean and variance of the combined experts. """ comm_mean = self.means[0] comm_var = self.variances[0] means = self.means[1:] variances = self.variances[1:] delta = 0.5 * (torch.log(comm_var) - torch.log(variances)).reshape(self.n_experts - 1, -1) delta[0, :] = 1 delta_diff = 0.5 * (torch.log(self.prior_var) - torch.log(variances)).reshape(self.n_experts - 1, -1) beta = self._beta_correction(delta_diff, delta) mean = torch.sum((beta / variances) * means - (beta * comm_mean / comm_var), dim=0) + (comm_mean / comm_var) variance = torch.sum((beta / variances) - (beta / comm_var), dim=0) + (1 / comm_var) return mean / variance, self._make_pseudo_covar(1 / variance)