%pip -qq install numpyro
%pip -qq install ucimlrepo
Variationally Inferred Parameterization
Tutorial also hosted on Numpyro
Occasionally, the Hamiltonian Monte Carlo (HMC) sampler encounters challenges in effectively sampling from the posterior distribution. One illustrative case is Neal’s funnel. In these situations, the conventional centered parameterization may prove inadequate, leading us to employ non-centered parameterization. However, there are instances where even non-centered parameterization may not suffice, necessitating the utilization of Variationally Inferred Parameterization to attain the desired centeredness within the range of 0 to 1.
The purpose of this tutorial is to implement Variationally Inferred Parameterization based on Automatic Reparameterization of Probabilistic Programs using LocScaleReparam in Numpyro.
import jax
import numpyro
import arviz as az
import numpy as np
import pandas as pd
import jax.numpy as jnp
from numpyro.infer import MCMC, NUTS
import numpyro.distributions as dist
from ucimlrepo import fetch_ucirepo
= jax.random.PRNGKey(0)
rng_key from numpyro.infer.reparam import LocScaleReparam
from numpyro.infer import SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoDiagonalNormal
1. Dataset
We will be using the German Credit Dataset for this illustration. The dataset consists of 1000 entries with 20 categorial symbolic attributes prepared by Prof. Hofmann. In this dataset, each entry represents a person who takes a credit by a bank. Each person is classified as good or bad credit risks according to the set of attributes.
def load_german_credit():
= fetch_ucirepo(id=144)
statlog_german_credit_data = statlog_german_credit_data.data.features
X = statlog_german_credit_data.data.targets
y return X, y
= load_german_credit()
X, y X
Attribute1 | Attribute2 | Attribute3 | Attribute4 | Attribute5 | Attribute6 | Attribute7 | Attribute8 | Attribute9 | Attribute10 | Attribute11 | Attribute12 | Attribute13 | Attribute14 | Attribute15 | Attribute16 | Attribute17 | Attribute18 | Attribute19 | Attribute20 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | A11 | 6 | A34 | A43 | 1169 | A65 | A75 | 4 | A93 | A101 | 4 | A121 | 67 | A143 | A152 | 2 | A173 | 1 | A192 | A201 |
1 | A12 | 48 | A32 | A43 | 5951 | A61 | A73 | 2 | A92 | A101 | 2 | A121 | 22 | A143 | A152 | 1 | A173 | 1 | A191 | A201 |
2 | A14 | 12 | A34 | A46 | 2096 | A61 | A74 | 2 | A93 | A101 | 3 | A121 | 49 | A143 | A152 | 1 | A172 | 2 | A191 | A201 |
3 | A11 | 42 | A32 | A42 | 7882 | A61 | A74 | 2 | A93 | A103 | 4 | A122 | 45 | A143 | A153 | 1 | A173 | 2 | A191 | A201 |
4 | A11 | 24 | A33 | A40 | 4870 | A61 | A73 | 3 | A93 | A101 | 4 | A124 | 53 | A143 | A153 | 2 | A173 | 2 | A191 | A201 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
995 | A14 | 12 | A32 | A42 | 1736 | A61 | A74 | 3 | A92 | A101 | 4 | A121 | 31 | A143 | A152 | 1 | A172 | 1 | A191 | A201 |
996 | A11 | 30 | A32 | A41 | 3857 | A61 | A73 | 4 | A91 | A101 | 4 | A122 | 40 | A143 | A152 | 1 | A174 | 1 | A192 | A201 |
997 | A14 | 12 | A32 | A43 | 804 | A61 | A75 | 4 | A93 | A101 | 4 | A123 | 38 | A143 | A152 | 1 | A173 | 1 | A191 | A201 |
998 | A11 | 45 | A32 | A43 | 1845 | A61 | A73 | 4 | A93 | A101 | 4 | A124 | 23 | A143 | A153 | 1 | A173 | 1 | A192 | A201 |
999 | A12 | 45 | A34 | A41 | 4576 | A62 | A71 | 3 | A93 | A101 | 4 | A123 | 27 | A143 | A152 | 1 | A173 | 1 | A191 | A201 |
1000 rows × 20 columns
Here, X depicts 20 attributes and the values corresponding to these attributes for each person represented in the data entry and y is the output variable corresponding to these attributes
def data_transform(X, y):
def categorical_to_int(x):
= {u: i for i, u in enumerate(np.unique(x))}
d return np.array([d[i] for i in x])
= []
categoricals = []
numericals len(y)]))
numericals.append(np.ones([for column in X:
= X[column]
column if column.dtype == "O":
categoricals.append(categorical_to_int(column))else:
- column.mean()) / column.std())
numericals.append((column = np.array(numericals).T
numericals = np.array(y == 1, dtype=np.int32)
status = np.squeeze(status)
status
return jnp.array(numericals), jnp.array(categoricals), jnp.array(status)
Data transformation for feeding it into the Numpyro model
= data_transform(X, y) numericals, categoricals, status
= numericals.astype(jnp.float32)
x_numeric = [jnp.eye(c.max() + 1)[c] for c in categoricals]
x_categorical = jnp.concatenate([x_numeric] + x_categorical, axis=1)
all_x = all_x.shape[1]
num_features = status[jnp.newaxis, Ellipsis] y
2. Model
We will be using a logistic regression model with hierarchical prior on coefficient scales
\[\begin{aligned} \log \tau_0 & \sim \mathcal{N}(0,10) & \log \tau_i & \sim \mathcal{N}\left(\log \tau_0, 1\right) \\ \beta_i & \sim \mathcal{N}\left(0, \tau_i\right) & y & \sim \operatorname{Bernoulli}\left(\sigma\left(\beta X^T\right)\right) \end{aligned}\]def german_credit():
= numpyro.sample("log_tau_zero", dist.Normal(0, 10))
log_tau_zero = numpyro.sample(
log_tau_i "log_tau_i", dist.Normal(log_tau_zero, jnp.ones(num_features))
)= numpyro.sample(
beta "beta", dist.Normal(jnp.zeros(num_features), jnp.exp(log_tau_i))
)
numpyro.sample("obs",
=jnp.einsum("nd,md->mn", all_x, beta[jnp.newaxis, :])),
dist.Bernoulli(logits=y,
obs )
= NUTS(german_credit)
nuts_kernel = MCMC(nuts_kernel, num_warmup=1000, num_samples=1000)
mcmc =("num_steps",)) mcmc.run(rng_key, extra_fields
sample: 100%|██████████| 2000/2000 [00:21<00:00, 94.07it/s, 63 steps of size 6.31e-02. acc. prob=0.87]
mcmc.print_summary()
mean std median 5.0% 95.0% n_eff r_hat
beta[0] 0.13 0.38 0.05 -0.36 0.74 284.06 1.00
beta[1] -0.34 0.12 -0.34 -0.52 -0.15 621.55 1.00
beta[2] -0.27 0.13 -0.27 -0.45 -0.03 542.13 1.00
beta[3] -0.30 0.10 -0.30 -0.44 -0.11 566.55 1.00
beta[4] -0.00 0.07 -0.00 -0.12 0.11 782.35 1.00
beta[5] 0.12 0.09 0.11 -0.02 0.27 728.28 1.01
beta[6] -0.08 0.08 -0.07 -0.22 0.05 822.89 1.00
beta[7] -0.05 0.07 -0.04 -0.19 0.05 752.66 1.00
beta[8] -0.42 0.32 -0.39 -0.87 0.05 198.00 1.00
beta[9] -0.07 0.26 -0.02 -0.50 0.31 220.27 1.00
beta[10] 0.26 0.31 0.18 -0.15 0.78 404.97 1.00
beta[11] 1.23 0.34 1.25 0.68 1.79 227.34 1.01
beta[12] -0.26 0.34 -0.17 -0.81 0.22 349.10 1.00
beta[13] -0.30 0.34 -0.21 -0.86 0.13 387.72 1.00
beta[14] 0.07 0.20 0.04 -0.26 0.38 240.45 1.03
beta[15] 0.10 0.22 0.05 -0.18 0.50 287.41 1.02
beta[16] 0.76 0.30 0.76 0.22 1.24 364.73 1.03
beta[17] -0.53 0.28 -0.55 -0.94 -0.05 269.95 1.00
beta[18] 0.70 0.42 0.70 -0.02 1.29 367.28 1.00
beta[19] 0.17 0.40 0.06 -0.43 0.77 333.54 1.00
beta[20] 0.03 0.19 0.01 -0.23 0.39 381.57 1.00
beta[21] 0.18 0.22 0.13 -0.14 0.53 335.48 1.00
beta[22] -0.05 0.32 -0.01 -0.56 0.46 439.54 1.00
beta[23] -0.10 0.30 -0.04 -0.63 0.30 508.20 1.00
beta[24] -0.34 0.36 -0.25 -0.94 0.12 283.15 1.00
beta[25] 0.14 0.40 0.04 -0.46 0.71 433.69 1.00
beta[26] -0.01 0.19 -0.00 -0.34 0.28 438.64 1.00
beta[27] -0.36 0.27 -0.33 -0.78 0.04 377.33 1.01
beta[28] -0.07 0.22 -0.03 -0.43 0.26 493.09 1.00
beta[29] 0.01 0.22 0.00 -0.32 0.34 448.21 1.00
beta[30] 0.35 0.43 0.22 -0.18 1.08 314.69 1.00
beta[31] 0.41 0.33 0.40 -0.10 0.90 402.62 1.00
beta[32] -0.03 0.21 -0.01 -0.39 0.30 525.23 1.00
beta[33] -0.12 0.18 -0.09 -0.41 0.16 334.94 1.00
beta[34] -0.02 0.16 -0.01 -0.24 0.26 318.25 1.00
beta[35] 0.42 0.27 0.42 -0.04 0.81 455.99 1.00
beta[36] 0.05 0.17 0.03 -0.18 0.35 506.34 1.00
beta[37] -0.12 0.25 -0.06 -0.57 0.21 470.11 1.00
beta[38] -0.07 0.20 -0.04 -0.39 0.24 410.71 1.00
beta[39] 0.36 0.24 0.35 -0.04 0.71 359.55 1.00
beta[40] 0.05 0.20 0.02 -0.29 0.35 441.70 1.00
beta[41] -0.00 0.21 0.00 -0.34 0.37 513.67 1.00
beta[42] -0.13 0.27 -0.08 -0.59 0.23 402.64 1.00
beta[43] 0.55 0.46 0.49 -0.11 1.28 570.74 1.00
beta[44] 0.19 0.21 0.15 -0.14 0.50 379.76 1.00
beta[45] -0.00 0.16 0.00 -0.25 0.26 352.19 1.00
beta[46] 0.01 0.16 0.01 -0.25 0.25 411.05 1.00
beta[47] -0.16 0.24 -0.11 -0.55 0.18 455.59 1.00
beta[48] -0.12 0.24 -0.07 -0.55 0.21 322.67 1.04
beta[49] -0.04 0.23 -0.02 -0.45 0.30 437.47 1.02
beta[50] 0.38 0.28 0.37 -0.03 0.82 266.19 1.04
beta[51] -0.14 0.22 -0.09 -0.52 0.16 406.31 1.00
beta[52] 0.19 0.23 0.14 -0.14 0.55 338.97 1.00
beta[53] 0.04 0.22 0.02 -0.23 0.43 438.03 1.00
beta[54] 0.05 0.24 0.02 -0.32 0.41 522.43 1.00
beta[55] 0.02 0.14 0.01 -0.22 0.23 562.00 1.00
beta[56] -0.01 0.13 -0.01 -0.24 0.21 638.20 1.00
beta[57] 0.01 0.17 0.00 -0.25 0.34 590.99 1.00
beta[58] -0.07 0.18 -0.04 -0.34 0.23 481.37 1.00
beta[59] 0.13 0.19 0.09 -0.12 0.47 507.56 1.00
beta[60] -0.14 0.33 -0.06 -0.64 0.37 303.00 1.00
beta[61] 0.48 0.56 0.32 -0.18 1.41 438.86 1.00
log_tau_i[0] -1.51 0.95 -1.52 -3.03 0.11 290.78 1.00
log_tau_i[1] -1.07 0.67 -1.11 -2.12 0.03 641.04 1.00
log_tau_i[2] -1.24 0.76 -1.26 -2.47 0.03 666.31 1.00
log_tau_i[3] -1.16 0.65 -1.19 -2.20 -0.10 821.60 1.00
log_tau_i[4] -2.11 0.88 -2.13 -3.50 -0.61 806.15 1.00
log_tau_i[5] -1.71 0.86 -1.68 -3.28 -0.44 697.00 1.00
log_tau_i[6] -1.88 0.84 -1.91 -3.30 -0.58 623.56 1.00
log_tau_i[7] -1.99 0.90 -1.98 -3.51 -0.65 710.21 1.00
log_tau_i[8] -1.00 0.86 -0.96 -2.23 0.52 445.30 1.00
log_tau_i[9] -1.69 0.93 -1.63 -3.17 -0.14 326.33 1.00
log_tau_i[10] -1.41 0.95 -1.35 -2.93 0.19 441.60 1.01
log_tau_i[11] -0.11 0.57 -0.12 -0.97 0.80 539.60 1.00
log_tau_i[12] -1.36 0.96 -1.31 -3.16 0.01 336.11 1.00
log_tau_i[13] -1.30 0.95 -1.26 -2.85 0.28 335.04 1.00
log_tau_i[14] -1.72 0.89 -1.70 -3.05 -0.25 584.38 1.00
log_tau_i[15] -1.65 0.92 -1.63 -3.07 -0.10 345.77 1.03
log_tau_i[16] -0.51 0.65 -0.49 -1.42 0.59 676.64 1.00
log_tau_i[17] -0.84 0.76 -0.76 -2.09 0.34 303.14 1.00
log_tau_i[18] -0.69 0.82 -0.59 -2.03 0.61 359.35 1.00
log_tau_i[19] -1.45 0.99 -1.42 -2.97 0.25 397.18 1.00
log_tau_i[20] -1.75 0.94 -1.73 -3.39 -0.40 617.54 1.00
log_tau_i[21] -1.51 0.88 -1.49 -3.16 -0.27 488.52 1.00
log_tau_i[22] -1.56 0.93 -1.56 -3.06 -0.10 348.20 1.00
log_tau_i[23] -1.58 0.94 -1.57 -3.05 0.02 278.69 1.00
log_tau_i[24] -1.26 1.00 -1.12 -2.91 0.29 205.38 1.00
log_tau_i[25] -1.53 0.95 -1.56 -3.11 0.02 351.09 1.00
log_tau_i[26] -1.73 0.91 -1.74 -3.17 -0.22 492.18 1.00
log_tau_i[27] -1.15 0.89 -1.08 -2.66 0.17 485.34 1.00
log_tau_i[28] -1.69 0.92 -1.65 -3.15 -0.19 425.75 1.00
log_tau_i[29] -1.71 0.99 -1.71 -3.19 0.01 374.58 1.00
log_tau_i[30] -1.24 0.99 -1.20 -2.74 0.50 327.47 1.00
log_tau_i[31] -1.02 0.89 -0.91 -2.40 0.51 587.85 1.00
log_tau_i[32] -1.71 0.94 -1.70 -3.22 -0.11 511.74 1.00
log_tau_i[33] -1.69 0.90 -1.68 -3.13 -0.28 538.65 1.00
log_tau_i[34] -1.82 0.92 -1.81 -3.35 -0.35 423.01 1.00
log_tau_i[35] -1.06 0.82 -1.00 -2.30 0.34 470.50 1.00
log_tau_i[36] -1.79 0.87 -1.76 -3.15 -0.34 527.47 1.00
log_tau_i[37] -1.58 0.95 -1.54 -3.11 0.04 485.52 1.00
log_tau_i[38] -1.71 0.87 -1.65 -3.18 -0.34 482.67 1.00
log_tau_i[39] -1.12 0.85 -1.01 -2.44 0.33 337.59 1.00
log_tau_i[40] -1.76 0.96 -1.73 -3.58 -0.36 533.15 1.00
log_tau_i[41] -1.74 0.94 -1.70 -3.26 -0.22 500.91 1.00
log_tau_i[42] -1.57 0.95 -1.54 -3.04 0.01 499.44 1.00
log_tau_i[43] -0.87 0.93 -0.74 -2.28 0.58 445.98 1.00
log_tau_i[44] -1.52 0.89 -1.45 -2.95 -0.13 442.63 1.00
log_tau_i[45] -1.84 0.94 -1.79 -3.21 -0.09 673.31 1.00
log_tau_i[46] -1.82 0.85 -1.83 -3.26 -0.56 579.66 1.00
log_tau_i[47] -1.54 0.90 -1.51 -3.35 -0.30 428.50 1.00
log_tau_i[48] -1.62 0.89 -1.60 -3.00 -0.15 413.30 1.01
log_tau_i[49] -1.71 0.95 -1.68 -3.23 -0.13 514.04 1.00
log_tau_i[50] -1.12 0.92 -0.99 -2.67 0.38 206.76 1.03
log_tau_i[51] -1.61 0.92 -1.58 -3.07 -0.03 477.41 1.00
log_tau_i[52] -1.54 0.90 -1.49 -2.96 -0.09 459.83 1.00
log_tau_i[53] -1.74 0.92 -1.69 -3.13 -0.14 509.51 1.00
log_tau_i[54] -1.68 0.95 -1.67 -3.07 0.10 477.21 1.00
log_tau_i[55] -1.87 0.97 -1.88 -3.49 -0.35 514.38 1.00
log_tau_i[56] -1.87 0.96 -1.84 -3.23 -0.12 574.80 1.00
log_tau_i[57] -1.77 0.86 -1.72 -3.26 -0.36 646.10 1.00
log_tau_i[58] -1.78 0.92 -1.77 -3.18 -0.15 617.59 1.00
log_tau_i[59] -1.67 0.93 -1.61 -3.19 -0.21 510.74 1.00
log_tau_i[60] -1.50 0.99 -1.44 -3.09 0.08 386.86 1.00
log_tau_i[61] -1.09 1.06 -1.00 -2.79 0.52 421.27 1.00
log_tau_zero -1.49 0.26 -1.49 -1.90 -1.05 169.88 1.00
Number of divergences: 37
From mcmc.print_summary it is evident that there are 37 divergences. Thus, we will use Variationally Inferred Parameterization (VIP) to reduce these divergences
= az.from_numpyro(mcmc)
data =True); az.plot_trace(data, compact
3. Reparameterization
We introduce a parameterization parameters \(\lambda \in [0,1]\) for any variable \(z\), and transform:
=> \(z\) ~ \(N (z | μ, σ)\)
=> by defining \(z\) ~ \(N(λμ, σ^λ)\)
=> \(z\) = \(μ + σ^{1-λ}(z - λμ)\).
Thus, using the above transformation the joint density can be transformed as follows: \[\begin{aligned} p(\theta, \hat{\mu}, \mathbf{y}) & =\mathcal{N}(\theta \mid 0,1) \times \mathcal{N}\left(\mu \mid \theta, \sigma_\mu\right) \times \mathcal{N}(\mathbf{y} \mid \mu, \sigma) \end{aligned}\] \[\begin{aligned} p(\theta, \hat{\mu}, \mathbf{y}) & =\mathcal{N}(\theta \mid 0,1) \times \mathcal{N}\left(\hat{\mu} \mid \lambda \theta, \sigma_\mu^\lambda\right) \times \mathcal{N}\left(\mathbf{y} \mid \theta+\sigma_\mu^{1-\lambda}(\hat{\mu}-\lambda \theta), \sigma\right) \end{aligned}\]def german_credit_reparam(beta_centeredness=None):
def model():
= numpyro.sample("log_tau_zero", dist.Normal(0, 10))
log_tau_zero = numpyro.sample(
log_tau_i "log_tau_i", dist.Normal(log_tau_zero, jnp.ones(num_features))
)with numpyro.handlers.reparam(
={"beta": LocScaleReparam(beta_centeredness)}
config
):= numpyro.sample(
beta "beta", dist.Normal(jnp.zeros(num_features), jnp.exp(log_tau_i))
)
numpyro.sample("obs",
=jnp.einsum("nd,md->mn", all_x, beta[jnp.newaxis, :])),
dist.Bernoulli(logits=y,
obs
)
return model
Now, using SVI we optimize \(\lambda\).
= german_credit_reparam()
model = AutoDiagonalNormal(model)
guide = SVI(model, guide, numpyro.optim.Adam(3e-4), Trace_ELBO(10))
svi = svi.run(rng_key, 10000) svi_results
100%|██████████| 10000/10000 [00:16<00:00, 588.87it/s, init loss: 2165.2424, avg. loss [9501-10000]: 576.7846]
= german_credit_reparam(
reparam_model =svi_results.params["beta_centered"]
beta_centeredness )
= NUTS(reparam_model)
nuts_kernel = MCMC(nuts_kernel, num_warmup=1000, num_samples=1000)
mcmc_reparam =("num_steps",)) mcmc_reparam.run(rng_key, extra_fields
sample: 100%|██████████| 2000/2000 [00:07<00:00, 285.41it/s, 31 steps of size 1.28e-01. acc. prob=0.89]
mcmc_reparam.print_summary()
mean std median 5.0% 95.0% n_eff r_hat
beta_decentered[0] 0.12 0.40 0.06 -0.48 0.80 338.70 1.00
beta_decentered[1] -0.45 0.15 -0.45 -0.70 -0.21 791.23 1.00
beta_decentered[2] -0.38 0.17 -0.38 -0.65 -0.09 691.79 1.00
beta_decentered[3] -0.41 0.13 -0.41 -0.61 -0.19 1022.79 1.00
beta_decentered[4] -0.01 0.11 -0.01 -0.18 0.20 1176.84 1.00
beta_decentered[5] 0.19 0.14 0.19 -0.04 0.41 1194.41 1.00
beta_decentered[6] -0.13 0.14 -0.13 -0.36 0.09 1227.24 1.00
beta_decentered[7] -0.07 0.12 -0.06 -0.24 0.14 1096.31 1.00
beta_decentered[8] -0.46 0.34 -0.46 -0.99 0.08 330.30 1.00
beta_decentered[9] -0.03 0.32 -0.02 -0.57 0.49 310.35 1.00
beta_decentered[10] 0.35 0.39 0.30 -0.26 1.00 426.11 1.00
beta_decentered[11] 1.29 0.31 1.30 0.81 1.82 433.16 1.00
beta_decentered[12] -0.32 0.39 -0.25 -0.96 0.24 521.05 1.00
beta_decentered[13] -0.38 0.40 -0.32 -1.00 0.24 410.05 1.00
beta_decentered[14] 0.08 0.28 0.06 -0.37 0.57 457.72 1.00
beta_decentered[15] 0.14 0.30 0.10 -0.28 0.66 612.31 1.00
beta_decentered[16] 0.85 0.31 0.86 0.41 1.45 432.14 1.00
beta_decentered[17] -0.64 0.28 -0.65 -1.05 -0.14 523.15 1.00
beta_decentered[18] 0.78 0.42 0.78 0.07 1.46 545.52 1.00
beta_decentered[19] 0.15 0.39 0.08 -0.50 0.80 662.60 1.00
beta_decentered[20] 0.04 0.25 0.03 -0.39 0.40 445.85 1.00
beta_decentered[21] 0.24 0.27 0.21 -0.20 0.65 477.68 1.00
beta_decentered[22] -0.03 0.38 -0.01 -0.64 0.60 984.59 1.00
beta_decentered[23] -0.13 0.34 -0.08 -0.72 0.35 702.87 1.00
beta_decentered[24] -0.41 0.39 -0.37 -1.08 0.13 603.13 1.00
beta_decentered[25] 0.19 0.47 0.09 -0.48 0.92 529.68 1.00
beta_decentered[26] 0.00 0.25 0.01 -0.47 0.35 690.54 1.00
beta_decentered[27] -0.46 0.31 -0.46 -0.95 0.04 464.44 1.00
beta_decentered[28] -0.09 0.30 -0.06 -0.56 0.41 464.65 1.00
beta_decentered[29] 0.02 0.30 0.01 -0.47 0.52 747.44 1.00
beta_decentered[30] 0.38 0.44 0.31 -0.30 1.05 717.12 1.00
beta_decentered[31] 0.47 0.36 0.47 -0.09 1.03 564.18 1.00
beta_decentered[32] -0.03 0.26 -0.02 -0.44 0.44 572.03 1.00
beta_decentered[33] -0.17 0.25 -0.15 -0.63 0.19 713.40 1.00
beta_decentered[34] -0.02 0.21 -0.01 -0.39 0.32 620.45 1.01
beta_decentered[35] 0.53 0.31 0.55 -0.03 1.00 681.60 1.00
beta_decentered[36] 0.09 0.24 0.06 -0.27 0.49 610.06 1.00
beta_decentered[37] -0.14 0.31 -0.10 -0.74 0.28 826.87 1.00
beta_decentered[38] -0.12 0.25 -0.11 -0.53 0.30 493.49 1.00
beta_decentered[39] 0.44 0.28 0.44 -0.01 0.89 542.71 1.00
beta_decentered[40] 0.05 0.26 0.03 -0.39 0.45 709.78 1.00
beta_decentered[41] 0.02 0.30 0.01 -0.52 0.46 389.41 1.00
beta_decentered[42] -0.15 0.33 -0.10 -0.74 0.32 607.05 1.00
beta_decentered[43] 0.66 0.47 0.65 -0.11 1.38 539.31 1.01
beta_decentered[44] 0.25 0.27 0.23 -0.19 0.66 686.63 1.00
beta_decentered[45] -0.01 0.22 -0.00 -0.36 0.35 909.70 1.00
beta_decentered[46] 0.02 0.22 0.01 -0.33 0.39 741.33 1.00
beta_decentered[47] -0.25 0.31 -0.21 -0.72 0.25 487.26 1.00
beta_decentered[48] -0.18 0.30 -0.15 -0.67 0.30 400.57 1.00
beta_decentered[49] -0.07 0.30 -0.05 -0.56 0.41 594.38 1.00
beta_decentered[50] 0.46 0.31 0.46 -0.09 0.88 384.33 1.00
beta_decentered[51] -0.23 0.27 -0.19 -0.69 0.16 561.35 1.00
beta_decentered[52] 0.22 0.25 0.21 -0.19 0.60 456.41 1.00
beta_decentered[53] 0.07 0.29 0.04 -0.41 0.53 500.99 1.00
beta_decentered[54] 0.05 0.30 0.02 -0.49 0.51 896.08 1.00
beta_decentered[55] 0.01 0.23 0.01 -0.34 0.43 1033.13 1.00
beta_decentered[56] -0.02 0.19 -0.01 -0.33 0.28 717.87 1.00
beta_decentered[57] 0.01 0.23 0.01 -0.33 0.41 684.61 1.00
beta_decentered[58] -0.09 0.26 -0.08 -0.55 0.30 455.66 1.01
beta_decentered[59] 0.20 0.27 0.18 -0.20 0.63 413.57 1.01
beta_decentered[60] -0.14 0.37 -0.09 -0.76 0.46 411.83 1.00
beta_decentered[61] 0.58 0.57 0.50 -0.23 1.50 495.86 1.00
log_tau_i[0] -1.56 0.91 -1.53 -2.94 -0.01 480.69 1.00
log_tau_i[1] -1.07 0.64 -1.08 -1.99 0.06 950.02 1.00
log_tau_i[2] -1.25 0.74 -1.23 -2.36 -0.06 796.37 1.00
log_tau_i[3] -1.15 0.65 -1.18 -2.25 -0.14 838.40 1.00
log_tau_i[4] -2.09 0.90 -2.12 -3.51 -0.49 1120.31 1.00
log_tau_i[5] -1.70 0.84 -1.69 -3.10 -0.35 1291.74 1.00
log_tau_i[6] -1.87 0.92 -1.84 -3.67 -0.58 1066.38 1.00
log_tau_i[7] -2.06 0.89 -2.07 -3.42 -0.42 641.48 1.00
log_tau_i[8] -1.11 0.91 -1.03 -2.79 0.19 482.69 1.00
log_tau_i[9] -1.65 0.91 -1.62 -3.05 -0.04 772.73 1.00
log_tau_i[10] -1.31 0.94 -1.27 -2.77 0.24 578.34 1.00
log_tau_i[11] -0.04 0.49 -0.08 -0.81 0.77 736.20 1.00
log_tau_i[12] -1.37 0.98 -1.32 -2.86 0.31 623.14 1.00
log_tau_i[13] -1.28 0.97 -1.21 -2.78 0.30 653.51 1.00
log_tau_i[14] -1.70 0.95 -1.71 -3.20 -0.15 831.88 1.00
log_tau_i[15] -1.67 0.97 -1.65 -3.15 0.04 726.23 1.00
log_tau_i[16] -0.53 0.65 -0.49 -1.55 0.54 558.56 1.00
log_tau_i[17] -0.81 0.68 -0.80 -1.91 0.27 655.36 1.00
log_tau_i[18] -0.71 0.83 -0.60 -1.89 0.75 536.55 1.00
log_tau_i[19] -1.55 0.98 -1.53 -3.17 -0.02 675.27 1.00
log_tau_i[20] -1.81 0.90 -1.79 -3.36 -0.41 823.93 1.00
log_tau_i[21] -1.53 0.93 -1.51 -3.03 -0.04 767.64 1.00
log_tau_i[22] -1.60 0.99 -1.58 -3.25 -0.01 840.71 1.00
log_tau_i[23] -1.64 0.97 -1.60 -3.15 0.10 615.62 1.00
log_tau_i[24] -1.21 0.96 -1.11 -2.61 0.47 674.58 1.00
log_tau_i[25] -1.54 1.04 -1.49 -3.39 -0.01 509.69 1.00
log_tau_i[26] -1.77 0.95 -1.74 -3.27 -0.13 915.91 1.00
log_tau_i[27] -1.14 0.89 -1.07 -2.53 0.36 569.10 1.01
log_tau_i[28] -1.72 0.93 -1.66 -3.25 -0.20 911.70 1.01
log_tau_i[29] -1.70 0.91 -1.71 -3.16 -0.14 648.29 1.00
log_tau_i[30] -1.28 1.02 -1.28 -2.99 0.34 575.90 1.00
log_tau_i[31] -1.10 0.86 -1.04 -2.39 0.45 730.21 1.00
log_tau_i[32] -1.79 0.95 -1.82 -3.32 -0.19 667.95 1.00
log_tau_i[33] -1.64 0.86 -1.62 -3.04 -0.32 948.64 1.00
log_tau_i[34] -1.87 0.92 -1.88 -3.29 -0.25 909.49 1.00
log_tau_i[35] -1.00 0.85 -0.95 -2.44 0.25 672.42 1.00
log_tau_i[36] -1.76 0.92 -1.73 -3.19 -0.22 889.56 1.00
log_tau_i[37] -1.63 1.00 -1.58 -3.11 0.12 973.29 1.00
log_tau_i[38] -1.72 0.85 -1.70 -3.13 -0.30 837.73 1.00
log_tau_i[39] -1.15 0.80 -1.13 -2.32 0.18 627.15 1.00
log_tau_i[40] -1.76 0.93 -1.70 -3.36 -0.34 686.88 1.00
log_tau_i[41] -1.75 0.96 -1.74 -3.20 -0.09 612.83 1.00
log_tau_i[42] -1.62 0.91 -1.62 -3.21 -0.25 596.18 1.00
log_tau_i[43] -0.79 0.93 -0.74 -2.20 0.83 560.52 1.01
log_tau_i[44] -1.52 0.94 -1.48 -3.08 0.01 888.27 1.00
log_tau_i[45] -1.83 0.90 -1.84 -3.27 -0.44 1122.53 1.00
log_tau_i[46] -1.83 0.89 -1.78 -3.21 -0.31 990.03 1.00
log_tau_i[47] -1.56 0.93 -1.48 -3.20 -0.21 736.01 1.00
log_tau_i[48] -1.59 0.94 -1.54 -3.20 -0.01 588.77 1.00
log_tau_i[49] -1.71 0.93 -1.68 -3.17 -0.17 813.94 1.00
log_tau_i[50] -1.15 0.83 -1.11 -2.51 0.14 514.68 1.00
log_tau_i[51] -1.54 0.89 -1.51 -2.97 -0.13 780.67 1.00
log_tau_i[52] -1.59 0.93 -1.56 -3.21 -0.28 807.47 1.00
log_tau_i[53] -1.74 0.92 -1.72 -3.10 -0.15 657.89 1.00
log_tau_i[54] -1.74 0.93 -1.74 -3.09 -0.10 867.29 1.00
log_tau_i[55] -1.89 0.94 -1.87 -3.34 -0.27 1132.90 1.00
log_tau_i[56] -1.89 0.92 -1.89 -3.27 -0.27 980.72 1.00
log_tau_i[57] -1.84 0.91 -1.84 -3.33 -0.38 813.32 1.00
log_tau_i[58] -1.75 0.93 -1.74 -3.25 -0.34 583.52 1.00
log_tau_i[59] -1.59 0.96 -1.53 -3.11 -0.06 754.39 1.01
log_tau_i[60] -1.58 0.98 -1.52 -3.23 -0.08 561.05 1.00
log_tau_i[61] -0.98 1.03 -0.85 -2.55 0.79 502.63 1.00
log_tau_zero -1.49 0.26 -1.49 -1.91 -1.10 220.32 1.00
Number of divergences: 1
The number of divergences have significantly reduced from 37 to 1.
= az.from_numpyro(mcmc_reparam)
data =True, figsize=(15, 25)); az.plot_trace(data, compact
4. References:
- https://arxiv.org/abs/1906.03028
- https://github.com/mgorinova/autoreparam/tree/master