交互式在线版本: Open In Colab

变分推断参数化

作者:Madhav Kanda

有时,哈密顿蒙特卡罗 (HMC) 采样器在从后验分布中有效采样时会遇到挑战。一个说明性的例子是 Neal's funnel。在这种情况下,传统的中心化参数化可能不足,需要采用非中心化参数化。然而,在某些情况下,即使非中心化参数化也不够,此时就需要利用变分推断参数化 (Variationally Inferred Parameterization) 来在 0 到 1 的范围内达到所需的中心化。

本教程的目的是使用 Numpyro 中的 LocScaleReparam,基于 概率程序的自动重参数化 来实现变分推断参数化

[ ]:
%pip -qq install numpyro
%pip -qq install ucimlrepo
[ ]:
import arviz as az
import numpy as np
from ucimlrepo import fetch_ucirepo

import jax
import jax.numpy as jnp

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoDiagonalNormal
from numpyro.infer.reparam import LocScaleReparam

rng_key = jax.random.PRNGKey(0)

1. 数据集

本示例将使用德国信用数据集。该数据集包含 1000 条记录,以及 Hofmann 教授准备的 20 个分类符号属性。在该数据集中,每条记录代表一个在银行获得贷款的人。每个人都根据属性集被归类为良好信用风险或不良信用风险。

[ ]:
def load_german_credit():
    statlog_german_credit_data = fetch_ucirepo(id=144)
    X = statlog_german_credit_data.data.features
    y = statlog_german_credit_data.data.targets
    return X, y
[ ]:
X, y = load_german_credit()
X
属性1 属性2 属性3 属性4 属性5 属性6 属性7 属性8 属性9 属性10 属性11 属性12 属性13 属性14 属性15 属性16 属性17 属性18 属性19 属性20
0 A11 6 A34 A43 1169 A65 A75 4 A93 A101 4 A121 67 A143 A152 2 A173 1 A192 A201
1 A12 48 A32 A43 5951 A61 A73 2 A92 A101 2 A121 22 A143 A152 1 A173 1 A191 A201
2 A14 12 A34 A46 2096 A61 A74 2 A93 A101 3 A121 49 A143 A152 1 A172 2 A191 A201
3 A11 42 A32 A42 7882 A61 A74 2 A93 A103 4 A122 45 A143 A153 1 A173 2 A191 A201
4 A11 24 A33 A40 4870 A61 A73 3 A93 A101 4 A124 53 A143 A153 2 A173 2 A191 A201
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
995 A14 12 A32 A42 1736 A61 A74 3 A92 A101 4 A121 31 A143 A152 1 A172 1 A191 A201
996 A11 30 A32 A41 3857 A61 A73 4 A91 A101 4 A122 40 A143 A152 1 A174 1 A192 A201
997 A14 12 A32 A43 804 A61 A75 4 A93 A101 4 A123 38 A143 A152 1 A173 1 A191 A201
998 A11 45 A32 A43 1845 A61 A73 4 A93 A101 4 A124 23 A143 A153 1 A173 1 A192 A201
999 A12 45 A34 A41 4576 A62 A71 3 A93 A101 4 A123 27 A143 A152 1 A173 1 A191 A201

1000 行 × 20 列

这里,X 表示 20 个属性以及数据记录中每个人对应这些属性的值,而 y 是对应这些属性的输出变量。

[ ]:
def data_transform(X, y):
    def categorical_to_int(x):
        d = {u: i for i, u in enumerate(np.unique(x))}
        return np.array([d[i] for i in x])

    categoricals = []
    numericals = []
    numericals.append(np.ones([len(y)]))
    for column in X:
        column = X[column]
        if column.dtype == "O":
            categoricals.append(categorical_to_int(column))
        else:
            numericals.append((column - column.mean()) / column.std())
    numericals = np.array(numericals).T
    status = np.array(y == 1, dtype=np.int32)
    status = np.squeeze(status)

    return jnp.array(numericals), jnp.array(categoricals), jnp.array(status)

用于输入到 Numpyro 模型的数据转换

[ ]:
numericals, categoricals, status = data_transform(X, y)
[ ]:
x_numeric = numericals.astype(jnp.float32)
x_categorical = [jnp.eye(c.max() + 1)[c] for c in categoricals]
all_x = jnp.concatenate([x_numeric] + x_categorical, axis=1)
num_features = all_x.shape[1]
y = status[jnp.newaxis, Ellipsis]

2. 模型

我们将使用一个对系数尺度具有分层先验的逻辑回归模型

\begin{align} \log \tau_0 & \sim \mathcal{N}(0,10) & \log \tau_i & \sim \mathcal{N}\left(\log \tau_0, 1\right) \\ \beta_i & \sim \mathcal{N}\left(0, \tau_i\right) & y & \sim \operatorname{Bernoulli}\left(\sigma\left(\beta X^T\right)\right) \end{align}

[ ]:
def german_credit():
    log_tau_zero = numpyro.sample("log_tau_zero", dist.Normal(0, 10))
    log_tau_i = numpyro.sample(
        "log_tau_i", dist.Normal(log_tau_zero, jnp.ones(num_features))
    )
    beta = numpyro.sample(
        "beta", dist.Normal(jnp.zeros(num_features), jnp.exp(log_tau_i))
    )
    numpyro.sample(
        "obs",
        dist.Bernoulli(logits=jnp.einsum("nd,md->mn", all_x, beta[jnp.newaxis, :])),
        obs=y,
    )
[ ]:
nuts_kernel = NUTS(german_credit)
mcmc = MCMC(nuts_kernel, num_warmup=1000, num_samples=1000)
mcmc.run(rng_key, extra_fields=("num_steps",))
sample: 100%|██████████| 2000/2000 [00:21<00:00, 94.07it/s, 63 steps of size 6.31e-02. acc. prob=0.87]
[ ]:
mcmc.print_summary()
                    mean       std    median      5.0%     95.0%     n_eff     r_hat
       beta[0]      0.13      0.38      0.05     -0.36      0.74    284.06      1.00
       beta[1]     -0.34      0.12     -0.34     -0.52     -0.15    621.55      1.00
       beta[2]     -0.27      0.13     -0.27     -0.45     -0.03    542.13      1.00
       beta[3]     -0.30      0.10     -0.30     -0.44     -0.11    566.55      1.00
       beta[4]     -0.00      0.07     -0.00     -0.12      0.11    782.35      1.00
       beta[5]      0.12      0.09      0.11     -0.02      0.27    728.28      1.01
       beta[6]     -0.08      0.08     -0.07     -0.22      0.05    822.89      1.00
       beta[7]     -0.05      0.07     -0.04     -0.19      0.05    752.66      1.00
       beta[8]     -0.42      0.32     -0.39     -0.87      0.05    198.00      1.00
       beta[9]     -0.07      0.26     -0.02     -0.50      0.31    220.27      1.00
      beta[10]      0.26      0.31      0.18     -0.15      0.78    404.97      1.00
      beta[11]      1.23      0.34      1.25      0.68      1.79    227.34      1.01
      beta[12]     -0.26      0.34     -0.17     -0.81      0.22    349.10      1.00
      beta[13]     -0.30      0.34     -0.21     -0.86      0.13    387.72      1.00
      beta[14]      0.07      0.20      0.04     -0.26      0.38    240.45      1.03
      beta[15]      0.10      0.22      0.05     -0.18      0.50    287.41      1.02
      beta[16]      0.76      0.30      0.76      0.22      1.24    364.73      1.03
      beta[17]     -0.53      0.28     -0.55     -0.94     -0.05    269.95      1.00
      beta[18]      0.70      0.42      0.70     -0.02      1.29    367.28      1.00
      beta[19]      0.17      0.40      0.06     -0.43      0.77    333.54      1.00
      beta[20]      0.03      0.19      0.01     -0.23      0.39    381.57      1.00
      beta[21]      0.18      0.22      0.13     -0.14      0.53    335.48      1.00
      beta[22]     -0.05      0.32     -0.01     -0.56      0.46    439.54      1.00
      beta[23]     -0.10      0.30     -0.04     -0.63      0.30    508.20      1.00
      beta[24]     -0.34      0.36     -0.25     -0.94      0.12    283.15      1.00
      beta[25]      0.14      0.40      0.04     -0.46      0.71    433.69      1.00
      beta[26]     -0.01      0.19     -0.00     -0.34      0.28    438.64      1.00
      beta[27]     -0.36      0.27     -0.33     -0.78      0.04    377.33      1.01
      beta[28]     -0.07      0.22     -0.03     -0.43      0.26    493.09      1.00
      beta[29]      0.01      0.22      0.00     -0.32      0.34    448.21      1.00
      beta[30]      0.35      0.43      0.22     -0.18      1.08    314.69      1.00
      beta[31]      0.41      0.33      0.40     -0.10      0.90    402.62      1.00
      beta[32]     -0.03      0.21     -0.01     -0.39      0.30    525.23      1.00
      beta[33]     -0.12      0.18     -0.09     -0.41      0.16    334.94      1.00
      beta[34]     -0.02      0.16     -0.01     -0.24      0.26    318.25      1.00
      beta[35]      0.42      0.27      0.42     -0.04      0.81    455.99      1.00
      beta[36]      0.05      0.17      0.03     -0.18      0.35    506.34      1.00
      beta[37]     -0.12      0.25     -0.06     -0.57      0.21    470.11      1.00
      beta[38]     -0.07      0.20     -0.04     -0.39      0.24    410.71      1.00
      beta[39]      0.36      0.24      0.35     -0.04      0.71    359.55      1.00
      beta[40]      0.05      0.20      0.02     -0.29      0.35    441.70      1.00
      beta[41]     -0.00      0.21      0.00     -0.34      0.37    513.67      1.00
      beta[42]     -0.13      0.27     -0.08     -0.59      0.23    402.64      1.00
      beta[43]      0.55      0.46      0.49     -0.11      1.28    570.74      1.00
      beta[44]      0.19      0.21      0.15     -0.14      0.50    379.76      1.00
      beta[45]     -0.00      0.16      0.00     -0.25      0.26    352.19      1.00
      beta[46]      0.01      0.16      0.01     -0.25      0.25    411.05      1.00
      beta[47]     -0.16      0.24     -0.11     -0.55      0.18    455.59      1.00
      beta[48]     -0.12      0.24     -0.07     -0.55      0.21    322.67      1.04
      beta[49]     -0.04      0.23     -0.02     -0.45      0.30    437.47      1.02
      beta[50]      0.38      0.28      0.37     -0.03      0.82    266.19      1.04
      beta[51]     -0.14      0.22     -0.09     -0.52      0.16    406.31      1.00
      beta[52]      0.19      0.23      0.14     -0.14      0.55    338.97      1.00
      beta[53]      0.04      0.22      0.02     -0.23      0.43    438.03      1.00
      beta[54]      0.05      0.24      0.02     -0.32      0.41    522.43      1.00
      beta[55]      0.02      0.14      0.01     -0.22      0.23    562.00      1.00
      beta[56]     -0.01      0.13     -0.01     -0.24      0.21    638.20      1.00
      beta[57]      0.01      0.17      0.00     -0.25      0.34    590.99      1.00
      beta[58]     -0.07      0.18     -0.04     -0.34      0.23    481.37      1.00
      beta[59]      0.13      0.19      0.09     -0.12      0.47    507.56      1.00
      beta[60]     -0.14      0.33     -0.06     -0.64      0.37    303.00      1.00
      beta[61]      0.48      0.56      0.32     -0.18      1.41    438.86      1.00
  log_tau_i[0]     -1.51      0.95     -1.52     -3.03      0.11    290.78      1.00
  log_tau_i[1]     -1.07      0.67     -1.11     -2.12      0.03    641.04      1.00
  log_tau_i[2]     -1.24      0.76     -1.26     -2.47      0.03    666.31      1.00
  log_tau_i[3]     -1.16      0.65     -1.19     -2.20     -0.10    821.60      1.00
  log_tau_i[4]     -2.11      0.88     -2.13     -3.50     -0.61    806.15      1.00
  log_tau_i[5]     -1.71      0.86     -1.68     -3.28     -0.44    697.00      1.00
  log_tau_i[6]     -1.88      0.84     -1.91     -3.30     -0.58    623.56      1.00
  log_tau_i[7]     -1.99      0.90     -1.98     -3.51     -0.65    710.21      1.00
  log_tau_i[8]     -1.00      0.86     -0.96     -2.23      0.52    445.30      1.00
  log_tau_i[9]     -1.69      0.93     -1.63     -3.17     -0.14    326.33      1.00
 log_tau_i[10]     -1.41      0.95     -1.35     -2.93      0.19    441.60      1.01
 log_tau_i[11]     -0.11      0.57     -0.12     -0.97      0.80    539.60      1.00
 log_tau_i[12]     -1.36      0.96     -1.31     -3.16      0.01    336.11      1.00
 log_tau_i[13]     -1.30      0.95     -1.26     -2.85      0.28    335.04      1.00
 log_tau_i[14]     -1.72      0.89     -1.70     -3.05     -0.25    584.38      1.00
 log_tau_i[15]     -1.65      0.92     -1.63     -3.07     -0.10    345.77      1.03
 log_tau_i[16]     -0.51      0.65     -0.49     -1.42      0.59    676.64      1.00
 log_tau_i[17]     -0.84      0.76     -0.76     -2.09      0.34    303.14      1.00
 log_tau_i[18]     -0.69      0.82     -0.59     -2.03      0.61    359.35      1.00
 log_tau_i[19]     -1.45      0.99     -1.42     -2.97      0.25    397.18      1.00
 log_tau_i[20]     -1.75      0.94     -1.73     -3.39     -0.40    617.54      1.00
 log_tau_i[21]     -1.51      0.88     -1.49     -3.16     -0.27    488.52      1.00
 log_tau_i[22]     -1.56      0.93     -1.56     -3.06     -0.10    348.20      1.00
 log_tau_i[23]     -1.58      0.94     -1.57     -3.05      0.02    278.69      1.00
 log_tau_i[24]     -1.26      1.00     -1.12     -2.91      0.29    205.38      1.00
 log_tau_i[25]     -1.53      0.95     -1.56     -3.11      0.02    351.09      1.00
 log_tau_i[26]     -1.73      0.91     -1.74     -3.17     -0.22    492.18      1.00
 log_tau_i[27]     -1.15      0.89     -1.08     -2.66      0.17    485.34      1.00
 log_tau_i[28]     -1.69      0.92     -1.65     -3.15     -0.19    425.75      1.00
 log_tau_i[29]     -1.71      0.99     -1.71     -3.19      0.01    374.58      1.00
 log_tau_i[30]     -1.24      0.99     -1.20     -2.74      0.50    327.47      1.00
 log_tau_i[31]     -1.02      0.89     -0.91     -2.40      0.51    587.85      1.00
 log_tau_i[32]     -1.71      0.94     -1.70     -3.22     -0.11    511.74      1.00
 log_tau_i[33]     -1.69      0.90     -1.68     -3.13     -0.28    538.65      1.00
 log_tau_i[34]     -1.82      0.92     -1.81     -3.35     -0.35    423.01      1.00
 log_tau_i[35]     -1.06      0.82     -1.00     -2.30      0.34    470.50      1.00
 log_tau_i[36]     -1.79      0.87     -1.76     -3.15     -0.34    527.47      1.00
 log_tau_i[37]     -1.58      0.95     -1.54     -3.11      0.04    485.52      1.00
 log_tau_i[38]     -1.71      0.87     -1.65     -3.18     -0.34    482.67      1.00
 log_tau_i[39]     -1.12      0.85     -1.01     -2.44      0.33    337.59      1.00
 log_tau_i[40]     -1.76      0.96     -1.73     -3.58     -0.36    533.15      1.00
 log_tau_i[41]     -1.74      0.94     -1.70     -3.26     -0.22    500.91      1.00
 log_tau_i[42]     -1.57      0.95     -1.54     -3.04      0.01    499.44      1.00
 log_tau_i[43]     -0.87      0.93     -0.74     -2.28      0.58    445.98      1.00
 log_tau_i[44]     -1.52      0.89     -1.45     -2.95     -0.13    442.63      1.00
 log_tau_i[45]     -1.84      0.94     -1.79     -3.21     -0.09    673.31      1.00
 log_tau_i[46]     -1.82      0.85     -1.83     -3.26     -0.56    579.66      1.00
 log_tau_i[47]     -1.54      0.90     -1.51     -3.35     -0.30    428.50      1.00
 log_tau_i[48]     -1.62      0.89     -1.60     -3.00     -0.15    413.30      1.01
 log_tau_i[49]     -1.71      0.95     -1.68     -3.23     -0.13    514.04      1.00
 log_tau_i[50]     -1.12      0.92     -0.99     -2.67      0.38    206.76      1.03
 log_tau_i[51]     -1.61      0.92     -1.58     -3.07     -0.03    477.41      1.00
 log_tau_i[52]     -1.54      0.90     -1.49     -2.96     -0.09    459.83      1.00
 log_tau_i[53]     -1.74      0.92     -1.69     -3.13     -0.14    509.51      1.00
 log_tau_i[54]     -1.68      0.95     -1.67     -3.07      0.10    477.21      1.00
 log_tau_i[55]     -1.87      0.97     -1.88     -3.49     -0.35    514.38      1.00
 log_tau_i[56]     -1.87      0.96     -1.84     -3.23     -0.12    574.80      1.00
 log_tau_i[57]     -1.77      0.86     -1.72     -3.26     -0.36    646.10      1.00
 log_tau_i[58]     -1.78      0.92     -1.77     -3.18     -0.15    617.59      1.00
 log_tau_i[59]     -1.67      0.93     -1.61     -3.19     -0.21    510.74      1.00
 log_tau_i[60]     -1.50      0.99     -1.44     -3.09      0.08    386.86      1.00
 log_tau_i[61]     -1.09      1.06     -1.00     -2.79      0.52    421.27      1.00
  log_tau_zero     -1.49      0.26     -1.49     -1.90     -1.05    169.88      1.00

Number of divergences: 37

从 mcmc.print_summary 可以看出,存在 37 个发散。因此,我们将使用变分推断参数化 (VIP) 来减少这些发散。

[ ]:
data = az.from_numpyro(mcmc)
az.plot_trace(data, compact=True);
../_images/tutorials_variationally_inferred_parameterization_16_0.png

3. 重参数化

我们为任意变量 \(\lambda \in [0,1]\) 引入一个参数化参数 \(z\),并进行变换

=> \(z\) ~ \(N (z | μ, σ)\)

=> 通过定义 \(z\) ~ \(N(λμ, σ^λ)\)

=> \(z\) = \(μ + σ^{1-λ}(z - λμ)\)

因此,使用上述变换,联合密度可以变换如下: \begin{align} p(\theta, \hat{\mu}, \mathbf{y}) & =\mathcal{N}(\theta \mid 0,1) \times \mathcal{N}\left(\mu \mid \theta, \sigma_\mu\right) \times \mathcal{N}(\mathbf{y} \mid \mu, \sigma) \end{align}

\begin{align} p(\theta, \hat{\mu}, \mathbf{y}) & =\mathcal{N}(\theta \mid 0,1) \times \mathcal{N}\left(\hat{\mu} \mid \lambda \theta, \sigma_\mu^\lambda\right) \times \mathcal{N}\left(\mathbf{y} \mid \theta+\sigma_\mu^{1-\lambda}(\hat{\mu}-\lambda \theta), \sigma\right) \end{align}

[ ]:
def german_credit_reparam(beta_centeredness=None):
    def model():
        log_tau_zero = numpyro.sample("log_tau_zero", dist.Normal(0, 10))
        log_tau_i = numpyro.sample(
            "log_tau_i", dist.Normal(log_tau_zero, jnp.ones(num_features))
        )
        with numpyro.handlers.reparam(
            config={"beta": LocScaleReparam(beta_centeredness)}
        ):
            beta = numpyro.sample(
                "beta", dist.Normal(jnp.zeros(num_features), jnp.exp(log_tau_i))
            )
        numpyro.sample(
            "obs",
            dist.Bernoulli(logits=jnp.einsum("nd,md->mn", all_x, beta[jnp.newaxis, :])),
            obs=y,
        )

    return model

现在,使用 SVI 优化 \(\lambda\)

[ ]:
model = german_credit_reparam()
guide = AutoDiagonalNormal(model)
svi = SVI(model, guide, numpyro.optim.Adam(3e-4), Trace_ELBO(10))
svi_results = svi.run(rng_key, 10000)
100%|██████████| 10000/10000 [00:16<00:00, 588.87it/s, init loss: 2165.2424, avg. loss [9501-10000]: 576.7846]
[ ]:
reparam_model = german_credit_reparam(
    beta_centeredness=svi_results.params["beta_centered"]
)
[ ]:
nuts_kernel = NUTS(reparam_model)
mcmc_reparam = MCMC(nuts_kernel, num_warmup=1000, num_samples=1000)
mcmc_reparam.run(rng_key, extra_fields=("num_steps",))
sample: 100%|██████████| 2000/2000 [00:07<00:00, 285.41it/s, 31 steps of size 1.28e-01. acc. prob=0.89]
[ ]:
mcmc_reparam.print_summary()
                         mean       std    median      5.0%     95.0%     n_eff     r_hat
 beta_decentered[0]      0.12      0.40      0.06     -0.48      0.80    338.70      1.00
 beta_decentered[1]     -0.45      0.15     -0.45     -0.70     -0.21    791.23      1.00
 beta_decentered[2]     -0.38      0.17     -0.38     -0.65     -0.09    691.79      1.00
 beta_decentered[3]     -0.41      0.13     -0.41     -0.61     -0.19   1022.79      1.00
 beta_decentered[4]     -0.01      0.11     -0.01     -0.18      0.20   1176.84      1.00
 beta_decentered[5]      0.19      0.14      0.19     -0.04      0.41   1194.41      1.00
 beta_decentered[6]     -0.13      0.14     -0.13     -0.36      0.09   1227.24      1.00
 beta_decentered[7]     -0.07      0.12     -0.06     -0.24      0.14   1096.31      1.00
 beta_decentered[8]     -0.46      0.34     -0.46     -0.99      0.08    330.30      1.00
 beta_decentered[9]     -0.03      0.32     -0.02     -0.57      0.49    310.35      1.00
beta_decentered[10]      0.35      0.39      0.30     -0.26      1.00    426.11      1.00
beta_decentered[11]      1.29      0.31      1.30      0.81      1.82    433.16      1.00
beta_decentered[12]     -0.32      0.39     -0.25     -0.96      0.24    521.05      1.00
beta_decentered[13]     -0.38      0.40     -0.32     -1.00      0.24    410.05      1.00
beta_decentered[14]      0.08      0.28      0.06     -0.37      0.57    457.72      1.00
beta_decentered[15]      0.14      0.30      0.10     -0.28      0.66    612.31      1.00
beta_decentered[16]      0.85      0.31      0.86      0.41      1.45    432.14      1.00
beta_decentered[17]     -0.64      0.28     -0.65     -1.05     -0.14    523.15      1.00
beta_decentered[18]      0.78      0.42      0.78      0.07      1.46    545.52      1.00
beta_decentered[19]      0.15      0.39      0.08     -0.50      0.80    662.60      1.00
beta_decentered[20]      0.04      0.25      0.03     -0.39      0.40    445.85      1.00
beta_decentered[21]      0.24      0.27      0.21     -0.20      0.65    477.68      1.00
beta_decentered[22]     -0.03      0.38     -0.01     -0.64      0.60    984.59      1.00
beta_decentered[23]     -0.13      0.34     -0.08     -0.72      0.35    702.87      1.00
beta_decentered[24]     -0.41      0.39     -0.37     -1.08      0.13    603.13      1.00
beta_decentered[25]      0.19      0.47      0.09     -0.48      0.92    529.68      1.00
beta_decentered[26]      0.00      0.25      0.01     -0.47      0.35    690.54      1.00
beta_decentered[27]     -0.46      0.31     -0.46     -0.95      0.04    464.44      1.00
beta_decentered[28]     -0.09      0.30     -0.06     -0.56      0.41    464.65      1.00
beta_decentered[29]      0.02      0.30      0.01     -0.47      0.52    747.44      1.00
beta_decentered[30]      0.38      0.44      0.31     -0.30      1.05    717.12      1.00
beta_decentered[31]      0.47      0.36      0.47     -0.09      1.03    564.18      1.00
beta_decentered[32]     -0.03      0.26     -0.02     -0.44      0.44    572.03      1.00
beta_decentered[33]     -0.17      0.25     -0.15     -0.63      0.19    713.40      1.00
beta_decentered[34]     -0.02      0.21     -0.01     -0.39      0.32    620.45      1.01
beta_decentered[35]      0.53      0.31      0.55     -0.03      1.00    681.60      1.00
beta_decentered[36]      0.09      0.24      0.06     -0.27      0.49    610.06      1.00
beta_decentered[37]     -0.14      0.31     -0.10     -0.74      0.28    826.87      1.00
beta_decentered[38]     -0.12      0.25     -0.11     -0.53      0.30    493.49      1.00
beta_decentered[39]      0.44      0.28      0.44     -0.01      0.89    542.71      1.00
beta_decentered[40]      0.05      0.26      0.03     -0.39      0.45    709.78      1.00
beta_decentered[41]      0.02      0.30      0.01     -0.52      0.46    389.41      1.00
beta_decentered[42]     -0.15      0.33     -0.10     -0.74      0.32    607.05      1.00
beta_decentered[43]      0.66      0.47      0.65     -0.11      1.38    539.31      1.01
beta_decentered[44]      0.25      0.27      0.23     -0.19      0.66    686.63      1.00
beta_decentered[45]     -0.01      0.22     -0.00     -0.36      0.35    909.70      1.00
beta_decentered[46]      0.02      0.22      0.01     -0.33      0.39    741.33      1.00
beta_decentered[47]     -0.25      0.31     -0.21     -0.72      0.25    487.26      1.00
beta_decentered[48]     -0.18      0.30     -0.15     -0.67      0.30    400.57      1.00
beta_decentered[49]     -0.07      0.30     -0.05     -0.56      0.41    594.38      1.00
beta_decentered[50]      0.46      0.31      0.46     -0.09      0.88    384.33      1.00
beta_decentered[51]     -0.23      0.27     -0.19     -0.69      0.16    561.35      1.00
beta_decentered[52]      0.22      0.25      0.21     -0.19      0.60    456.41      1.00
beta_decentered[53]      0.07      0.29      0.04     -0.41      0.53    500.99      1.00
beta_decentered[54]      0.05      0.30      0.02     -0.49      0.51    896.08      1.00
beta_decentered[55]      0.01      0.23      0.01     -0.34      0.43   1033.13      1.00
beta_decentered[56]     -0.02      0.19     -0.01     -0.33      0.28    717.87      1.00
beta_decentered[57]      0.01      0.23      0.01     -0.33      0.41    684.61      1.00
beta_decentered[58]     -0.09      0.26     -0.08     -0.55      0.30    455.66      1.01
beta_decentered[59]      0.20      0.27      0.18     -0.20      0.63    413.57      1.01
beta_decentered[60]     -0.14      0.37     -0.09     -0.76      0.46    411.83      1.00
beta_decentered[61]      0.58      0.57      0.50     -0.23      1.50    495.86      1.00
       log_tau_i[0]     -1.56      0.91     -1.53     -2.94     -0.01    480.69      1.00
       log_tau_i[1]     -1.07      0.64     -1.08     -1.99      0.06    950.02      1.00
       log_tau_i[2]     -1.25      0.74     -1.23     -2.36     -0.06    796.37      1.00
       log_tau_i[3]     -1.15      0.65     -1.18     -2.25     -0.14    838.40      1.00
       log_tau_i[4]     -2.09      0.90     -2.12     -3.51     -0.49   1120.31      1.00
       log_tau_i[5]     -1.70      0.84     -1.69     -3.10     -0.35   1291.74      1.00
       log_tau_i[6]     -1.87      0.92     -1.84     -3.67     -0.58   1066.38      1.00
       log_tau_i[7]     -2.06      0.89     -2.07     -3.42     -0.42    641.48      1.00
       log_tau_i[8]     -1.11      0.91     -1.03     -2.79      0.19    482.69      1.00
       log_tau_i[9]     -1.65      0.91     -1.62     -3.05     -0.04    772.73      1.00
      log_tau_i[10]     -1.31      0.94     -1.27     -2.77      0.24    578.34      1.00
      log_tau_i[11]     -0.04      0.49     -0.08     -0.81      0.77    736.20      1.00
      log_tau_i[12]     -1.37      0.98     -1.32     -2.86      0.31    623.14      1.00
      log_tau_i[13]     -1.28      0.97     -1.21     -2.78      0.30    653.51      1.00
      log_tau_i[14]     -1.70      0.95     -1.71     -3.20     -0.15    831.88      1.00
      log_tau_i[15]     -1.67      0.97     -1.65     -3.15      0.04    726.23      1.00
      log_tau_i[16]     -0.53      0.65     -0.49     -1.55      0.54    558.56      1.00
      log_tau_i[17]     -0.81      0.68     -0.80     -1.91      0.27    655.36      1.00
      log_tau_i[18]     -0.71      0.83     -0.60     -1.89      0.75    536.55      1.00
      log_tau_i[19]     -1.55      0.98     -1.53     -3.17     -0.02    675.27      1.00
      log_tau_i[20]     -1.81      0.90     -1.79     -3.36     -0.41    823.93      1.00
      log_tau_i[21]     -1.53      0.93     -1.51     -3.03     -0.04    767.64      1.00
      log_tau_i[22]     -1.60      0.99     -1.58     -3.25     -0.01    840.71      1.00
      log_tau_i[23]     -1.64      0.97     -1.60     -3.15      0.10    615.62      1.00
      log_tau_i[24]     -1.21      0.96     -1.11     -2.61      0.47    674.58      1.00
      log_tau_i[25]     -1.54      1.04     -1.49     -3.39     -0.01    509.69      1.00
      log_tau_i[26]     -1.77      0.95     -1.74     -3.27     -0.13    915.91      1.00
      log_tau_i[27]     -1.14      0.89     -1.07     -2.53      0.36    569.10      1.01
      log_tau_i[28]     -1.72      0.93     -1.66     -3.25     -0.20    911.70      1.01
      log_tau_i[29]     -1.70      0.91     -1.71     -3.16     -0.14    648.29      1.00
      log_tau_i[30]     -1.28      1.02     -1.28     -2.99      0.34    575.90      1.00
      log_tau_i[31]     -1.10      0.86     -1.04     -2.39      0.45    730.21      1.00
      log_tau_i[32]     -1.79      0.95     -1.82     -3.32     -0.19    667.95      1.00
      log_tau_i[33]     -1.64      0.86     -1.62     -3.04     -0.32    948.64      1.00
      log_tau_i[34]     -1.87      0.92     -1.88     -3.29     -0.25    909.49      1.00
      log_tau_i[35]     -1.00      0.85     -0.95     -2.44      0.25    672.42      1.00
      log_tau_i[36]     -1.76      0.92     -1.73     -3.19     -0.22    889.56      1.00
      log_tau_i[37]     -1.63      1.00     -1.58     -3.11      0.12    973.29      1.00
      log_tau_i[38]     -1.72      0.85     -1.70     -3.13     -0.30    837.73      1.00
      log_tau_i[39]     -1.15      0.80     -1.13     -2.32      0.18    627.15      1.00
      log_tau_i[40]     -1.76      0.93     -1.70     -3.36     -0.34    686.88      1.00
      log_tau_i[41]     -1.75      0.96     -1.74     -3.20     -0.09    612.83      1.00
      log_tau_i[42]     -1.62      0.91     -1.62     -3.21     -0.25    596.18      1.00
      log_tau_i[43]     -0.79      0.93     -0.74     -2.20      0.83    560.52      1.01
      log_tau_i[44]     -1.52      0.94     -1.48     -3.08      0.01    888.27      1.00
      log_tau_i[45]     -1.83      0.90     -1.84     -3.27     -0.44   1122.53      1.00
      log_tau_i[46]     -1.83      0.89     -1.78     -3.21     -0.31    990.03      1.00
      log_tau_i[47]     -1.56      0.93     -1.48     -3.20     -0.21    736.01      1.00
      log_tau_i[48]     -1.59      0.94     -1.54     -3.20     -0.01    588.77      1.00
      log_tau_i[49]     -1.71      0.93     -1.68     -3.17     -0.17    813.94      1.00
      log_tau_i[50]     -1.15      0.83     -1.11     -2.51      0.14    514.68      1.00
      log_tau_i[51]     -1.54      0.89     -1.51     -2.97     -0.13    780.67      1.00
      log_tau_i[52]     -1.59      0.93     -1.56     -3.21     -0.28    807.47      1.00
      log_tau_i[53]     -1.74      0.92     -1.72     -3.10     -0.15    657.89      1.00
      log_tau_i[54]     -1.74      0.93     -1.74     -3.09     -0.10    867.29      1.00
      log_tau_i[55]     -1.89      0.94     -1.87     -3.34     -0.27   1132.90      1.00
      log_tau_i[56]     -1.89      0.92     -1.89     -3.27     -0.27    980.72      1.00
      log_tau_i[57]     -1.84      0.91     -1.84     -3.33     -0.38    813.32      1.00
      log_tau_i[58]     -1.75      0.93     -1.74     -3.25     -0.34    583.52      1.00
      log_tau_i[59]     -1.59      0.96     -1.53     -3.11     -0.06    754.39      1.01
      log_tau_i[60]     -1.58      0.98     -1.52     -3.23     -0.08    561.05      1.00
      log_tau_i[61]     -0.98      1.03     -0.85     -2.55      0.79    502.63      1.00
       log_tau_zero     -1.49      0.26     -1.49     -1.91     -1.10    220.32      1.00

Number of divergences: 1

发散的数量已从 37 个显著减少到 1 个。

[ ]:
data = az.from_numpyro(mcmc_reparam)
az.plot_trace(data, compact=True, figsize=(15, 25));
../_images/tutorials_variationally_inferred_parameterization_26_0.png

4. 参考文献:

  1. https://arxiv.org/abs/1906.03028

  2. https://github.com/mgorinova/autoreparam/tree/master