Sparse variational inference for GPs

[ ]:
# © Crown Copyright GCHQ
#
# Licensed under the GNU General Public License, version 3 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.gnu.org/licenses/gpl-3.0.en.html
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
[ ]:
# This notebook is not compiled into the documentation due to the time taken to run it to get
# a representative analysis. Please run this notebook locally if you wish to see the outputs.

This notebook demonstrates the use of sparse variational GP approximations in Vanguard and the ease of combination with other techniques such as warping and input uncertainty.

[ ]:
random_seed = 1_989
[ ]:
import matplotlib.pyplot as plt
import numpy as np
import torch
from gpytorch.kernels import MaternKernel, ScaleKernel
from gpytorch.mlls import VariationalELBO

from vanguard.datasets.bike import BikeDataset
from vanguard.uncertainty import GaussianUncertaintyGPController
from vanguard.vanilla import GaussianGPController
from vanguard.variational import VariationalInference
from vanguard.warps import SetWarp, warpfunctions

Introduction

Exact inference for GPs is elegant and simple to explain, formulate and code-up. But, as ever, there’s no free lunch: exact inference also has cubic complexity in the size of the training dataset and can be applied only to Gaussian likelihoods. There are practical cases where one cannot reasonably assume a Gaussian likelihood, most obviously in classification problems where the likelihood must be discrete.

Variational inference is one solution to these problems. Instead of computing the true posterior process \(p(f \mid \text{data})\) one introduces a variational approximation \(q(f, u) = q(f \mid u)q(u)\) where the \(u\) are “inducing points” and \(q(u)\) is a prior on the inducing points. In the simplest case, one can think of inducing points as synthetic data points which can be used in conjunction with the Nystrom approximation \(K(X, X) \approx K(X, U)K(U, U)^{-1} K(U, X)\) and so avoid inverting the \(N\times N\) matrix \(K(X,X)\) but instead only the \(M\times M\) matrix \(K(U, U)\), where \(M\) is the user-specified size of the approximation. The actual method used by default in GPyTorch and Vanguard is a little more complicated and can be found in [Hensman15].

From the point of view of the user, all one needs to know is

  • Increasing the number of inducing points increases the size of approximation and so can lead to better approximations. The trade-off is the extra computational effort required to fit the greater number of parameters.

  • Variational GPs are typically more fiddly to fit than exact. One might have to fiddle with learning rates, scheduling etc.

  • Overfitting shouldn’t be an issue due to the variational formulation.

Data

We will use the BikeDataset, with 13 input features. The main point of this dataset for this example notebook is that it’s large. With a 90/10 train/test split, we have ~15.5k training points. Exact GP inference in this case would be very expensive. This dataset is taken from [FanaeeT2013] and was accessed and copied to Github LFS within this repo on 1st July 2024.

[ ]:
DATASET = BikeDataset(rng=np.random.default_rng(random_seed))
[ ]:
plt.hist(DATASET.train_y.detach().cpu().numpy())
plt.xlabel("$y$", fontsize=15)
plt.show()

The regressand is non-negative, so warping could be useful.

We’ll start with a simple visualisation of the concept of inducing points. We’ll restrict to only 2 of the bike features so we can plot without using dimensionality reduction. The Vanguard code below will be introduced later in the notebook, but for now we just do it and look at the inducing points.

[ ]:
N_DATA_POINTS = 500
N_INDUCING_POINTS = 20
DATASET = BikeDataset(num_samples=N_DATA_POINTS, rng=np.random.default_rng(random_seed))


@VariationalInference(n_inducing_points=N_INDUCING_POINTS, ignore_methods=("__init__",))
class GaussianVariationalGPController(GaussianGPController):
    """Does variational inference."""

    pass


class ScaledMaternKernel(ScaleKernel):
    """A scaled matern kernel."""

    def __init__(self):
        super().__init__(MaternKernel(nu=1.5, ard_num_dims=2))


# TODO: Include a batch_size argument in this example when functionality resolved
# https://github.com/gchq/Vanguard/issues/377
gp = GaussianVariationalGPController(
    train_x=DATASET.train_x[:, [4, 7]],
    train_y=DATASET.train_y,
    kernel_class=ScaledMaternKernel,
    y_std=0.001 * torch.mean(torch.abs(DATASET.train_y)),
    marginal_log_likelihood_class=VariationalELBO,
    likelihood_kwargs={"learn_additional_noise": True},
    optim_kwargs={"lr": 0.01},
    rng=np.random.default_rng(random_seed),
)

with gp.metrics_tracker.print_metrics(every=200):
    gp.fit(n_sgd_iters=2000)
[ ]:
inducing_points = gp._gp.variational_strategy.inducing_points.detach().cpu().numpy()
plt_x = DATASET.train_x[:, [4, 7]].detach().cpu().numpy()

plt.scatter(plt_x[:, 0], plt_x[:, 1])
plt.scatter(inducing_points[:, 0], inducing_points[:, 1], marker="x")
plt.show()

Modelling

Let’s do some plain VI. First we need to define the class using a VariationalInference decorator. We have to specify a model that subclasses ApproximateGP, in this case we’ll use the stock SVGPModel which uses GPyTorch’s default variational strategy and distribution. We also need to specify the marginal-log likelihood - in this case we use the standard VariationalELBO from GPyTorch. The VariationalInference decorator deals with the rest.

A large number of inducing points will produce excellent results but take quite a while (suggest using a GPU if you are going to do this). Setting SLOW = False will reduce the number of inducing points used and the number of training epochs to make this notebook run quickly for the purposes of demonstration.

[ ]:
SLOW = False
[ ]:
N_INDUCING_POINTS = 750 if SLOW else 20


@VariationalInference(n_inducing_points=N_INDUCING_POINTS, ignore_methods=("__init__",))
class GaussianVariationalGPController(GaussianGPController):
    """Does variational inference."""

    pass
[ ]:
# TODO: Include a batch_size argument in this example when functionality resolved
# https://github.com/gchq/Vanguard/issues/377
# BATCH_SIZE = 256
# NUM_ITERS = max(len(DATASET.train_x) // BATCH_SIZE, 15) * (100 if SLOW else 10)
NUM_ITERS = max(len(DATASET.train_x), 15) * (100 if SLOW else 10)
print(NUM_ITERS)
[ ]:
class ScaledMaternKernel(ScaleKernel):
    """A scaled matern kernel."""

    def __init__(self):
        super().__init__(MaternKernel(nu=1.5, ard_num_dims=DATASET.train_x.shape[1]))
[ ]:
# TODO: Include a batch_size argument in this example when functionality resolved
# https://github.com/gchq/Vanguard/issues/377
gp = GaussianVariationalGPController(
    train_x=DATASET.train_x,
    train_y=DATASET.train_y,
    kernel_class=ScaledMaternKernel,
    y_std=0.001 * torch.mean(torch.abs(DATASET.train_y)),
    marginal_log_likelihood_class=VariationalELBO,
    likelihood_kwargs={"learn_additional_noise": True},
    optim_kwargs={"lr": 0.01},
    rng=np.random.default_rng(random_seed),
)

with gp.metrics_tracker.print_metrics(every=150):
    gp.fit(n_sgd_iters=NUM_ITERS)
[ ]:
posterior = gp.predictive_likelihood(DATASET.test_x)
DATASET.plot_prediction(*posterior.confidence_interval())
plt.show()

Now let’s look at SVGP combined with compositional warping. We’ll use an affine-log warp to reflect the non-negativity of the data: \(\phi(y) = a + b\log(y)\).

The code to create this GP model in Vanguard is simple. Use a SetWarp decorator to apply the warp, a VariationalInference decorator to make the GP variational and specify the variational model type (SVGPModel here) and variational objective (VariationalELBO here).

[ ]:
warp = warpfunctions.AffineWarpFunction() @ warpfunctions.BoxCoxWarpFunction(lambda_=0)


@SetWarp(warp_function=warp, ignore_methods=("fit", "__init__"))
@VariationalInference(n_inducing_points=N_INDUCING_POINTS, ignore_methods=("__init__",))
class WarpedGaussianVariationalGPController(GaussianGPController):
    """Does variational inference."""

    pass
[ ]:
# TODO: Include a batch_size argument in this example when functionality resolved
# https://github.com/gchq/Vanguard/issues/377
gp = WarpedGaussianVariationalGPController(
    train_x=DATASET.train_x,
    train_y=DATASET.train_y,
    kernel_class=ScaledMaternKernel,
    y_std=0.001 * torch.mean(torch.abs(DATASET.train_y)),
    marginal_log_likelihood_class=VariationalELBO,
    likelihood_kwargs={"learn_additional_noise": True},
    optim_kwargs={"lr": 0.01},
    rng=np.random.default_rng(random_seed),
)

with gp.metrics_tracker.print_metrics(every=150):
    gp.fit(n_sgd_iters=NUM_ITERS)
[ ]:
warp_posterior = gp.predictive_likelihood(DATASET.test_x)
DATASET.plot_prediction(*warp_posterior.confidence_interval())
plt.show()

Warping improves the RMSE overall but is likely to be most useful for smaller \(y\) values, so let’s filter by the true \(y\) value and compare warping to no warping.

[ ]:
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
DATASET.plot_prediction(*warp_posterior.confidence_interval(), y_upper_bound=0.5)
plt.title("Warping. " + plt.gca().title.get_text())
plt.subplot(1, 2, 2)
DATASET.plot_prediction(*posterior.confidence_interval(), y_upper_bound=0.5)
plt.title("No warping. " + plt.gca().title.get_text())
plt.show()

This demonstrates nicely that the warping is working where it matters, preventing impossible negative predictions.

Finally we can demonstrate combining with input uncertainty as well, using some dummy input noise.

[ ]:
warp = warpfunctions.AffineWarpFunction() @ warpfunctions.BoxCoxWarpFunction(lambda_=0)


@SetWarp(warp_function=warp, ignore_all=True)
@VariationalInference(n_inducing_points=N_INDUCING_POINTS, ignore_all=True)
class WarpedGaussianUncertaintyVariationalGPController(GaussianUncertaintyGPController):
    """Does variational inference."""

    pass
[ ]:
# TODO: Include a batch_size argument in this example when functionality resolved
# https://github.com/gchq/Vanguard/issues/377
gp = WarpedGaussianUncertaintyVariationalGPController(
    train_x=DATASET.train_x,
    train_x_std=0.1,
    train_y=DATASET.train_y,
    kernel_class=ScaledMaternKernel,
    y_std=0.001 * torch.mean(torch.abs(DATASET.train_y)),
    marginal_log_likelihood_class=VariationalELBO,
    likelihood_kwargs={"learn_additional_noise": True},
    optim_kwargs={"lr": 0.01},
    rng=np.random.default_rng(random_seed),
)

with gp.metrics_tracker.print_metrics(every=150):
    gp.fit(n_sgd_iters=NUM_ITERS)
[ ]:
posterior = gp.predictive_likelihood(DATASET.test_x)
DATASET.plot_prediction(*warp_posterior.confidence_interval())
plt.show()

Conclusions

This short example demonstrates that compositional warping can be combined with sparse variational GP inference in Vanguard using very little code. We have demonstrated good results on a real-world dataset with ~15.5k training items. We have compared plain SVGP with a warped SVGP and found similar performance with the warped model. Other datasets may exhibit a stronger preference for warping, but we have shown that, for low values of the regressand (close to zero), the warped GP is much better, as it makes no impossible negative predictions. We have shown that combining warping and variational inference is feasible and the training is no more difficult that a plain SVGP.

In addition, we have provided a proof-of-concept demonstration that warping, input uncertainty and variational GPs can be combined simply within Vanguard and trained successfully.