{ "cells": [ { "cell_type": "raw", "id": "44be0f69-fefb-40c4-b4ee-0bf9419cc0de", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "Binary Classification in Vanguard\n", "=================================" ] }, { "cell_type": "code", "execution_count": null, "id": "a68a63dc10db329a", "metadata": {}, "outputs": [], "source": [ "# © Crown Copyright GCHQ\n", "#\n", "# Licensed under the GNU General Public License, version 3 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.gnu.org/licenses/gpl-3.0.en.html\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "raw", "id": "11c27eadf980cb1c", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "A showcase of the implementation of standard binary classification in Vanguard. The approach used for this implementation borrows heavily from `this GPyTorch example `_." ] }, { "cell_type": "code", "execution_count": null, "id": "9c5dcc02a8ee496f", "metadata": {}, "outputs": [], "source": [ "random_seed = 1_989" ] }, { "cell_type": "code", "execution_count": null, "id": "3734997c", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from gpytorch.likelihoods import BernoulliLikelihood\n", "from gpytorch.mlls import VariationalELBO\n", "from matplotlib import pyplot as plt\n", "\n", "from vanguard.classification import BinaryClassification\n", "from vanguard.datasets.classification import BinaryGaussianClassificationDataset\n", "from vanguard.kernels import ScaledRBFKernel\n", "from vanguard.vanilla import GaussianGPController\n", "from vanguard.variational import VariationalInference" ] }, { "cell_type": "raw", "id": "4a3e17c1-866f-4952-8449-7d4e54dfe12e", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "Introduction\n", "------------\n", "\n", "A standard binary classification problem can be mapped to a regression problem very straightforwardly. Instead of considering your data points as class indices, consider them as extreme points in the interval $[0, 1]$. By regressing on those points, the model can be used to make class predictions by thresholding on the value, where $[0, 0.5)$ denotes one class, and $[0.5, 1]$ the other. This value can also be used to determine the model uncertainty, as values closer to the extremes imply more certainty." ] }, { "cell_type": "raw", "id": "979d5161", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "Given that we are regressing on two classes, we make use of the :py:class:`~gpytorch.likelihoods.BernoulliLikelihood` in order to transform the latent posteriors into actual probabilities. Given that the standard output from the model is a Gaussian distribution, this likelihood employs `probit regression `_ to give us proper probabilities, scaling via the standard Gaussian cumulative distribution function. In particular, the probit likelihood is calculated in closed form by applying the following formula :cite:`Kuss05`:\n", "\n", ".. math::\n", " q(y_*=1\\mid\\mathcal{D},{\\pmb{\\theta}},{\\bf x_*})\n", " = \\int {\\bf\\Phi}(f_*)\\mathcal{N}(f_*\\mid\\mu_*,\\sigma_*^2)df_*\n", " = {\\bf\\Phi}\\left( \\frac{\\mu_*}{\\sqrt{1 + \\sigma_*^2}} \\right ).\n", " \n", "This means that the predictive uncertainty is properly taken into account." ] }, { "cell_type": "raw", "id": "ab428893", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "Data\n", "----\n", "\n", "We use the :py:class:`~vanguard.datasets.classification.BinaryGaussianClassificationDataset` for this experiment, which creates two classes based on the distance to the centre of a two-dimensional gaussian distribution. We use relatively few training points to prevent the model overfitting out of the gate." ] }, { "cell_type": "code", "execution_count": null, "id": "985ea6f0", "metadata": {}, "outputs": [], "source": [ "DATASET = BinaryGaussianClassificationDataset(\n", " num_train_points=10, num_test_points=100, rng=np.random.default_rng(random_seed)\n", ")" ] }, { "cell_type": "raw", "id": "45595d92-f9e5-449f-8261-53cabbb47cfd", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "We plot all of the truth data:" ] }, { "cell_type": "code", "execution_count": null, "id": "da3c98a7", "metadata": {}, "outputs": [], "source": [ "plt.figure(figsize=(8, 8))\n", "DATASET.plot()\n", "plt.show()" ] }, { "cell_type": "raw", "id": "4c39f043", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "Modelling\n", "---------\n", "\n", "Preparing a controller for binary classification is as straightforward as applying the :py:class:`~vanguard.classification.binary.BinaryClassification` decorator. Because the :py:class:`~gpytorch.likelihoods.BernoulliLikelihood` is non-Gaussian, the :py:class:`~vanguard.variational.VariationalInference` decorator is required in order to run approximate inference. It also has the added benefit of using a smaller number of inducing points, which will enable inference on the the larger datasets traditionally used in classification tasks." ] }, { "cell_type": "code", "execution_count": null, "id": "1137b1fb", "metadata": {}, "outputs": [], "source": [ "@BinaryClassification(ignore_all=True)\n", "@VariationalInference(ignore_all=True)\n", "class BinaryClassifier(GaussianGPController):\n", " pass" ] }, { "cell_type": "raw", "id": "cacac906", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "We choose a standard kernel: the :py:class:`~vanguard.kernels.ScaledRBFKernel`. The likelihood class must also be a subclass of :py:class:`~gpytorch.likelihoods.BernoulliLikelihood`, and the marginal log likelihood class needs to accept a ``num_data`` parameter, so the safest bet is the :py:class:`~gpytorch.mlls.VariationalELBO` class." ] }, { "cell_type": "code", "execution_count": null, "id": "0241318d", "metadata": {}, "outputs": [], "source": [ "controller = BinaryClassifier(\n", " DATASET.train_x,\n", " DATASET.train_y,\n", " kernel_class=ScaledRBFKernel,\n", " y_std=0,\n", " likelihood_class=BernoulliLikelihood,\n", " marginal_log_likelihood_class=VariationalELBO,\n", " rng=np.random.default_rng(random_seed),\n", ")" ] }, { "cell_type": "raw", "id": "ec04ba4d", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "Before we try fitting, let's see how well the classifier does without any hyperparameter training. We cannot use the :py:meth:`~vanguard.base.gpcontroller.GPController.posterior_over_point` method, as the model posteriors need to be passed through the likelihood to be properly scaled. Instead, classifiers in Vanguard have a special ``classify_points`` method to do this." ] }, { "cell_type": "code", "execution_count": null, "id": "70b9d9dd", "metadata": {}, "outputs": [], "source": [ "predictions, probs = controller.classify_points(DATASET.test_x)" ] }, { "cell_type": "raw", "id": "cfb0ef81-c9cb-415e-9a9c-3cf988e3d677", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "Recall that the output from the model is being scaled by the likelihood to the interval $[0, 1]$. In fact, the uncertainty from that output is ignored, as the means of those distributions implies the model uncertainty based on distance from the extrema.\n", "\n", "The plot below shows the prediction classes. A circle represents a correct prediction, whereas a cross represents an incorrect prediction." ] }, { "cell_type": "code", "execution_count": null, "id": "a7672925", "metadata": {}, "outputs": [], "source": [ "plt.figure(figsize=(8, 8))\n", "DATASET.plot_prediction(predictions)\n", "plt.show()" ] }, { "cell_type": "raw", "id": "5fc6c1f8-567b-4aeb-bd03-d7771f956698", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "Now we actually try fitting, to see if this improves the performance." ] }, { "cell_type": "code", "execution_count": null, "id": "069e63ca", "metadata": {}, "outputs": [], "source": [ "loss = controller.fit(100)\n", "print(f\"Loss: {loss:.5f}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "466049c3", "metadata": {}, "outputs": [], "source": [ "predictions, probs = controller.classify_points(DATASET.test_x)" ] }, { "cell_type": "code", "execution_count": null, "id": "1bcce54c", "metadata": {}, "outputs": [], "source": [ "plt.figure(figsize=(8, 8))\n", "DATASET.plot_prediction(predictions)\n", "plt.show()" ] }, { "cell_type": "raw", "id": "5b8c4215-59f4-49bf-9fc2-00f3496e9aeb", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "Conclusions\n", "-----------\n", "\n", "We have successfully demonstrated binary classification in Vanguard. For classification tasks with more than two classes, check out the `multiclass example notebook `_" ] } ], "metadata": { "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3" } }, "nbformat": 4, "nbformat_minor": 5 }