交互式在线版本: 在 Colab 中打开

不良后验几何形状及处理方法

HMC 及其变体 NUTS 使用梯度信息从后验分布中提取(近似)样本。这些梯度是在特定的坐标系中计算的,不同的坐标系选择会使 HMC 的效率或高或低。这类似于约束优化问题中的情况,例如,通过指数变换而非 softplus 变换参数化一个正量会导致不同的优化动力学。

因此,关注后验分布的几何形状非常重要。对模型进行重参数化(即改变坐标系)对许多复杂模型来说会产生很大的实际差异。对于最复杂的模型,这可能是绝对必要的。出于同样的原因,关注控制 HMC/NUTS 的一些超参数(特别是 max_tree_depthdense_mass)也很重要。

在本教程中,我们将通过几个具体示例探讨具有不良后验几何形状的模型,以及如何处理它们以获得更好的性能。

[1]:
!pip install -q numpyro@git+https://github.com/pyro-ppl/numpyro
[1]:
from functools import partial

import numpy as np

from jax import random
import jax.numpy as jnp

import numpyro
from numpyro.diagnostics import summary
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS

assert numpyro.__version__.startswith("0.18.0")

# NB: replace cpu by gpu to run this notebook on gpu
numpyro.set_platform("cpu")

我们首先编写一个辅助函数来执行 NUTS 推断。

[2]:
def run_inference(
    model, num_warmup=1000, num_samples=1000, max_tree_depth=10, dense_mass=False
):
    kernel = NUTS(model, max_tree_depth=max_tree_depth, dense_mass=dense_mass)
    mcmc = MCMC(
        kernel,
        num_warmup=num_warmup,
        num_samples=num_samples,
        num_chains=1,
        progress_bar=False,
    )
    mcmc.run(random.PRNGKey(0))
    summary_dict = summary(mcmc.get_samples(), group_by_chain=False)

    # print the largest r_hat for each variable
    for k, v in summary_dict.items():
        spaces = " " * max(12 - len(k), 0)
        print("[{}] {} \t max r_hat: {:.4f}".format(k, spaces, np.max(v["r_hat"])))

评估 HMC/NUTS

通常很难评估从 HMC 或 NUTS 返回的样本是否代表后验分布的准确(近似)样本。然而,两条经验法则是在查看由 mcmc.print_summary() 返回的有效样本大小 (ESS) 和 r_hat 诊断指标。如果看到 r_hat 的值在 (1.0, 1.05) 范围内,且有效样本大小与总样本数 num_samples 相近(假设 thinning=1),那么我们有充分的理由相信 HMC 运行良好。然而,如果某些变量的有效样本大小很低或 r_hat 很大(例如 r_hat = 1.15),那么 HMC 可能正在与后验几何形状作斗争。在下文中,我们将使用 r_hat 作为主要的诊断指标。

模型重参数化

示例 #1

我们从一个示例(马蹄形回归;完整的示例脚本请参见 examples/horseshoe_regression.py)开始,其中重参数化非常有帮助。这个特定的示例演示了一种通用的重参数化策略,它在许多具有分层/多级结构的模型中都非常有用。有关分层模型中可能出现的一些问题的更多讨论,请参见参考文献 [1]。

[3]:
# In this unreparameterized model some of the parameters of the distributions
# explicitly depend on other parameters (in particular beta depends on lambdas and tau).
# This kind of coordinate system can be a challenge for HMC.
def _unrep_hs_model(X, Y):
    lambdas = numpyro.sample("lambdas", dist.HalfCauchy(jnp.ones(X.shape[1])))
    tau = numpyro.sample("tau", dist.HalfCauchy(jnp.ones(1)))
    betas = numpyro.sample("betas", dist.Normal(scale=tau * lambdas))
    mean_function = jnp.dot(X, betas)
    numpyro.sample("Y", dist.Normal(mean_function, 0.05), obs=Y)

为了解决这种坐标系导致的不良几何形状,我们使用以下重写逻辑来改变坐标。与其写成

\[\beta \sim {\rm Normal}(0, \lambda \tau)\]

我们写成

\[\beta^\prime \sim {\rm Normal}(0, 1)\]

并且

\[\beta \equiv \lambda \tau \beta^\prime\]

其中 \(\beta\) 现在由 \(\lambda\)\(\tau\)\(\beta^\prime\) 确定性地定义。实际上,我们已经改变到了一个坐标系,其中不同的潜变量之间的相关性较低。在这个新的坐标系中,使用对角质量矩阵的 HMC 可以比在原始坐标系中表现得好得多。

在 NumPyro 中实现这种重参数化基本上有两种方法

首先,让我们手动进行重参数化。

[4]:
# In this reparameterized model none of the parameters of the distributions
# explicitly depend on other parameters. This model is exactly equivalent
# to _unrep_hs_model but is expressed in a different coordinate system.
def _rep_hs_model1(X, Y):
    lambdas = numpyro.sample("lambdas", dist.HalfCauchy(jnp.ones(X.shape[1])))
    tau = numpyro.sample("tau", dist.HalfCauchy(jnp.ones(1)))
    unscaled_betas = numpyro.sample(
        "unscaled_betas", dist.Normal(scale=jnp.ones(X.shape[1]))
    )
    scaled_betas = numpyro.deterministic("betas", tau * lambdas * unscaled_betas)
    mean_function = jnp.dot(X, scaled_betas)
    numpyro.sample("Y", dist.Normal(mean_function, 0.05), obs=Y)

接下来,我们使用 numpyro.infer.reparam 进行重参数化。至少有两种方法可以做到这一点。首先,让我们使用 LocScaleReparam

[5]:
from numpyro.infer.reparam import LocScaleReparam

# LocScaleReparam with centered=0 fully "decenters" the prior over betas.
config = {"betas": LocScaleReparam(centered=0)}
# The coordinate system of this model is equivalent to that in _rep_hs_model1 above.
_rep_hs_model2 = numpyro.handlers.reparam(_unrep_hs_model, config=config)

为了展示 numpyro.infer.reparam 库的多功能性,我们改用 TransformReparam 进行重参数化。

[6]:
from numpyro.distributions.transforms import AffineTransform
from numpyro.infer.reparam import TransformReparam


# In this reparameterized model none of the parameters of the distributions
# explicitly depend on other parameters. This model is exactly equivalent
# to _unrep_hs_model but is expressed in a different coordinate system.
def _rep_hs_model3(X, Y):
    lambdas = numpyro.sample("lambdas", dist.HalfCauchy(jnp.ones(X.shape[1])))
    tau = numpyro.sample("tau", dist.HalfCauchy(jnp.ones(1)))

    # instruct NumPyro to do the reparameterization automatically.
    reparam_config = {"betas": TransformReparam()}
    with numpyro.handlers.reparam(config=reparam_config):
        betas_root_variance = tau * lambdas
        # in order to use TransformReparam we have to express the prior
        # over betas as a TransformedDistribution
        betas = numpyro.sample(
            "betas",
            dist.TransformedDistribution(
                dist.Normal(0.0, jnp.ones(X.shape[1])),
                AffineTransform(0.0, betas_root_variance),
            ),
        )

    mean_function = jnp.dot(X, betas)
    numpyro.sample("Y", dist.Normal(mean_function, 0.05), obs=Y)

最后,我们验证 _rep_hs_model1_rep_hs_model2_rep_hs_model3 确实比 _unrep_hs_model 获得了更好的 r_hat 值。

[8]:
# create fake dataset
X = np.random.RandomState(0).randn(100, 500)
Y = X[:, 0]

print("unreparameterized model (very bad r_hats)")
run_inference(partial(_unrep_hs_model, X, Y))

print("\nreparameterized model with manual reparameterization (good r_hats)")
run_inference(partial(_rep_hs_model1, X, Y))

print("\nreparameterized model with LocScaleReparam (good r_hats)")
run_inference(partial(_rep_hs_model2, X, Y))

print("\nreparameterized model with TransformReparam (good r_hats)")
run_inference(partial(_rep_hs_model3, X, Y))
unreparameterized model (very bad r_hats)
[betas]                  max r_hat: 1.0775
[lambdas]                max r_hat: 3.2450
[tau]                    max r_hat: 2.1926

reparameterized model with manual reparameterization (good r_hats)
[betas]                  max r_hat: 1.0074
[lambdas]                max r_hat: 1.0146
[tau]                    max r_hat: 1.0036
[unscaled_betas]         max r_hat: 1.0059

reparameterized model with LocScaleReparam (good r_hats)
[betas]                  max r_hat: 1.0103
[betas_decentered]       max r_hat: 1.0060
[lambdas]                max r_hat: 1.0124
[tau]                    max r_hat: 0.9998

reparameterized model with TransformReparam (good r_hats)
[betas]                  max r_hat: 1.0087
[betas_base]             max r_hat: 1.0080
[lambdas]                max r_hat: 1.0114
[tau]                    max r_hat: 1.0029

旁注:numpyro.deterministic

在上面的 _rep_hs_model1 中,我们使用了 numpyro.deterministic 来定义 scaled_betas。我们注意到使用这个原语并非严格必要;然而,它的结果是 scaled_betas 将出现在 trace 中,从而出现在 mcmc.print_summary() 报告的摘要中。换句话说,我们也可以写成

scaled_betas = tau * lambdas * unscaled_betas

而不调用 deterministic 原语。

质量矩阵

默认情况下,HMC/NUTS 使用对角质量矩阵。对于具有复杂几何形状的模型,使用更丰富的质量矩阵集可能会有所帮助。

示例 #2

在第一个简单示例中,我们展示了使用满秩(即密集)质量矩阵会获得更好的 r_hat

[9]:
# Because rho is very close to 1.0 the posterior geometry
# is extremely skewed and using the "diagonal" coordinate system
# implied by dense_mass=False leads to bad results
rho = 0.9999
cov = jnp.array([[10.0, rho], [rho, 0.1]])


def mvn_model():
    numpyro.sample("x", dist.MultivariateNormal(jnp.zeros(2), covariance_matrix=cov))


print("dense_mass = False (bad r_hat)")
run_inference(mvn_model, dense_mass=False, max_tree_depth=3)

print("dense_mass = True (good r_hat)")
run_inference(mvn_model, dense_mass=True, max_tree_depth=3)
dense_mass = False (bad r_hat)
[x]                      max r_hat: 1.3810
dense_mass = True (good r_hat)
[x]                      max r_hat: 0.9992

示例 #3

当潜空间维度 D 非常大时,使用 dense_mass=True 会非常昂贵。此外,如果 D 很大,使用适量的样本估计一个具有 D^2 参数的满秩质量矩阵也可能很困难。在这些情况下,dense_mass=True 可能是一个糟糕的选择。幸运的是,参数 dense_mass 也可以用来指定结构化质量矩阵,这些矩阵比对角质量矩阵更丰富,但比满秩质量矩阵更受约束(即参数更少)(参见文档)。在第二个示例中,我们展示了如何使用 dense_mass 来指定这种结构化质量矩阵。

[10]:
rho = 0.9
cov = jnp.array([[10.0, rho], [rho, 0.1]])


# In this model x1 and x2 are highly correlated with one another
# but not correlated with y at all.
def partially_correlated_model():
    x1 = numpyro.sample(
        "x1", dist.MultivariateNormal(jnp.zeros(2), covariance_matrix=cov)
    )
    x2 = numpyro.sample(
        "x2", dist.MultivariateNormal(jnp.zeros(2), covariance_matrix=cov)
    )
    numpyro.sample("y", dist.Normal(jnp.zeros(100), 1.0))
    numpyro.sample("obs", dist.Normal(x1 - x2, 0.1), jnp.ones(2))

现在让我们比较两种 dense_mass 的选择。

[11]:
print("dense_mass = False (very bad r_hats)")
run_inference(partially_correlated_model, dense_mass=False, max_tree_depth=3)

print("\ndense_mass = True (bad r_hats)")
run_inference(partially_correlated_model, dense_mass=True, max_tree_depth=3)

# We use dense_mass=[("x1", "x2")] to specify
# a structured mass matrix in which the y-part of the mass matrix is diagonal
# and the (x1, x2) block of the mass matrix is full-rank.

# Graphically:
#
#       x1 x2 y
#   x1 | * * 0 |
#   x2 | * * 0 |
#   y  | 0 0 * |

print("\nstructured mass matrix (good r_hats)")
run_inference(partially_correlated_model, dense_mass=[("x1", "x2")], max_tree_depth=3)
dense_mass = False (very bad r_hats)
[x1]                     max r_hat: 1.5882
[x2]                     max r_hat: 1.5410
[y]                      max r_hat: 2.0179

dense_mass = True (bad r_hats)
[x1]                     max r_hat: 1.0697
[x2]                     max r_hat: 1.0738
[y]                      max r_hat: 1.2746

structured mass matrix (good r_hats)
[x1]                     max r_hat: 1.0023
[x2]                     max r_hat: 1.0024
[y]                      max r_hat: 1.0030

max_tree_depth

超参数 max_tree_depth 在确定 NUTS 生成的后验样本质量方面起着重要作用。NumPyro 中的默认值为 max_tree_depth=10。在某些模型中,特别是那些几何形状特别困难的模型,可能需要将 max_tree_depth 增加到 10 以上。在计算模型对数密度梯度特别昂贵的情况下,可能需要将 max_tree_depth 减小到 10 以下以减少计算量。作为一个需要较大 max_tree_depth 的示例,我们回到示例 #2 的变体。(我们注意到,在这个特定情况下,另一种提高性能的方法是使用 dense_mass=True)。

示例 #4

[12]:
# Because rho is very close to 1.0 the posterior geometry is extremely
# skewed and using small max_tree_depth leads to bad results.
rho = 0.999
dim = 200
cov = rho * jnp.ones((dim, dim)) + (1 - rho) * jnp.eye(dim)


def mvn_model():
    numpyro.sample("x", dist.MultivariateNormal(jnp.zeros(dim), covariance_matrix=cov))


print("max_tree_depth = 5 (bad r_hat)")
run_inference(mvn_model, max_tree_depth=5)

print("max_tree_depth = 10 (good r_hat)")
run_inference(mvn_model, max_tree_depth=10)
max_tree_depth = 5 (bad r_hat)
[x]                      max r_hat: 1.1159
max_tree_depth = 10 (good r_hat)
[x]                      max r_hat: 1.0166

其他策略

  • 在某些情况下,使用变分推断来学习新的坐标系是合理的。详情请参见 examples/neutra.py 和参考文献 [2]。

参考文献

[1] “用于分层模型的哈密顿蒙特卡洛,” M. J. Betancourt, Mark Girolami。

[2] “使用神经传输消除哈密顿蒙特卡洛中的不良几何形状,” Matthew Hoffman, Pavel Sountsov, Joshua V. Dillon, Ian Langmore, Dustin Tran, Srinivas Vasudevan。

[3] “重新参数化”,在Stan用户手册中。 https://mc-stan.org/docs/2_27/stan-users-guide/reparameterization-section.html