Deep Kernel Learning (DKL) combines deep neural networks with Gaussian Processes to model complex data patterns. In this post, we explore DKL and implement a DKL model using the GPyTorch library in Python.
In this post, we explore deep kernel learning (DKL), a hybrid approach that combines the strengths of deep neural networks (DNNs) with Gaussian processes (GPs). DKL offers a powerful framework for modeling complex data patterns, enhancing the predictive capabilities and interpretability of standard GPs. If you are new to Gaussian processes, we recommend reading our previous posts on Gaussian processes and Multi-output Gaussian processes before diving into DKL.
Gaussian processes (GPs)
Gaussian processes are non-parametric models used for regression and classification tasks. A GP is defined by its mean function \(m(\cdot)\) and covariance function (kernel) \(k(\cdot, \cdot)\).
\[
f \sim \mathcal{GP}(m, k)
\]
Remember that, given training data \(X = [x_1, \ldots, x_n]\) and targets \(Y = [y_1 = g(x_1), \ldots, y_n = g(x_n)]\), the predictive distribution for a test point \(x_*\) is:
and \(K_n\) is the kernel matrix for training data \(X\) with noise added to the diagonal.
Here, we can notice how much the kernel is important. It plays a vital role both in the mean and the variance of the predictive distribution (that is why often the mean is set to zero). Therefore, its choice is crucial for the model’s performance, and ad-hoc kernels are designed to capture specific patterns in the data.
Deep kernel learning (DKL)
DKL integrates DNNs with GPs by using a DNN to learn a representation of the data, which is then used to define the GP kernel. This allows the GP to capture complex patterns in the data that a standard kernel might miss.
In DKL, the DNN acts as a feature extractor, transforming input \(X\) into a feature vector \(Z = \phi(X)\). The GP kernel is then defined on these features:
Here stands the flexibility of DKL. The most suitable kernel is learnt ad-hoc for the data at hand, and it is not fixed a priori. Indeed, the DNN parameters and the GP ones are jointly optimized to maximize the likelihood of the data.
Implementation with GPyTorch
Let’s try to implement a DKL model using the GPyTorch library in Python. Then we will compare its performance with a standard GP. We will use the McCormick function as a synthetic dataset and compare the DKL model’s performance with a standard GP. The McCormick function is defined as:
and it has been widely used as a benchmark function for optimization and regression tasks. To make the proplem more challenging, gaussian noise is added to the training data, as a standard normal distribution with mean 0 and standard deviation 1. Moreover, 25 training points are used to fit the model, distributed on a 5x5 grid. Note that such grid is not fine-grained, as the domain of the function is \(x_1 \in [-1.5, 4]\) and \(x_2 \in [-3, 4]\). Therefore, the model has to generalize well to make accurate predictions on unseen data.
First of all, we need to import the libraries that we need.
Code
%pip install plotlyimport numpy as npimport torchimport gpytorchimport plotly.graph_objects as goimport matplotlib.pyplot as pltimport pandas as pdfrom torch import nnfrom plotly.subplots import make_subplotsfrom sklearn.metrics import mean_squared_error
Now, we can define the McCormick function and generate the synthetic dataset for training and testing.
Figure 1: Neural network architecture
Regarding the deep neural network, we define a simple feedforward network with two hidden layers and a softplus activation function. The architecture is defined in the FeatureExtractor class, and it is shown in Figure 1. The remaining part is the same as the standard GP model, and it is deeply explained in the Gaussian processes post.
The interactive 3D plot above shows the predictions of the DKL (blue) and GP (green) models on the McCormick function (red). The DKL model captures the complex patterns of the function more accurately than the GP model, providing a better fit to the true surface.
In terms of performance, we can compare the mean squared error (MSE) of the DKL and GP models on the test data (i.e., the McCormick function evaluated on a grid 100x100).
The results are clear - the DKL model outperforms the GP model, demonstrating the benefits of combining deep neural networks with Gaussian Processes for complex regression tasks.
NN’s embedding
Let’s visualize the embedding of the input data into the feature space learned by the DNN.
Code
# Get the feature embedding of the training datawith torch.no_grad(): feature_embedding = dkl_model.feature_extractor(train_x)# Create a subplot with 1 row and 2 columnsfig = make_subplots(rows=1, cols=2, subplot_titles=('Original Training Data', 'Feature Embedding'))# Add the original training data to the first subplotfig.add_trace( go.Scatter( x=train_x[:, 0].numpy(), y=train_x[:, 1].numpy(), mode='markers', marker=dict(size=5, color=train_y.numpy(), colorscale='Viridis', opacity=0.8), name='Training Data', hoverinfo='text', text=[f'Index: {i}'for i inrange(len(train_x))]), # Custom hover text row=1, col=1)# Add the feature embedding data to the second subplotfig.add_trace( go.Scatter( x=feature_embedding[:, 0].numpy(), y=feature_embedding[:, 1].numpy(), mode='markers', marker=dict(size=5, color=train_y.numpy(), colorscale='Viridis', opacity=0.8), name='Feature Embedding', hoverinfo='text', text=[f'Index: {i}'for i inrange(len(feature_embedding))]), # Custom hover text row=1, col=2)# Update layout for a cohesive lookfig.update_layout(height=350, width=700, showlegend=False)fig.show()
It is extremely interesting that the DNN has noticed the kind of simmetry of the McCormick function. As a result, it has learned an almost 1D representation of the data, ordered by the value of the function. Now, the GP can easily fit the data, as the feature space is more suitable for the task.
Garnett, Roman. 2023. Bayesian Optimization. Cambridge University Press.
Ober, Sebastian W., Carl E. Rasmussen, and Mark van der Wilk. 2021. “The Promises and Pitfalls of Deep Kernel Learning.”https://arxiv.org/abs/2102.12108.
Wilson, Andrew Gordon, Zhiting Hu, Ruslan Salakhutdinov, and Eric P. Xing. 2015. “Deep Kernel Learning.”https://arxiv.org/abs/1511.02222.