Bayesian treatment of hyperparameters with Laplace approximations

[ ]:
# © Crown Copyright GCHQ
#
# Licensed under the GNU General Public License, version 3 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.gnu.org/licenses/gpl-3.0.en.html
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
[ ]:
# This notebook is not compiled into the documentation due to the time taken to run it to get
# a representative analysis. Please run this notebook locally if you wish to see the outputs.

Gaussian process models are already Bayesian, in that an unknown function is given a GP prior and then a posterior is inferred over that function. Typically, however, the GP prior is defined using hyperparameters that are contained in the prior mean and kernel functions. There can also be hyperparameters in the likelihood. Sometimes one might have a good idea of what these prior hyperparameters should be (e.g. if the kernel is periodic and the data have some known periodicity). Otherwise, it is standard to learn good hyperparameters by optimising the log marginal likelihood with respect to them. This process yields point estimates of the hyperparameters and is not Bayesian. In a lot of practical applications, there will be considerable prior uncertainty about the value of these hyperparameters, and point estimates obtained from likelihood maximisation will yield over-confident posteriors. The correct thing to do is place a prior over the hyperparameters themselves and then infer the posterior. The full posterior process is then

\[p(f \mid \mathcal{D}) = \int d\theta p(\theta | \mathcal{D}) p(f \mid \mathcal{D}, \theta)\]

Here \(p(f \mid \mathcal{D}, \theta)\) is the usual GP posterior with some fixed hyperparameters \(\theta\) and \(p(\theta\mid \mathcal{D})\) is the posterior distribution over the hyperparameters. Apart from the extra uncertainty accounted for by this approach, note also that, even if the conditional GP posteriors \(p(f \mid \mathcal{D}, \theta)\) are Gaussian, the overall posterior above will not be, so the Bayesian treatment of hyperparameters allows for a much richer class of posteriors.

We showcase how to approximate the intractable hyperparameter posterior \(p(\theta \mid \mathcal{D})\) using an approach based on the Laplace approximation.

Suppose that log marginal likelihood maximisation has produced optimised hyperparameters \(\theta_*\). The Laplace approximation to \(p(\theta \mid \mathcal{D})\) is then

\[p(\theta \mid \mathcal{D}) \approx \mathcal{N}(\theta \mid \theta_*, H^{-1})\]

where the matrix \(H\) is the Hessian of log marginal likelihood at \(\theta_*\), i.e.

\[H = \frac{\partial^2 L}{\partial \theta^2}\Bigg|_{\theta=\theta_*}\]

where \(L(\theta) = \log p(\mathcal{D}\mid \theta)\) (under the GP model).

The Laplace approximation has been found to be a competitive approach in, say, Bayesian neural networks. It has the great advantage of being simple, and particularly in the case of GP hyperparameters, it is very efficient compared to other approaches, since the dimension of \(\theta\) is typically small, so exact Hessians can be computed using automatic differentiation.

A clear issue with this approach as described is the need to invert \(H\). In practice one finds that the log marginal likelihood surfaces contains at least some very flat directions around \(\theta_*\), i.e. \(H\) has some very small eigenvalues. These correspond to directions in which the Laplace approximation fails (or nearly fails), since it is only valid in the case that \(H\) is positive-definite. Note that the Hessian will not generically contain any negative eigenvalues, except very small ones in the almost-flat directions; this is to be expected if we assume that the optimisation procedure can escape such obvious saddle points and only becomes stuck in approximate local minima. Naively inverting the Hessian in the presence of such small eigenvalues will, at best, result in a covariance matrix with some extremely large variances, and at worst lead to an invalid indefinite matrix. The latter problem can be mitigated with covariance cleaning techniques such as linear shrinkage, i.e. replace \(H\) by \((1-\beta)H + \beta I\) for some small \(\beta\in(0, 1)\), however this will not solve the former problem.

Our aim here is not to provide perfect representation of the hyperparameter posteriors, rather we aim simply to provide some improvement in uncertainty quantification over the baseline approach of plain marginal likelihood maximisation with point estimates. We therefore accept that the full Laplace approximation is not available and restrict only to directions in which the Hessian eigenvalues are not too small. More precisely, use an eigendecomposition \(H = U^T\Lambda U\) and let \(r(\Lambda)\) be a diagonal matrix with \(r(\Lambda)_i = (\lambda_i)^{-1}\) if \(\lambda_i > \epsilon\) and \(r(\Lambda)_i = \eta\) otherwise, where \(\epsilon,\eta>0\) are small parameters. We then replace \(H^{-1}\) in the Laplace approximation by \(\Sigma = U^T r(\Lambda)U\).

Thus \(\Sigma\) preserves the covariance structure of the Laplace approximation in the well-behaved directions, while essentially using point estimates in the badly-behaved directions. Note however that the eigendirections of \(H\) are not the same as the coordinate directions corresponding to the individual hyperparameters themselves, so our approach is not the same as treating certain hyperparameters as point estimates (though in practice we expect that this may be approximately the case).

Even with the above regularisation of the Laplace approximation covariance matrix, the resulting posterior process \(\int d\theta p(\cdot \mid \theta, \mathcal{D})p(\theta \mid \mathcal{D})\) may still have impractically large posterior predictive uncertainty. Without a definitive way of saying that this accurate, we must be pragmatic and seek to make the posterior predictive actually useful. In the Bayesian deep learning literature, it has been found that cold posteriors can give superior approximations to Bayesian NN posteriors. In our case this amounts to

\[p(\theta \mid \mathcal{D}) \approx \mathcal{N}(\theta \mid \theta_*, T\Sigma)\]

where \(T>0\) is a temperature parameter, and \(T<1\) corresponds to the cold posterior regime.

All this in place, we have a practical method for providing users with some means of accounting for the uncertainty in their models’ hyperparameters. At \(T=0\) we recover exactly the point estimate hyperparameter posterior, so by gradually increasing \(T\) the user can explore, in a principled and efficient way, the posterior uncertainty that is hidden by their point estimate hyperparameter estimation. The temperature also provides another parameter to tune to maximise marginal log probability and so hopefully provide some greater robustness to overfitting.

[ ]:
random_seed = 1_989
num_iters = 100
[ ]:
import warnings

import matplotlib.pyplot as plt
import numpy as np
import torch
from gpytorch import constraints, kernels, likelihoods, means
from tqdm import tqdm

from vanguard.datasets.air_passengers import AirPassengers
from vanguard.datasets.synthetic import SyntheticDataset, complicated_f
from vanguard.hierarchical import BayesianHyperparameters, LaplaceHierarchicalHyperparameters
from vanguard.learning import LearnYNoise
from vanguard.normalise import NormaliseY
from vanguard.vanilla import GaussianGPController

Data

We will use SyntheticDataset to begin with.

[ ]:
DATASET = SyntheticDataset(functions=(complicated_f,), rng=np.random.default_rng(random_seed))
train_test_split_index = len(DATASET.train_x)

Modelling

Let’s start by constructing a standard GP models with point estimate hyperparameters for comparison.

[ ]:
class ScaledRBFKernel(kernels.ScaleKernel):
    def __init__(self, active_dims=None, batch_shape=torch.Size([])):
        super().__init__(
            kernels.RBFKernel(active_dims=active_dims, batch_shape=batch_shape),
            batch_shape=batch_shape,
        )


class ScaledMaternKernel(kernels.ScaleKernel):
    def __init__(self, active_dims=None, batch_shape=torch.Size([])):
        super().__init__(
            kernels.MaternKernel(nu=0.5, active_dims=active_dims, batch_shape=batch_shape),
            batch_shape=batch_shape,
        )


class ScaledPeriodicKernel(kernels.ScaleKernel):
    def __init__(self, active_dims=None, batch_shape=torch.Size([])):
        super().__init__(
            kernels.PeriodicKernel(active_dims=active_dims, batch_shape=batch_shape),
            batch_shape=batch_shape,
        )


class Kernel(kernels.ProductKernel):
    def __init__(self, batch_shape=torch.Size([])):
        super().__init__(
            ScaledRBFKernel(batch_shape=batch_shape),
            kernels.PeriodicKernel(batch_shape=batch_shape),
        )
[ ]:
@LearnYNoise(ignore_all=True)
class PointEstimateController(GaussianGPController):
    pass


gp = PointEstimateController(
    train_x=DATASET.train_x,
    train_y=DATASET.train_y,
    kernel_class=Kernel,
    y_std=DATASET.train_y_std,
    optim_kwargs={"lr": 0.5},
    rng=np.random.default_rng(random_seed),
)

with gp.metrics_tracker.print_metrics(every=20):
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        gp.fit(n_sgd_iters=num_iters)

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    posterior = gp.posterior_over_point(DATASET.test_x)
    likelihood = gp.predictive_likelihood(DATASET.test_x)

mu, lower, upper = posterior.confidence_interval()
l_mu, l_lower, l_upper = likelihood.confidence_interval()

# Convert from tensors to numpy arrays for plotting
l_mu = l_mu.detach().cpu().numpy()
l_lower = l_lower.detach().cpu().numpy()
l_upper = l_upper.detach().cpu().numpy()
plt_x = DATASET.test_x.ravel().detach().cpu().numpy()

plt.figure(figsize=(15, 7))
plt.plot(plt_x, l_mu, label="likelihood")
plt.fill_between(plt_x, l_lower, l_upper, alpha=0.2, label="likelihood CI")
plt.plot(plt_x, DATASET.test_y.detach().cpu().numpy(), "x", label="data")
plt.grid(which="both")
plt.legend()
print(f"Log probability: {likelihood.log_probability(DATASET.test_y)}")

Now we will convert this model to use Bayesian inference over its hyperparameters.

Any kernels or means that are to be given Bayesian hyperparameters must be decorated with BayesianHyperparameters. This may seem clunky, but it allows for fine-grained control over which hyperparameters are made Bayesian.

[ ]:
@BayesianHyperparameters()
class BayesianRBFKernel(kernels.RBFKernel):
    pass


@BayesianHyperparameters()
class BayesianPeriodicKernel(kernels.PeriodicKernel):
    pass


@BayesianHyperparameters()
class BayesianScaleKernel(kernels.ScaleKernel):
    pass


class BayesianScaledRBFKernel(BayesianScaleKernel):
    def __init__(self, active_dims=None, batch_shape=torch.Size([])):
        super().__init__(
            BayesianRBFKernel(active_dims=active_dims, batch_shape=batch_shape),
            batch_shape=batch_shape,
        )


class BayesianScaledPeriodicKernel(BayesianScaleKernel):
    def __init__(self, batch_shape=torch.Size([]), active_dims=None):
        super().__init__(
            BayesianPeriodicKernel(active_dims=active_dims, batch_shape=batch_shape),
            batch_shape=batch_shape,
        )


class BayesianKernel(kernels.ProductKernel):
    def __init__(self, batch_shape=torch.Size([])):
        super().__init__(
            BayesianScaledRBFKernel(batch_shape=batch_shape),
            BayesianPeriodicKernel(batch_shape=batch_shape),
        )


@BayesianHyperparameters()
class BayesianConstantMean(means.ConstantMean):
    pass


@BayesianHyperparameters()
class BayesianFixedNoiseGaussianLikelihood(likelihoods.FixedNoiseGaussianLikelihood):
    pass

The decorator LaplaceHierarchicalHyperparameters converts a controller to approximate the hyperparameter posterior use a Laplace approximation. The argument num_mc_samples defines the number of samples to draw from the variational hyperparameter posterior distribution when approximating integrals using Monte Carlo integration.

We can specify a temperature for the hyperparameter posterior in the LaplaceHierarchicalHyperparameters decorator, but leaving it blank will set the temperature automatically using a heuristic (to give a covariance matrix with unit trace).

[ ]:
@LaplaceHierarchicalHyperparameters(num_mc_samples=100, ignore_all=True)
class FullBayesianController(PointEstimateController):
    pass


gp = FullBayesianController(
    train_x=DATASET.train_x,
    train_y=DATASET.train_y,
    kernel_class=BayesianKernel,
    y_std=DATASET.train_y_std,
    mean_class=BayesianConstantMean,
    likelihood_class=BayesianFixedNoiseGaussianLikelihood,
    optim_kwargs={"lr": 0.5},
    rng=np.random.default_rng(random_seed),
)

with gp.metrics_tracker.print_metrics(every=20):
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        gp.fit(n_sgd_iters=num_iters)

Let’s have a look at the hyperparameter posterior mean and covariance.

[ ]:
plt.imshow(gp.hyperparameter_posterior.covariance_matrix.detach().cpu().numpy())
plt.colorbar()
print(gp.hyperparameter_posterior.mean.detach().cpu().numpy())
[ ]:
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    posterior = gp.posterior_over_point(DATASET.test_x)
    likelihood = gp.predictive_likelihood(DATASET.test_x)

mu, lower, upper = posterior.confidence_interval()
l_mu, l_lower, l_upper = likelihood.confidence_interval()
plt_x = DATASET.test_x.ravel()

# Convert from tensors to numpy arrays for plotting
plt_x = plt_x.detach().cpu().numpy()
l_mu = l_mu.detach().cpu().numpy()
l_lower = l_lower.detach().cpu().numpy()
l_upper = l_upper.detach().cpu().numpy()
plt_y = DATASET.test_y.detach().cpu().numpy()

plt.figure(figsize=(10, 4))
plt.plot(plt_x, l_mu, label="likelihood")
plt.fill_between(plt_x, l_lower, l_upper, alpha=0.2, label="likelihood CI")
plt.plot(plt_x, plt_y, "x", label="data")
plt.grid(which="both")
plt.legend()
print(f"Log probability: {likelihood.log_probability(DATASET.test_y)}")

Let’s look at some posterior samples.

[ ]:
plt.figure(figsize=(10, 4))
plt.plot(posterior.sample(500).T.detach().cpu().numpy())
plt.show()

Let’s try varying the temperature to see what gets the best posterior log probability on the test set.

[ ]:
temps = np.logspace(-5, 0, 20)
log_probs = []
for _ in tqdm(range(20)):
    lp = []
    for temperature in temps:
        gp.temperature = temperature
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            likelihood = gp.predictive_likelihood(DATASET.test_x)
        lp.append(likelihood.log_probability(DATASET.test_y))
    log_probs.append(lp)

log_probs = np.array(log_probs)
plt.plot(temps, log_probs.T)
plt.xscale("log")
plt.grid()
plt.show()
[ ]:
mean_log_probs = np.mean(log_probs, axis=0)
plt.plot(temps, mean_log_probs, label="empirical mean")
plt.vlines(
    [gp.auto_temperature()],
    [min(mean_log_probs)],
    [max(mean_log_probs)],
    linestyles="--",
    color="r",
    label="auto temperature",
)
plt.xscale("log")
plt.ylabel("log probability")
plt.xlabel("temperature")
plt.legend()
plt.grid()
plt.show()

It appears that the automatically selected temperature is pretty good.

Real data: airline delays

This dataset is taken from the Kats Repository in the Facebook research repo, see [Jiang_KATS_2022].

[ ]:
data = AirPassengers()
df = data._load_data()

train_test_split_index = 100
x = df.index.values.astype(float)
y = df.y.values.astype(float)
train_x, train_y = x[:train_test_split_index], y[:train_test_split_index]
test_x, test_y = x[train_test_split_index:], y[train_test_split_index:]

We’ll build a kernel suitable for time series. We’ll constrain the linear kernel as big values for its variance can easily lead to posterior blow-up.

[ ]:
linear_co_constraint = constraints.Interval(0.0, 1.0)


class AirlineKernel(kernels.AdditiveKernel):
    def __init__(self, batch_shape=torch.Size([])):
        local_period = ScaledRBFKernel(batch_shape=batch_shape)
        local_period *= kernels.PeriodicKernel(batch_shape=batch_shape)
        linear = kernels.LinearKernel(
            batch_shape=batch_shape,
            variance_constraint=linear_co_constraint,
        )
        rbf = ScaledRBFKernel(batch_shape=batch_shape)
        super().__init__(local_period, linear, rbf)

We’ll apply SoftPlus warping to impose positivity and some fixed affine rescaling to prevent numerical issues.

[ ]:
@NormaliseY()
@LearnYNoise(ignore_all=True)
class PointEstimateController(GaussianGPController):
    pass


gp = PointEstimateController(
    train_x=train_x,
    train_y=train_y,
    kernel_class=AirlineKernel,
    y_std=0,
    optim_kwargs={"lr": 0.1},
    rng=np.random.default_rng(random_seed),
)

with gp.metrics_tracker.print_metrics(every=20):
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        gp.fit(n_sgd_iters=num_iters)
[ ]:
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    posterior = gp.posterior_over_point(x)
    likelihood = gp.predictive_likelihood(x)

mu, lower, upper = posterior.confidence_interval()
l_mu, l_lower, l_upper = likelihood.confidence_interval()

plt_x = x.ravel()

# Convert from tensors to numpy arrays for plotting
l_mu = l_mu.detach().cpu().numpy()
l_lower = l_lower.detach().cpu().numpy()
l_upper = l_upper.detach().cpu().numpy()

plt.figure(figsize=(15, 7))
plt.plot(plt_x, l_mu, label="likelihood")
plt.fill_between(plt_x, l_lower, l_upper, alpha=0.2, label="likelihood CI")
plt.plot(train_x, train_y, "x", label="train data")
plt.plot(test_x, test_y, "o", label="test data")
plt.grid(which="both")
plt.legend()
print(f"Log probability: {likelihood.log_probability(torch.tensor(y))}")

Below we just directly convert to kernel into a Bayesian one.

[ ]:
@BayesianHyperparameters()
class BayesianLinearKernel(kernels.LinearKernel):
    pass


class BayesianAirlineKernel(kernels.AdditiveKernel):
    def __init__(self, batch_shape=torch.Size([])):
        periodic = BayesianPeriodicKernel(batch_shape=batch_shape)
        local_period = BayesianScaledRBFKernel(batch_shape=batch_shape) * periodic
        linear = BayesianLinearKernel(
            batch_shape=batch_shape,
            variance_constraint=linear_co_constraint,
        )
        rbf = BayesianScaledRBFKernel(batch_shape=batch_shape)
        super().__init__(local_period, linear, rbf)
[ ]:
@LaplaceHierarchicalHyperparameters(num_mc_samples=100, ignore_all=True)
class FullBayesianController(PointEstimateController):
    pass


laplace_gp = FullBayesianController(
    train_x=train_x,
    train_y=train_y,
    kernel_class=BayesianAirlineKernel,
    y_std=0,
    mean_class=BayesianConstantMean,
    likelihood_class=BayesianFixedNoiseGaussianLikelihood,
    optim_kwargs={"lr": 0.1},
    rng=np.random.default_rng(random_seed),
)

with laplace_gp.metrics_tracker.print_metrics(every=20):
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        laplace_gp.fit(n_sgd_iters=num_iters)
[ ]:
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    posterior = laplace_gp.posterior_over_point(x)
    laplace_likelihood = laplace_gp.predictive_likelihood(x)

laplace_mu, laplace_lower, laplace_upper = laplace_likelihood.confidence_interval()

# Convert from tensors to numpy arrays for plotting
laplace_mu = laplace_mu.detach().cpu().numpy()
laplace_lower = laplace_lower.detach().cpu().numpy()
laplace_upper = laplace_upper.detach().cpu().numpy()

plt_x = x.ravel()
plt.figure(figsize=(15, 7))
plt.plot(plt_x, laplace_mu, label="likelihood")
plt.fill_between(plt_x, laplace_lower, laplace_upper, alpha=0.2, label="likelihood CI")
plt.plot(train_x, train_y, "x", label="train data")
plt.plot(test_x, test_y, "o", label="test data")
plt.grid(which="both")
plt.legend()
print(f"Log probability: {laplace_likelihood.log_probability(torch.tensor(y))}")

Let’s have a look at the raw hyperparameter posterior covariance matrix.

[ ]:
plt.imshow(laplace_gp.hyperparameter_posterior.covariance_matrix.detach().cpu().numpy())
plt.colorbar()
print(laplace_gp.hyperparameter_posterior.mean.detach().cpu().numpy())
[ ]:
temps = np.logspace(-5, 0, 20)
log_probs = []
for run_index in tqdm(range(20)):
    lp = []
    for temperature in temps:
        try:
            laplace_gp.temperature = temperature
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                likelihood = laplace_gp.predictive_likelihood(test_x)
            lp.append(likelihood.log_probability(torch.tensor(test_y)).detach().cpu().numpy())
        except Exception:
            print(f"Skipping temperature {temperature} run {run_index + 1} due to numerical issues")
            lp.append(np.nan)

    log_probs.append(lp)

log_probs = np.array(log_probs)
plt.plot(temps, log_probs.T)
plt.xscale("log")
plt.grid()
plt.show()
[ ]:
mean_log_probs = np.mean(log_probs, axis=0)
plt.plot(temps, mean_log_probs, label="empirical mean")
plt.vlines(
    [laplace_gp.auto_temperature()],
    [min(mean_log_probs)],
    [max(mean_log_probs)],
    linestyles="--",
    color="r",
    label="auto temperature",
)
plt.xscale("log")
plt.ylabel("log probability")
plt.xlabel("temperature")
plt.legend()
plt.grid()
plt.show()