{ "cells": [ { "cell_type": "raw", "id": "40b658c2", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "Sparse variational inference for GPs\n", "====================================" ] }, { "cell_type": "code", "execution_count": null, "id": "ed0e22ddc65fb904", "metadata": { "jupyter": { "is_executing": true } }, "outputs": [], "source": [ "# © Crown Copyright GCHQ\n", "#\n", "# Licensed under the GNU General Public License, version 3 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.gnu.org/licenses/gpl-3.0.en.html\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "code", "execution_count": null, "id": "af017f255a3131a5", "metadata": {}, "outputs": [], "source": [ "# This notebook is not compiled into the documentation due to the time taken to run it to get\n", "# a representative analysis. Please run this notebook locally if you wish to see the outputs." ] }, { "cell_type": "raw", "id": "f5591922dba11883", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "This notebook demonstrates the use of sparse variational GP approximations in Vanguard and the ease of combination with other techniques such as warping and input uncertainty." ] }, { "cell_type": "code", "execution_count": null, "id": "ba1aca0388745a99", "metadata": {}, "outputs": [], "source": [ "random_seed = 1_989" ] }, { "cell_type": "code", "execution_count": null, "id": "b54906cf", "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import torch\n", "from gpytorch.kernels import MaternKernel, ScaleKernel\n", "from gpytorch.mlls import VariationalELBO\n", "\n", "from vanguard.datasets.bike import BikeDataset\n", "from vanguard.uncertainty import GaussianUncertaintyGPController\n", "from vanguard.vanilla import GaussianGPController\n", "from vanguard.variational import VariationalInference\n", "from vanguard.warps import SetWarp, warpfunctions" ] }, { "cell_type": "raw", "id": "e4437858", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "Introduction\n", "------------\n", "\n", "Exact inference for GPs is elegant and simple to explain, formulate and code-up. But, as ever, there's no free lunch: exact inference also has cubic complexity in the size of the training dataset and can be applied only to Gaussian likelihoods. There are practical cases where one cannot reasonably assume a Gaussian likelihood, most obviously in classification problems where the likelihood must be discrete. \n", "\n", "Variational inference is one solution to these problems. Instead of computing the true posterior process :math:`p(f \\mid \\text{data})` one introduces a variational approximation :math:`q(f, u) = q(f \\mid u)q(u)` where the :math:`u` are \"inducing points\" and :math:`q(u)` is a prior on the inducing points. In the simplest case, one can think of inducing points as synthetic data points which can be used in conjunction with the Nystrom approximation :math:`K(X, X) \\approx K(X, U)K(U, U)^{-1} K(U, X)` and so avoid inverting the :math:`N\\times N` matrix :math:`K(X,X)` but instead only the :math:`M\\times M` matrix :math:`K(U, U)`, where :math:`M` is the user-specified size of the approximation. The actual method used by default in GPyTorch and Vanguard is a little more complicated and can be found in :cite:`Hensman15`.\n", "\n", "From the point of view of the user, all one needs to know is\n", "\n", "* Increasing the number of inducing points increases the size of approximation and so can lead to better approximations. The trade-off is the extra computational effort required to fit the greater number of parameters. \n", "* Variational GPs are typically more fiddly to fit than exact. One might have to fiddle with learning rates, scheduling etc. \n", "* Overfitting shouldn't be an issue due to the variational formulation. " ] }, { "cell_type": "raw", "id": "a662efe2", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "Data\n", "----\n", "\n", "We will use the :py:class:`~vanguard.datasets.bike.BikeDataset`, with 13 input features. The main point of this dataset for this example notebook is that it's large. With a 90/10 train/test split, we have ~15.5k training points. Exact GP inference in this case would be very expensive. This dataset is taken from :cite:`FanaeeT2013` and was accessed and copied to Github LFS within this repo on 1st July 2024. " ] }, { "cell_type": "code", "execution_count": null, "id": "ea718384", "metadata": {}, "outputs": [], "source": [ "DATASET = BikeDataset(rng=np.random.default_rng(random_seed))" ] }, { "cell_type": "code", "execution_count": null, "id": "3e4d5b0d", "metadata": {}, "outputs": [], "source": [ "plt.hist(DATASET.train_y.detach().cpu().numpy())\n", "plt.xlabel(\"$y$\", fontsize=15)\n", "plt.show()" ] }, { "cell_type": "raw", "id": "83db11821544a15", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "The regressand is non-negative, so warping could be useful.\n", "\n", "We'll start with a simple visualisation of the concept of inducing points. We'll restrict to only 2 of the bike features so we can plot without using dimensionality reduction. The Vanguard code below will be introduced later in the notebook, but for now we just do it and look at the inducing points." ] }, { "cell_type": "code", "execution_count": null, "id": "508c4b66", "metadata": {}, "outputs": [], "source": [ "N_DATA_POINTS = 500\n", "N_INDUCING_POINTS = 20\n", "DATASET = BikeDataset(num_samples=N_DATA_POINTS, rng=np.random.default_rng(random_seed))\n", "\n", "\n", "@VariationalInference(n_inducing_points=N_INDUCING_POINTS, ignore_methods=(\"__init__\",))\n", "class GaussianVariationalGPController(GaussianGPController):\n", " \"\"\"Does variational inference.\"\"\"\n", "\n", " pass\n", "\n", "\n", "class ScaledMaternKernel(ScaleKernel):\n", " \"\"\"A scaled matern kernel.\"\"\"\n", "\n", " def __init__(self):\n", " super().__init__(MaternKernel(nu=1.5, ard_num_dims=2))\n", "\n", "\n", "# TODO: Include a batch_size argument in this example when functionality resolved\n", "# https://github.com/gchq/Vanguard/issues/377\n", "gp = GaussianVariationalGPController(\n", " train_x=DATASET.train_x[:, [4, 7]],\n", " train_y=DATASET.train_y,\n", " kernel_class=ScaledMaternKernel,\n", " y_std=0.001 * torch.mean(torch.abs(DATASET.train_y)),\n", " marginal_log_likelihood_class=VariationalELBO,\n", " likelihood_kwargs={\"learn_additional_noise\": True},\n", " optim_kwargs={\"lr\": 0.01},\n", " rng=np.random.default_rng(random_seed),\n", ")\n", "\n", "with gp.metrics_tracker.print_metrics(every=200):\n", " gp.fit(n_sgd_iters=2000)" ] }, { "cell_type": "code", "execution_count": null, "id": "eb2e3b3f", "metadata": {}, "outputs": [], "source": [ "inducing_points = gp._gp.variational_strategy.inducing_points.detach().cpu().numpy()\n", "plt_x = DATASET.train_x[:, [4, 7]].detach().cpu().numpy()\n", "\n", "plt.scatter(plt_x[:, 0], plt_x[:, 1])\n", "plt.scatter(inducing_points[:, 0], inducing_points[:, 1], marker=\"x\")\n", "plt.show()" ] }, { "cell_type": "raw", "id": "0e4260b8", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "Modelling\n", "---------\n", "\n", "Let's do some plain VI. First we need to define the class using a :py:class:`~vanguard.variational.VariationalInference` decorator. We have to specify a model that subclasses :py:class:`~gpytorch.models.ApproximateGP`, in this case we'll use the stock :py:class:`~vanguard.variational.models.SVGPModel` which uses GPyTorch's default variational strategy and distribution. We also need to specify the marginal-log likelihood - in this case we use the standard :py:class:`~gpytorch.mlls.VariationalELBO` from GPyTorch. The :py:class:`~vanguard.variational.VariationalInference` decorator deals with the rest." ] }, { "cell_type": "raw", "id": "4ef6f3433d6b30a0", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "A large number of inducing points will produce excellent results but take quite a while (suggest using a GPU if you are going to do this). Setting ``SLOW = False`` will reduce the number of inducing points used and the number of training epochs to make this notebook run quickly for the purposes of demonstration." ] }, { "cell_type": "code", "execution_count": null, "id": "a26cc8f2", "metadata": {}, "outputs": [], "source": [ "SLOW = False" ] }, { "cell_type": "code", "execution_count": null, "id": "8ae8a3a0", "metadata": {}, "outputs": [], "source": [ "N_INDUCING_POINTS = 750 if SLOW else 20\n", "\n", "\n", "@VariationalInference(n_inducing_points=N_INDUCING_POINTS, ignore_methods=(\"__init__\",))\n", "class GaussianVariationalGPController(GaussianGPController):\n", " \"\"\"Does variational inference.\"\"\"\n", "\n", " pass" ] }, { "cell_type": "code", "execution_count": null, "id": "7a193412", "metadata": {}, "outputs": [], "source": [ "# TODO: Include a batch_size argument in this example when functionality resolved\n", "# https://github.com/gchq/Vanguard/issues/377\n", "# BATCH_SIZE = 256\n", "# NUM_ITERS = max(len(DATASET.train_x) // BATCH_SIZE, 15) * (100 if SLOW else 10)\n", "NUM_ITERS = max(len(DATASET.train_x), 15) * (100 if SLOW else 10)\n", "print(NUM_ITERS)" ] }, { "cell_type": "code", "execution_count": null, "id": "66458b1a", "metadata": {}, "outputs": [], "source": [ "class ScaledMaternKernel(ScaleKernel):\n", " \"\"\"A scaled matern kernel.\"\"\"\n", "\n", " def __init__(self):\n", " super().__init__(MaternKernel(nu=1.5, ard_num_dims=DATASET.train_x.shape[1]))" ] }, { "cell_type": "code", "execution_count": null, "id": "b4abfd4e", "metadata": {}, "outputs": [], "source": [ "# TODO: Include a batch_size argument in this example when functionality resolved\n", "# https://github.com/gchq/Vanguard/issues/377\n", "gp = GaussianVariationalGPController(\n", " train_x=DATASET.train_x,\n", " train_y=DATASET.train_y,\n", " kernel_class=ScaledMaternKernel,\n", " y_std=0.001 * torch.mean(torch.abs(DATASET.train_y)),\n", " marginal_log_likelihood_class=VariationalELBO,\n", " likelihood_kwargs={\"learn_additional_noise\": True},\n", " optim_kwargs={\"lr\": 0.01},\n", " rng=np.random.default_rng(random_seed),\n", ")\n", "\n", "with gp.metrics_tracker.print_metrics(every=150):\n", " gp.fit(n_sgd_iters=NUM_ITERS)" ] }, { "cell_type": "code", "execution_count": null, "id": "86cb0faa", "metadata": {}, "outputs": [], "source": [ "posterior = gp.predictive_likelihood(DATASET.test_x)\n", "DATASET.plot_prediction(*posterior.confidence_interval())\n", "plt.show()" ] }, { "cell_type": "raw", "id": "fe55f1f8", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "Now let's look at SVGP combined with compositional warping. We'll use an affine-log warp to reflect the non-negativity of the data: :math:`\\phi(y) = a + b\\log(y)`.\n", "\n", "The code to create this GP model in Vanguard is simple. Use a :py:class:`~vanguard.warps.SetWarp` decorator to apply the warp, a :py:class:`~vanguard.variational.VariationalInference` decorator to make the GP variational and specify the variational model type (:py:class:`~vanguard.variational.models.SVGPModel` here) and variational objective (:py:class:`~gpytorch.mlls.VariationalELBO` here)." ] }, { "cell_type": "code", "execution_count": null, "id": "79459a90", "metadata": {}, "outputs": [], "source": [ "warp = warpfunctions.AffineWarpFunction() @ warpfunctions.BoxCoxWarpFunction(lambda_=0)\n", "\n", "\n", "@SetWarp(warp_function=warp, ignore_methods=(\"fit\", \"__init__\"))\n", "@VariationalInference(n_inducing_points=N_INDUCING_POINTS, ignore_methods=(\"__init__\",))\n", "class WarpedGaussianVariationalGPController(GaussianGPController):\n", " \"\"\"Does variational inference.\"\"\"\n", "\n", " pass" ] }, { "cell_type": "code", "execution_count": null, "id": "9c0a063f", "metadata": {}, "outputs": [], "source": [ "# TODO: Include a batch_size argument in this example when functionality resolved\n", "# https://github.com/gchq/Vanguard/issues/377\n", "gp = WarpedGaussianVariationalGPController(\n", " train_x=DATASET.train_x,\n", " train_y=DATASET.train_y,\n", " kernel_class=ScaledMaternKernel,\n", " y_std=0.001 * torch.mean(torch.abs(DATASET.train_y)),\n", " marginal_log_likelihood_class=VariationalELBO,\n", " likelihood_kwargs={\"learn_additional_noise\": True},\n", " optim_kwargs={\"lr\": 0.01},\n", " rng=np.random.default_rng(random_seed),\n", ")\n", "\n", "with gp.metrics_tracker.print_metrics(every=150):\n", " gp.fit(n_sgd_iters=NUM_ITERS)" ] }, { "cell_type": "code", "execution_count": null, "id": "91dc5d33", "metadata": {}, "outputs": [], "source": [ "warp_posterior = gp.predictive_likelihood(DATASET.test_x)\n", "DATASET.plot_prediction(*warp_posterior.confidence_interval())\n", "plt.show()" ] }, { "cell_type": "raw", "id": "f0aa3b428113c674", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "Warping improves the RMSE overall but is likely to be most useful for smaller :math:`y` values, so let's filter by the true :math:`y` value and compare warping to no warping." ] }, { "cell_type": "code", "execution_count": null, "id": "9ec875fd", "metadata": {}, "outputs": [], "source": [ "plt.figure(figsize=(12, 4))\n", "plt.subplot(1, 2, 1)\n", "DATASET.plot_prediction(*warp_posterior.confidence_interval(), y_upper_bound=0.5)\n", "plt.title(\"Warping. \" + plt.gca().title.get_text())\n", "plt.subplot(1, 2, 2)\n", "DATASET.plot_prediction(*posterior.confidence_interval(), y_upper_bound=0.5)\n", "plt.title(\"No warping. \" + plt.gca().title.get_text())\n", "plt.show()" ] }, { "cell_type": "raw", "id": "ae2d8daa0552b1f2", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "This demonstrates nicely that the warping is working where it matters, preventing impossible negative predictions.\n", "\n", "Finally we can demonstrate combining with input uncertainty as well, using some dummy input noise. " ] }, { "cell_type": "code", "execution_count": null, "id": "20f3902e", "metadata": {}, "outputs": [], "source": [ "warp = warpfunctions.AffineWarpFunction() @ warpfunctions.BoxCoxWarpFunction(lambda_=0)\n", "\n", "\n", "@SetWarp(warp_function=warp, ignore_all=True)\n", "@VariationalInference(n_inducing_points=N_INDUCING_POINTS, ignore_all=True)\n", "class WarpedGaussianUncertaintyVariationalGPController(GaussianUncertaintyGPController):\n", " \"\"\"Does variational inference.\"\"\"\n", "\n", " pass" ] }, { "cell_type": "code", "execution_count": null, "id": "0a0a02a0", "metadata": {}, "outputs": [], "source": [ "# TODO: Include a batch_size argument in this example when functionality resolved\n", "# https://github.com/gchq/Vanguard/issues/377\n", "gp = WarpedGaussianUncertaintyVariationalGPController(\n", " train_x=DATASET.train_x,\n", " train_x_std=0.1,\n", " train_y=DATASET.train_y,\n", " kernel_class=ScaledMaternKernel,\n", " y_std=0.001 * torch.mean(torch.abs(DATASET.train_y)),\n", " marginal_log_likelihood_class=VariationalELBO,\n", " likelihood_kwargs={\"learn_additional_noise\": True},\n", " optim_kwargs={\"lr\": 0.01},\n", " rng=np.random.default_rng(random_seed),\n", ")\n", "\n", "with gp.metrics_tracker.print_metrics(every=150):\n", " gp.fit(n_sgd_iters=NUM_ITERS)" ] }, { "cell_type": "code", "execution_count": null, "id": "0e0c7d70", "metadata": {}, "outputs": [], "source": [ "posterior = gp.predictive_likelihood(DATASET.test_x)\n", "DATASET.plot_prediction(*warp_posterior.confidence_interval())\n", "plt.show()" ] }, { "cell_type": "raw", "id": "edf293f40dc4d3b2", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "Conclusions\n", "-----------\n", "\n", "This short example demonstrates that compositional warping can be combined with sparse variational GP inference in Vanguard using very little code. We have demonstrated good results on a real-world dataset with ~15.5k training items. We have compared plain SVGP with a warped SVGP and found similar performance with the warped model. Other datasets may exhibit a stronger preference for warping, but we have shown that, for low values of the regressand (close to zero), the warped GP is much better, as it makes no impossible negative predictions. We have shown that combining warping and variational inference is feasible and the training is no more difficult that a plain SVGP.\n", "\n", "In addition, we have provided a proof-of-concept demonstration that warping, input uncertainty and variational GPs can be combined simply within Vanguard and trained successfully." ] } ], "metadata": { "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3" }, "nbsphinx": { "execute": "never" } }, "nbformat": 4, "nbformat_minor": 5 }