Hi,
I am looking for any implementations available for Deep kernel learning([1511.02222] Deep Kernel Learning ). I am aware of gpax(GitHub - ziatdinovmax/gpax: Gaussian Processes for Experimental Sciences , more specifically this example gpax/examples/gpax_viDKL_plasmons.ipynb at main · ziatdinovmax/gpax · GitHub ) having it, but it is built on top of Jax.
Does anyone know if there is any implementation already out there in botorch (motivation being botorch already have many optimizers to play with)?
1 Like
While it’s not a direct answer/solution, this thread may be of help!
opened 09:47PM - 07 Oct 22 UTC
closed 03:46PM - 30 Nov 22 UTC
announcement
I probably lack the understanding and the language required to talk about this e… ffectively, so here are a few follow-up questions.
- are you familiar with structured GPs?
- is this the right name to be using? (e.g. what's described in https://www.sciencedirect.com/science/article/pii/S0021999119300397)
From my basic understanding, it's functionally similar to [performing BO over a VAE latent space](https://botorch.org/tutorials/vae_mnist), except that the latent space embeddings aren't entirely as fixed, and the manifold itself is learned based on what a deep kernel learning (?) model decides is "useful" or not. On a higher level, I've been told it's useful for incorporating physical insight/domain knowledge (e.g. physical models) into active learning.
I'm asking based on some discussion with Sergei Kalinin on DKL models they've been applying in microscopy settings and how it applies to other domains. See e.g. https://arxiv.org/abs/2205.15458
Related:
- https://github.com/ziatdinovmax/gpax
- https://github.com/pycroscopy/atomai
- https://twitter.com/Sergei_Imaging/status/1467856622668103686
- [The Promises and Pitfalls of Deep Kernel Learning](https://arxiv.org/abs/2102.12108)
from [Twitter search of deep kernel learning](https://twitter.com/search?q=deep%20kernel%20learning)
From https://github.com/ziatdinovmax/gpax:
> The limitation of the standard GP is that it does not usually allow for the incorporation of prior domain knowledge and can be biased toward a trivial interpolative solution. Recently, we [introduced](https://arxiv.org/abs/2108.10280) a structured Gaussian Process (sGP), where a classical GP is augmented by a structured probabilistic model of the expected system’s behavior. This approach allows us to [balance](https://towardsdatascience.com/unknown-knowns-bayesian-inference-and-structured-gaussian-processes-why-domain-scientists-know-4659b7e924a4) the flexibility of the non-parametric GP approach with a rigid structure of prior (physical) knowledge encoded into the parametric model. Implementation-wise, we substitute a constant/zero prior mean function in GP with a probabilistic model of the expected system's behavior.
>
> The limitation of the standard GP is that it does not usually allow for the incorporation of prior domain knowledge and can be biased toward a trivial interpolative solution. Recently, we [introduced](https://arxiv.org/abs/2108.10280) a structured Gaussian Process (sGP), where a classical GP is augmented by a structured probabilistic model of the expected system’s behavior. This approach allows us to [balance](https://towardsdatascience.com/unknown-knowns-bayesian-inference-and-structured-gaussian-processes-why-domain-scientists-know-4659b7e924a4) the flexibility of the non-parametric GP approach with a rigid structure of prior (physical) knowledge encoded into the parametric model.
> Implementation-wise, we substitute a constant/zero prior mean function in GP with a probabilistic model of the expected system's behavior.
> ...
> For example, if we have prior knowledge that our objective function has a discontinuous 'phase transition', and a power law-like behavior before and after this transition, we may express it using a simple piecewise function
> ```python3
> import jax.numpy as jnp
> def piecewise(x: jnp.ndarray, params: Dict[str, float]) -> jnp.ndarray:
> """Power-law behavior before and after the transition"""
> return jnp.piecewise(
> x, [x < params["t"], x >= params["t"]],
> [lambda x: x**params["beta1"], lambda x: x**params["beta2"]])
> ```
> where ```jnp``` corresponds to jax.numpy module. This function is deterministic. To make it probabilistic, we put priors over its parameters with the help of [NumPyro](https://github.com/pyro-ppl/numpyro)
> ```python3
> import numpyro
> from numpyro import distributions
> def piecewise_priors():
> # Sample model parameters
> t = numpyro.sample("t", distributions.Uniform(0.5, 2.5))
> beta1 = numpyro.sample("beta1", distributions.Normal(3, 1))
> beta2 = numpyro.sample("beta2", distributions.Normal(3, 1))
> # Return sampled parameters as a dictionary
> return {"t": t, "beta1": beta1, "beta2": beta2}
Feel free to close as this is just a discussion post, and no worries if this doesn't fit well within the scope of Ax/BoTorch. Curious to hear your thoughts, if any!
1 Like
AFalk
June 15, 2024, 11:09am
3
You might look into GPytorch, which is the GP backend for Botorch. Depending on your familiarity with GPs it has a bit of a learning curve, but they have several tutorials on deep kernel learning.
https://docs.gpytorch.ai/en/stable/examples/06_PyTorch_NN_Integration_DKL/index.html
https://docs.gpytorch.ai/en/stable/examples/06_PyTorch_NN_Integration_DKL/KISSGP_Deep_Kernel_Regression_CUDA.html
2 Likes
@Utkarsh - there is a GPyTorch version of deep kernel learning, PyTorch NN Integration (Deep Kernel Learning) — GPyTorch 1.12.dev60+g25da2cc documentation , which should be compatible with BOTorch. However, one of the reasons I added DKL to GPax was that I didn’t find the GPyTorch implementation flexible enough or easily customizable. For example, trying to run the exact DKL for convents was quite a painful process. Same with placing priors over the NN weights. Is the problem with jax in particular or is there anything I can add to the current GPax implementation that may help you?
2 Likes
@maxim.ziatdinov
I want to train the DKL with multiple objectives. I noticed that extending the JAX version might not allow me to use the multiobjective-optimizers available in BoTorch. However, as you mentioned, the GPyTorch implementation is not particularly friendly for customization, especially for convolutional networks. I’m a bit confused about how to proceed. Could you provide some thoughts?
1 Like
If the goal is to use an advanced suite of multi-objective optimization tools in Botorch, I’m afraid there’s no other way but to deal with the pain of customizing gpytorch’s DKL models. I won’t have the capacity to add those advanced multi-objective opt tools to GPax for the foreseeable future.
2 Likes
Thank you! I will try customising gpytorch’s DKL model.