# Importing modules
try:
import jax # JAX is a library for differentiable programming
except ModuleNotFoundError:
%pip install jaxlib jax
import jax
import jax.numpy as jnp # JAX's numpy implementation
try:
import tensorflow_probability.substrates.jax as tfp # TFP is a library for probabilistic programming
except ModuleNotFoundError:
%pip install tensorflow-probability
import tensorflow_probability.substrates.jax as tfp
import matplotlib.pyplot as plt
import warnings
import seaborn as sns
from tqdm import trange
import logging
= logging.getLogger()
logger class CheckTypesFilter(logging.Filter):
def filter(self, record):
return "check_types" not in record.getMessage()
logger.addFilter(CheckTypesFilter())
Sampling from the Bernouli distribution with \(\theta\) = 0.7
= tfp.distributions.Bernoulli(
bernoulli_samples =0.7
probs# Create a Bernoulli distribution with p=0.7
) = bernoulli_samples.sample(
samples =100, seed=jax.random.PRNGKey(0)
sample_shape# Sample 100 samples from the distribution
) print(samples)
= 3 # Set the parameter (alpha) of the Beta distribution
alpha = 5 # Set the parameter (beta) of the Beta distribution
beta sum() samples.
[1 0 0 1 1 0 1 1 1 1 1 1 0 0 1 0 1 1 1 1 1 0 1 0 1 1 1 1 0 1 0 1 1 0 1 1 1
1 0 1 1 1 1 0 0 0 1 1 1 0 1 1 0 1 1 1 1 0 1 1 1 0 1 0 0 0 1 1 1 0 1 0 0 1
1 1 1 0 1 1 0 1 1 1 1 1 1 0 1 1 1 1 1 1 0 1 1 0 0 1]
DeviceArray(69, dtype=int32)
Negative log joint
def neg_logjoint(theta): # Define the negative log-joint distribution
= 3
alpha = 5
beta = tfp.distributions.Beta(alpha, beta)
dist_prior = tfp.distributions.Bernoulli(probs=theta)
dist_likelihood return -(dist_prior.log_prob(theta) + dist_likelihood.log_prob(samples).sum())
Calculating \(\theta_{map}\) by minimising the negative log joint using gradient descent
= jax.value_and_grad(
gradient
jax.jit(neg_logjoint)# Define the gradient of the negative log-joint distribution
) = 0.001 # Set the learning rate
lr = 200 # Set the number of epochs
epochs = 0.5 # Set the initial value of theta
theta_map = []
losses for i in trange(epochs): # Run the optimization loop
= gradient(theta_map)
val, grad -= lr * grad
theta_map
losses.append(val)
plt.plot(losses)
sns.despine() theta_map
100%|██████████| 200/200 [00:02<00:00, 71.83it/s]
DeviceArray(0.6698113, dtype=float32, weak_type=True)
Verification of obtained \(\theta_{map}\) value using the formula:
\(\theta_{map} = \frac{n_h+\alpha-1}{n_h+n_t+\alpha+\beta-2}\)
= samples.sum().astype("float32") # Compute the number of heads
nH = (samples.size - nH).astype("float32") # Compute the number of tails
nT = (nH + alpha - 1) / (
theta_check + nT + alpha + beta - 2
nH # Compute the posterior mean
) theta_check
DeviceArray(0.6698113, dtype=float32)
Computing Hessian and Covariance
= jax.hessian(neg_logjoint)(
hessian
theta_map# Compute the Hessian of the negative log-joint distribution
) = jnp.reshape(hessian, (1, 1)) # Reshape the Hessian to a 1x1 matrix
hessian = jnp.linalg.inv(hessian) # Compute the covariance matrix
cov cov
DeviceArray([[0.00208645]], dtype=float32)
Plots Comparing the distribution obtained using Laplace approximation with actual Beta Bernoulli posterior
# Compute the Laplace approximation
= jnp.linspace(0, 1, 100) # Create a grid of 100 points between 0 and 1
x = x.reshape(-1, 1) # Reshape the grid to a 100x1 matrix
x = tfp.distributions.MultivariateNormalFullCovariance( # Create a multivariate normal distribution
Laplace_Approx =theta_map, covariance_matrix=cov
loc
)= Laplace_Approx.prob(
Laplace_Approx_pdf
x# Compute the probability density function of the Laplace approximation
) ="Laplace Approximation")
plt.plot(x, Laplace_Approx_pdf, label
# Compute the true posterior distribution
= 3
alpha = 5
beta = tfp.distributions.Beta(
true_posterior + nH, beta + nT
alpha # Create a Beta distribution
) = true_posterior.prob(
true_posterior_pdf
x# Compute the probability density function of the true posterior
) ="True Posterior")
plt.plot(x, true_posterior_pdf, label0, 1)
plt.xlim( plt.legend()
<matplotlib.legend.Legend at 0x7f2181895a10>
# Compute the log-probability density function of the Laplace approximation
= true_posterior.log_prob(x)
true_posterior_pdf_log = Laplace_Approx.log_prob(x)
Laplace_Approx_pdf_log ="Laplace Approximation")
plt.plot(x, Laplace_Approx_pdf_log, label="True Posterior")
plt.plot(x, true_posterior_pdf_log, label plt.legend()
<matplotlib.legend.Legend at 0x7f2180719250>