Multiclass Classification with Dirichlet Distributions¶
[1]:
# © Crown Copyright GCHQ
#
# Licensed under the GNU General Public License, version 3 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.gnu.org/licenses/gpl-3.0.en.html
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
An alternative implementation for multiclass classification in Vanguard. This methodology is based on this example notebook. To get started with multi-class classification, make sure you check out the multi-class classification example.
[2]:
random_seed = 1_989
[3]:
import numpy as np
from gpytorch.likelihoods import DirichletClassificationLikelihood
from gpytorch.means import ZeroMean
from matplotlib import pyplot as plt
from vanguard.classification import DirichletMulticlassClassification
from vanguard.classification.kernel import DirichletKernelMulticlassClassification
from vanguard.classification.likelihoods import (
DirichletKernelClassifierLikelihood,
GenericExactMarginalLogLikelihood,
)
from vanguard.datasets.classification import MulticlassGaussianClassificationDataset
from vanguard.kernels import ScaledRBFKernel
from vanguard.vanilla import GaussianGPController
Introduction¶
Recall that in standard multi-class classification, one is essentially training a binary classifier for each distinct class, taking advantage of Vanguard components to ensure that covariance between them is properly ascertained. In this example, we consider a different method for multi-class classification, where we regress directly onto the target label probability distributions. The theory behind this is explored fully in sections 4 and 4.1 of [Milios18].
Consider a classification problem over $m$ classes. The aim is to infer a posterior distribution for each prediction, from which we draw a multinomial over the classes \(\pi = (\pi_1,\dots,\pi_m)\), corresponding to the class probabilities. This distribution takes the form of a Dirichlet model parameterised by some \(m\)-dimensional vector \(\alpha\):
A common way to generate a sample from a Dirichlet distribution is to instead consider \(m\) independent random Gamma distributions. If we have that
and we define
then
Recall that the output from a Gaussian process is a normal distribution, and it is difficult to “transition” samples from such a distribution to a Gamma distribution in order to generate our Dirichlet sample. Instead of using a Gamma directly, we instead approximate it with a Log-normal distribution. Consider our approximate random variable \(\bar{x}_i\):
The mean and variance of \(\bar{x}_i\) are given by:
Given that \(\text{E}[x_i] = \text{Var}[x_i] = \alpha_i\), it is possible to deduce the values for \(\bar{\mu}_i\) and \(\bar{\sigma}_i^2\) to match these:
This can be verified with the following plot:
[4]:
alpha_i = 0.86858729
sigma_squared_i = np.log(1 / alpha_i + 1)
mu_i = np.log(alpha_i - sigma_squared_i / 2)
n_samples = 10_000
random_generator = np.random.Generator(np.random.PCG64(seed=random_seed))
gamma_samples = random_generator.gamma(shape=alpha_i, scale=1.0, size=n_samples)
lognormal_samples = random_generator.lognormal(mean=mu_i, sigma=np.sqrt(sigma_squared_i), size=n_samples)
plt.figure(figsize=(10, 5))
n_bins = 150
plt.hist(gamma_samples, bins=n_bins, density=True, alpha=0.6, label="gamma")
plt.hist(lognormal_samples, bins=n_bins, density=True, alpha=0.6, label="lognormal")
plt.xlim(right=8)
plt.legend()
plt.show()
If a random variable \(X\) is normally-distributed, then \(\exp(X)\) is log-normally distributed. Given that the output of the Gaussian process model is a Gaussian distribution, then this means we can follow the above steps in reverse to sample from \(\text{Dir}(\alpha)\). By maximising the likelihood, we infer the parameters to regress onto the correct probability distributions \(\pi\). This is all taken care of within the DirichletClassificationLikelihood.
Data¶
We start with the MulticlassGaussianClassificationDataset for this experiment, which creates multiple classes based on the distance to the centre of a two-dimensional Gaussian distribution.
[5]:
NUM_CLASSES = 4
DATASET = MulticlassGaussianClassificationDataset(
num_train_points=100,
num_test_points=500,
num_classes=NUM_CLASSES,
covariance_scale=1,
rng=np.random.default_rng(random_seed),
)
[6]:
plt.figure(figsize=(8, 8))
DATASET.plot()
plt.show()
Instead of the CategoricalClassification decorator we used before, we instead use the DirichletMulticlassClassification decorator. Note that we no longer require the VariationalInference decorator - this works with exact inference.
[7]:
@DirichletMulticlassClassification(num_classes=NUM_CLASSES, ignore_methods=("__init__",))
class MulticlassGaussianClassifier(GaussianGPController):
pass
We require the DirichletClassificationLikelihood, and also need to ensure that the correct batch shape is passed:
[8]:
controller = MulticlassGaussianClassifier(
DATASET.train_x,
DATASET.train_y,
ScaledRBFKernel,
y_std=0,
mean_class=ZeroMean,
likelihood_class=DirichletClassificationLikelihood,
mean_kwargs={"batch_shape": (NUM_CLASSES,)},
kernel_kwargs={"batch_shape": (NUM_CLASSES,)},
likelihood_kwargs={"alpha_epsilon": 0.3, "learn_additional_noise": True},
optim_kwargs={"lr": 0.05},
rng=np.random.default_rng(random_seed),
)
/home/docs/checkouts/readthedocs.org/user_builds/vanguard/envs/stable/lib/python3.13/site-packages/torch/utils/_device.py:104: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
return func(*args, **kwargs)
[9]:
predictions, probs = controller.classify_points(DATASET.test_x)
[10]:
plt.figure(figsize=(8, 8))
DATASET.plot_prediction(predictions)
plt.show()
[11]:
controller.fit(100)
predictions, probs = controller.classify_points(DATASET.test_x)
/home/docs/checkouts/readthedocs.org/user_builds/vanguard/envs/stable/lib/python3.13/site-packages/linear_operator/utils/cholesky.py:40: NumericalWarning: A not p.d., added jitter of 1.0e-06 to the diagonal
warnings.warn(
[12]:
plt.figure(figsize=(8, 8))
DATASET.plot_prediction(predictions)
plt.show()
Note that the model does surprisingly well without fitting, but that fitting does not seem to make much of a difference.
Dirichlet Kernel Approximation Modelling¶
An alternative method for modelling uses the DirichletKernelMulticlassClassification decorator. It requires the DirichletKernelClassifierLikelihood and the GenericExactMarginalLogLikelihood special marginal log likelihood.
This method is based on [MacKenzie14]. It is is a kernel machine method with a Dirichlet likelihood. The posterior over the classes is
where \(N\) is the total number of data points, \(N_{-i}\) is the total number of data points that are not in class \(i\) and \(m\) is the total number of classes. \(y_j\) is the class of the \(j\)-th training data point. Here \(\alpha\) is a hyperparameter (it is the usual Dirichlet prior hyperparameter) and can be tuned like any other if desired. The kernel \(k\) is just like any other kernel in GP modelling.
Note
This model is not actually a GP, but is similar enough in practice to warrant its inclusion in Vanguard (it is a non-parametric kernel classifier with a Dirichlet likelihood). In addition, there may exist a formulation in which this method is an approximation to a GP posterior, but we have not been able to find this formulation.
[13]:
@DirichletKernelMulticlassClassification(num_classes=NUM_CLASSES, ignore_methods=("__init__",))
class MulticlassGaussianClassifier(GaussianGPController):
pass
/tmp/ipykernel_702/3244861142.py:1: ExperimentalFeatureWarning: The DirichletKernelMulticlassClassification decorator is currently an experimental feature. It may cause errors or give incorrect results, and may have breaking changes without warning.
@DirichletKernelMulticlassClassification(num_classes=NUM_CLASSES, ignore_methods=("__init__",))
[14]:
controller = MulticlassGaussianClassifier(
DATASET.train_x,
DATASET.train_y,
kernel_class=ScaledRBFKernel,
y_std=0,
mean_class=ZeroMean,
likelihood_class=DirichletKernelClassifierLikelihood,
likelihood_kwargs={"learn_alpha": False, "alpha": 5},
marginal_log_likelihood_class=GenericExactMarginalLogLikelihood,
optim_kwargs={"lr": 0.1, "early_stop_patience": 5},
rng=np.random.default_rng(random_seed),
)
/home/docs/checkouts/readthedocs.org/user_builds/vanguard/envs/stable/lib/python3.13/site-packages/torch/utils/_device.py:104: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
return func(*args, **kwargs)
/home/docs/checkouts/readthedocs.org/user_builds/vanguard/envs/stable/lib/python3.13/site-packages/vanguard/base/basecontroller.py:573: UserWarning: A regression problem with no warping may suffer from numerical instability in optimisation if the y values are not standard scaled. Using the NormaliseY decorator will likely help.
warnings.warn(
[15]:
plt.figure(figsize=(8, 8))
DATASET.plot_prediction(predictions)
plt.show()
[16]:
with controller.metrics_tracker.print_metrics(every=25):
controller.fit(100)
/home/docs/checkouts/readthedocs.org/user_builds/vanguard/envs/stable/lib/python3.13/site-packages/linear_operator/utils/interpolation.py:71: UserWarning: torch.sparse.SparseTensor(indices, values, shape, *, device=) is deprecated. Please use torch.sparse_coo_tensor(indices, values, shape, dtype=, device=). (Triggered internally at /pytorch/torch/csrc/utils/tensor_new.cpp:644.)
summing_matrix = cls(summing_matrix_indices, summing_matrix_values, size)
iteration: 25, loss: 0.951393723487854
iteration: 50, loss: 0.8208474516868591
iteration: 75, loss: 0.7446041703224182
iteration: 100, loss: 0.691365659236908
[17]:
predictions, probs = controller.classify_points(DATASET.test_x)
[18]:
plt.figure(figsize=(8, 8))
DATASET.plot_prediction(predictions)
plt.show()
This model seems prone to overfitting in the kernel hyperparameters but particularly so in \(\alpha\) as seen below:
[19]:
controller = MulticlassGaussianClassifier(
DATASET.train_x,
DATASET.train_y,
kernel_class=ScaledRBFKernel,
y_std=0,
mean_class=ZeroMean,
likelihood_class=DirichletKernelClassifierLikelihood,
likelihood_kwargs={"learn_alpha": True, "alpha": 5},
marginal_log_likelihood_class=GenericExactMarginalLogLikelihood,
optim_kwargs={"lr": 0.1, "early_stop_patience": 5},
rng=np.random.default_rng(random_seed),
)
/home/docs/checkouts/readthedocs.org/user_builds/vanguard/envs/stable/lib/python3.13/site-packages/torch/utils/_device.py:104: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
return func(*args, **kwargs)
/home/docs/checkouts/readthedocs.org/user_builds/vanguard/envs/stable/lib/python3.13/site-packages/vanguard/base/basecontroller.py:573: UserWarning: A regression problem with no warping may suffer from numerical instability in optimisation if the y values are not standard scaled. Using the NormaliseY decorator will likely help.
warnings.warn(
[20]:
plt.figure(figsize=(8, 8))
DATASET.plot_prediction(predictions)
plt.show()
[21]:
with controller.metrics_tracker.print_metrics(every=25):
controller.fit(100)
iteration: 25, loss: 0.8169602751731873
iteration: 50, loss: 0.23843780159950256
iteration: 75, loss: 0.024171294644474983
iteration: 100, loss: 0.010413208045065403
[22]:
predictions, probs = controller.classify_points(DATASET.test_x)
[23]:
plt.figure(figsize=(8, 8))
DATASET.plot_prediction(predictions)
plt.show()
[24]:
plt.figure(figsize=(8, 8))
DATASET.plot_confusion_matrix(predictions)
plt.show()
Conclusions¶
It is unlikely that Dirichlet multi-class classification is ever worth using over previous techniques, but it is interesting how powerful it is without any training.