Laplace Approximation

Implementing Laplace Approximation in JAX
Laplace Approximation
ML
Author

Madhav Kanda

Published

August 12, 2022

# Importing modules
try:
  import jax                      # JAX is a library for differentiable programming
except ModuleNotFoundError:
  %pip install jaxlib jax
  import jax
import jax.numpy as jnp           # JAX's numpy implementation
try:
  import tensorflow_probability.substrates.jax as tfp     # TFP is a library for probabilistic programming
except ModuleNotFoundError:
  %pip install tensorflow-probability
  import tensorflow_probability.substrates.jax as tfp
import matplotlib.pyplot as plt
import warnings
import seaborn as sns
from tqdm import trange
import logging
logger = logging.getLogger()
class CheckTypesFilter(logging.Filter):                   
    def filter(self, record):
        return "check_types" not in record.getMessage()


logger.addFilter(CheckTypesFilter())

Sampling from the Bernouli distribution with \(\theta\) = 0.7

bernoulli_samples = tfp.distributions.Bernoulli(
    probs=0.7
)  # Create a Bernoulli distribution with p=0.7
samples = bernoulli_samples.sample(
    sample_shape=100, seed=jax.random.PRNGKey(0)
)  # Sample 100 samples from the distribution
print(samples)

alpha = 3  # Set the parameter (alpha) of the Beta distribution
beta = 5  # Set the parameter (beta) of the Beta distribution
samples.sum()
[1 0 0 1 1 0 1 1 1 1 1 1 0 0 1 0 1 1 1 1 1 0 1 0 1 1 1 1 0 1 0 1 1 0 1 1 1
 1 0 1 1 1 1 0 0 0 1 1 1 0 1 1 0 1 1 1 1 0 1 1 1 0 1 0 0 0 1 1 1 0 1 0 0 1
 1 1 1 0 1 1 0 1 1 1 1 1 1 0 1 1 1 1 1 1 0 1 1 0 0 1]
DeviceArray(69, dtype=int32)

Negative log joint

def neg_logjoint(theta):  # Define the negative log-joint distribution
    alpha = 3
    beta = 5
    dist_prior = tfp.distributions.Beta(alpha, beta)
    dist_likelihood = tfp.distributions.Bernoulli(probs=theta)
    return -(dist_prior.log_prob(theta) + dist_likelihood.log_prob(samples).sum())

Calculating \(\theta_{map}\) by minimising the negative log joint using gradient descent

gradient = jax.value_and_grad(
    jax.jit(neg_logjoint)
)  # Define the gradient of the negative log-joint distribution
lr = 0.001  # Set the learning rate
epochs = 200  # Set the number of epochs
theta_map = 0.5  # Set the initial value of theta
losses = []
for i in trange(epochs):  # Run the optimization loop
    val, grad = gradient(theta_map)
    theta_map -= lr * grad
    losses.append(val)
plt.plot(losses)
sns.despine()
theta_map
100%|██████████| 200/200 [00:02<00:00, 71.83it/s] 
DeviceArray(0.6698113, dtype=float32, weak_type=True)

Verification of obtained \(\theta_{map}\) value using the formula:

\(\theta_{map} = \frac{n_h+\alpha-1}{n_h+n_t+\alpha+\beta-2}\)

nH = samples.sum().astype("float32")  # Compute the number of heads
nT = (samples.size - nH).astype("float32")  # Compute the number of tails
theta_check = (nH + alpha - 1) / (
    nH + nT + alpha + beta - 2
)  # Compute the posterior mean
theta_check
DeviceArray(0.6698113, dtype=float32)

Computing Hessian and Covariance

hessian = jax.hessian(neg_logjoint)(
    theta_map
)  # Compute the Hessian of the negative log-joint distribution
hessian = jnp.reshape(hessian, (1, 1))  # Reshape the Hessian to a 1x1 matrix
cov = jnp.linalg.inv(hessian)  # Compute the covariance matrix
cov
DeviceArray([[0.00208645]], dtype=float32)

Plots Comparing the distribution obtained using Laplace approximation with actual Beta Bernoulli posterior

# Compute the Laplace approximation
x = jnp.linspace(0, 1, 100)  # Create a grid of 100 points between 0 and 1
x = x.reshape(-1, 1)  # Reshape the grid to a 100x1 matrix
Laplace_Approx = tfp.distributions.MultivariateNormalFullCovariance(  # Create a multivariate normal distribution
    loc=theta_map, covariance_matrix=cov
)
Laplace_Approx_pdf = Laplace_Approx.prob(
    x
)  # Compute the probability density function of the Laplace approximation
plt.plot(x, Laplace_Approx_pdf, label="Laplace Approximation")


# Compute the true posterior distribution
alpha = 3
beta = 5
true_posterior = tfp.distributions.Beta(
    alpha + nH, beta + nT
)  # Create a Beta distribution
true_posterior_pdf = true_posterior.prob(
    x
)  # Compute the probability density function of the true posterior
plt.plot(x, true_posterior_pdf, label="True Posterior")
plt.xlim(0, 1)
plt.legend()
<matplotlib.legend.Legend at 0x7f2181895a10>

# Compute the log-probability density function of the Laplace approximation
true_posterior_pdf_log = true_posterior.log_prob(x)
Laplace_Approx_pdf_log = Laplace_Approx.log_prob(x)
plt.plot(x, Laplace_Approx_pdf_log, label="Laplace Approximation")
plt.plot(x, true_posterior_pdf_log, label="True Posterior")
plt.legend()
<matplotlib.legend.Legend at 0x7f2180719250>