# © Crown Copyright GCHQ
#
# Licensed under the GNU General Public License, version 3 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.gnu.org/licenses/gpl-3.0.en.html
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Contains the HyperparameterCollection class.
"""
from collections.abc import Iterator
from typing import Any, TypeVar
import gpytorch
import torch
from gpytorch.distributions import MultivariateNormal
from vanguard.hierarchical.hyperparameter import BayesianHyperparameter
HyperparameterT = TypeVar("HyperparameterT", bound=BayesianHyperparameter)
ModuleT = TypeVar("ModuleT", bound=gpytorch.module.Module)
VariationalDistributionT = TypeVar(
"VariationalDistributionT",
bound=gpytorch.variational._VariationalDistribution, # pylint: disable=protected-access
)
[docs]
class HyperparameterCollection:
"""
Represents a collection of hyperparameters for a controller.
This class will delete the original torch parameters for the hyperparameters
so that they can be replaced by batches of parameters representing samples
from a distribution over those hyperparameters.
"""
[docs]
def __init__(
self,
module_hyperparameter_pairs: list[tuple[ModuleT, HyperparameterT]],
sample_shape: torch.Size,
variational_distribution_class: type[VariationalDistributionT],
) -> None:
"""
Initialise self.
:param module_hyperparameter_pairs: A list of (module, hyperparameter) pairs.
:param sample_shape: The shape of the sample tensor.
:param variational_distribution_class: The variational
distribution to use for the raw hyperparameters' posterior.
"""
self.sample_shape = sample_shape
self.module_hyperparameter_pairs = module_hyperparameter_pairs
self.variational_dimension = sum(
self._parameter_index_size(hyperparameter) for _, hyperparameter in module_hyperparameter_pairs
)
self.variational_distribution = variational_distribution_class(self.variational_dimension)
self.prior_mean = torch.zeros(self.variational_dimension)
self.prior_variance = torch.ones(self.variational_dimension)
prior_covariance_matrix = torch.diag(self.prior_variance)
self._inverse_prior_covariance_matrix = torch.diag(1 / self.prior_variance)
self.prior = MultivariateNormal(self.prior_mean, prior_covariance_matrix)
self.sample_tensor = None
self._hyperparameter_to_index = {}
self._delete_point_estimate_hyperparameters()
self._sample()
self._initialise_variational_parameters_and_constants()
self.prior_mean.requires_grad = False
self.prior_variance.requires_grad = False
[docs]
def sample_and_update(self) -> None:
"""Sample from the collection, and update the hyperparameters."""
self._sample()
for owner_module, hyperparameter in self.module_hyperparameter_pairs:
self._update_hyperparameter_value(owner_module, hyperparameter)
[docs]
def kl_term(self) -> torch.Tensor:
"""Compute the KL divergence term in the ELBO."""
mu = self.variational_distribution.variational_mean
sigma = self.variational_distribution().covariance_matrix
mu_0 = self.prior.mean
sigma_0 = self.prior.covariance_matrix
sigma_0_inv = self._inverse_prior_covariance_matrix
trace_term = torch.trace(sigma_0_inv @ sigma)
mean_diff = mu_0 - mu
mean_term = mean_diff.t() @ sigma_0_inv @ mean_diff
det_term = torch.log(torch.linalg.det(sigma_0) / torch.linalg.det(sigma)) # pylint: disable=not-callable
return (trace_term + mean_term + det_term - mu.shape[0]) / 2
def _sample(self) -> None:
"""Sample from the collection."""
distribution = self.variational_distribution()
self.sample_tensor = distribution.rsample(self.sample_shape)
def _initialise_variational_parameters_and_constants(self) -> None:
"""Infer an index into the sample tensor for each hyperparameter, and initialise accordingly."""
variational_index = 0
for owner_module, hyperparameter in self.module_hyperparameter_pairs:
index_size = self._parameter_index_size(hyperparameter)
index = (slice(None), slice(variational_index, variational_index + index_size))
self._hyperparameter_to_index[(owner_module, hyperparameter.raw_name)] = index
self._update_hyperparameter_value(owner_module, hyperparameter)
mean_var_slice = slice(variational_index, variational_index + index_size)
self.prior_mean[mean_var_slice] = hyperparameter.prior_mean
self.prior_variance[mean_var_slice] = hyperparameter.prior_variance
variational_index += index_size
def _parameter_index_size(self, hyperparameter: HyperparameterT) -> int:
"""
Get the size of the index into the sample tensor corresponding to the hyperparameter.
In order to ensure that all hyperparameters fit into the sample tensor (whose size can vary),
this method will scale the size of the hyperparameter in order to return the correct
proportional index size.
"""
return hyperparameter.numel() // self.sample_shape[0]
def _update_hyperparameter_value(self, owner_module: ModuleT, hyperparameter: HyperparameterT) -> None:
"""Update the value of a hyperparameter within its owner module."""
index = self._hyperparameter_to_index[(owner_module, hyperparameter.raw_name)]
sliced_tensor = self.sample_tensor[index].reshape(hyperparameter.raw_shape)
setattr(owner_module, hyperparameter.raw_name, sliced_tensor)
def _delete_point_estimate_hyperparameters(self) -> None:
for owner_module, hyperparameter in self.module_hyperparameter_pairs:
try:
delattr(owner_module, hyperparameter.raw_name)
except AttributeError:
continue
class OnePointHyperparameterCollection:
"""
Represents a collection of hyperparameters for a controller.
This class keeps hyperparameters in their original shape and just manages
the representation of the hyperparameters as a single combined tensor.
It also manages the prior placed over the hyperparameters.
"""
def __init__(self, module_hyperparameter_pairs: list[tuple[ModuleT, HyperparameterT]]) -> None:
"""
Initialise self.
:param module_hyperparameter_pairs: A list of (module, hyperparameter) pairs.
"""
self.module_hyperparameter_pairs = module_hyperparameter_pairs
self.hyperparameter_dimension = sum(
self._parameter_index_size(hyperparameter) for _, hyperparameter in module_hyperparameter_pairs
)
self.prior_mean = torch.zeros(self.hyperparameter_dimension)
self.prior_variance = torch.ones(self.hyperparameter_dimension)
self._hyperparameter_to_index = {}
self._initialise_hyperparameter_indices()
self.prior_mean.requires_grad = False
self.prior_variance.requires_grad = False
prior_covariance_matrix = torch.diag(self.prior_variance)
self.prior = MultivariateNormal(self.prior_mean, prior_covariance_matrix)
self.log_partition_function = self.prior.log_prob(self.prior_mean)
def __iter__(self) -> Iterator[Any]:
return (getattr(module, hyperparameter.raw_name) for module, hyperparameter in self.module_hyperparameter_pairs)
def __len__(self) -> int:
return len(self.module_hyperparameter_pairs)
@property
def hyperparameter_tensor(self) -> torch.Tensor:
"""Return the representation of the hyperparameters as a single combined tensor."""
tensor = torch.zeros(self.hyperparameter_dimension)
for owner_module, hyperparameter in self.module_hyperparameter_pairs:
index = self._hyperparameter_to_index[(owner_module, hyperparameter.raw_name)]
tensor[index] = getattr(owner_module, hyperparameter.raw_name)
return tensor
@hyperparameter_tensor.setter
def hyperparameter_tensor(self, value: torch.Tensor) -> None:
"""Update the hyperparameters based from a single combined tensor."""
for owner_module, hyperparameter in self.module_hyperparameter_pairs:
index = self._hyperparameter_to_index[(owner_module, hyperparameter.raw_name)]
shape = getattr(owner_module, hyperparameter.raw_name).shape
try:
setattr(owner_module, hyperparameter.raw_name, value[index].reshape(shape))
except TypeError:
delattr(owner_module, hyperparameter.raw_name)
setattr(owner_module, hyperparameter.raw_name, value[index].reshape(shape))
def log_prior_term(self) -> torch.Tensor:
"""
Compute the log un-normalised prior density.
The partition function has in principle no effect on the optimisation
but can skew the loss values unhelpfully, so we remove it.
"""
return self.prior.log_prob(self.hyperparameter_tensor) - self.log_partition_function
def _initialise_hyperparameter_indices(self) -> None:
"""Infer an index into the sample tensor for each hyperparameter, and initialise accordingly."""
variational_index = 0
for owner_module, hyperparameter in self.module_hyperparameter_pairs:
index_size = self._parameter_index_size(hyperparameter)
index = slice(variational_index, variational_index + index_size)
self._hyperparameter_to_index[(owner_module, hyperparameter.raw_name)] = index
mean_var_slice = slice(variational_index, variational_index + index_size)
self.prior_mean[mean_var_slice] = hyperparameter.prior_mean
self.prior_variance[mean_var_slice] = hyperparameter.prior_variance
variational_index += index_size
def _parameter_index_size(self, hyperparameter: HyperparameterT) -> int:
"""
Get the size of the index into the sample tensor corresponding to the hyperparameter.
In order to ensure that all hyperparameters fit into the sample tensor (whose size can vary),
this method will scale the size of the hyperparameter in order to return the correct
proportional index size.
"""
return hyperparameter.numel()