NumPyro 入门

JAX 提供支持的概率编程,用于 GPU/TPU/CPU 的自动微分和 JIT 编译。

文档和示例 | 论坛


什么是 NumPyro?

NumPyro 是一个轻量级的概率编程库,为 Pyro 提供 NumPy 后端。我们依赖 JAX 进行自动微分和 JIT 编译到 GPU/CPU。NumPyro 正在积极开发中,因此要注意其脆弱性、错误以及随着设计演变而对 API 产生的更改。

NumPyro 被设计为轻量级的,专注于提供一个灵活的基础供用户在其上构建

  • Pyro 原语:NumPyro 程序除了包含 Pyro 原语,如 sampleparam,还可以包含常规 Python 和 NumPy 代码。模型代码看起来与 Pyro 非常相似,只是 PyTorch 和 Numpy 的 API 有一些细微差异。请参阅下面的示例

  • 推断算法:NumPyro 支持多种推断算法,特别关注哈密顿蒙特卡罗等 MCMC 算法,包括无 U-Turn 采样器的实现。其他 MCMC 算法包括 MixedHMC(可容纳离散隐变量)以及 HMCECS(仅计算每次迭代中数据子集的似然)。NumPyro 的动机之一是通过 JIT 编译包含多次梯度计算的 Verlet 积分器来加速哈密顿蒙特卡罗。利用 JAX,我们可以组合 jitgrad 将整个积分步骤编译成优化的 XLA 内核。我们还通过 JIT 编译 NUTS 中的整个树构建阶段来消除 Python 开销(这可以通过使用 迭代 NUTS 实现)。此外,还有一个基本的变分推断实现,以及用于自动微分变分推断 (ADVI) 的许多灵活的(自动)指导函数。变分推断实现支持多种特性,包括支持带有离散隐变量的模型(参见 TraceGraph_ELBOTraceEnum_ELBO)。

  • 分布:numpyro.distributions 模块提供了分布类、约束和双射变换。这些分布类封装了用于 JAX 函数式伪随机数生成器 的采样器。分布模块的设计很大程度上借鉴了 PyTorch。API 的主要子集已实现,并且包含了 PyTorch 中存在的大部分常见分布。因此,Pyro 和 PyTorch 用户可以依赖与 torch.distributions 相同的 API 和批处理语义。除了分布之外,constraintstransforms 在处理有界支持的分布类时非常有用。最后,TensorFlow Probability (TFP) 中的分布可以直接用于 NumPyro 模型。

  • 效应处理器:与 Pyro 类似,可以使用 numpyro.handlers 模块中的效应处理器为 sampleparam 等原语提供非标准解释,并且这些处理器可以轻松扩展以实现自定义推断算法和推断工具。

一个简单示例 - 8 所学校

让我们通过一个简单示例来探索 NumPyro。我们将使用 Gelman 等人《贝叶斯数据分析:第 5.5 节,2003 年》中的八校示例,该示例研究辅导对八所学校 SAT 成绩的影响。

数据如下所示

>>> import numpy as np



>>> J = 8

>>> y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])

>>> sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])

,其中 y 是治疗效果,sigma 是标准误差。我们为该研究构建一个分层模型,假设每所学校的组级参数 theta 从均值未知 mu 和标准差未知 tau 的正态分布中采样,而观测数据则依次从均值和标准差分别为 theta(真实效果)和 sigma 的正态分布中生成。这使得我们可以通过汇集所有观测数据来估计总体级别参数 mutau,同时仍然允许使用组级参数 theta 来考虑学校之间的个体差异。

>>> import numpyro

>>> import numpyro.distributions as dist



>>> # Eight Schools example

... def eight_schools(J, sigma, y=None):

...     mu = numpyro.sample('mu', dist.Normal(0, 5))

...     tau = numpyro.sample('tau', dist.HalfCauchy(5))

...     with numpyro.plate('J', J):

...         theta = numpyro.sample('theta', dist.Normal(mu, tau))

...         numpyro.sample('obs', dist.Normal(theta, sigma), obs=y)

让我们使用无 U-Turn 采样器 (NUTS) 运行 MCMC 来推断模型中未知参数的值。注意在 MCMC.run 中使用 extra_fields 参数。默认情况下,当我们使用 MCMC 运行推断时,我们只从目标(后验)分布中收集样本。然而,通过使用 extra_fields 参数可以轻松收集更多字段,例如势能或样本的接受概率。有关可收集字段的列表,请参阅 HMCState 对象。在本示例中,我们将额外收集每个样本的 potential_energy

>>> from jax import random

>>> from numpyro.infer import MCMC, NUTS



>>> nuts_kernel = NUTS(eight_schools)

>>> mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)

>>> rng_key = random.PRNGKey(0)

>>> mcmc.run(rng_key, J, sigma, y=y, extra_fields=('potential_energy',))

我们可以打印 MCMC 运行的摘要,并检查在推断期间是否观察到任何散度。此外,由于我们收集了每个样本的势能,我们可以轻松计算期望对数联合密度。

>>> mcmc.print_summary()  



                mean       std    median      5.0%     95.0%     n_eff     r_hat

        mu      4.14      3.18      3.87     -0.76      9.50    115.42      1.01

       tau      4.12      3.58      3.12      0.51      8.56     90.64      1.02

  theta[0]      6.40      6.22      5.36     -2.54     15.27    176.75      1.00

  theta[1]      4.96      5.04      4.49     -1.98     14.22    217.12      1.00

  theta[2]      3.65      5.41      3.31     -3.47     13.77    247.64      1.00

  theta[3]      4.47      5.29      4.00     -3.22     12.92    213.36      1.01

  theta[4]      3.22      4.61      3.28     -3.72     10.93    242.14      1.01

  theta[5]      3.89      4.99      3.71     -3.39     12.54    206.27      1.00

  theta[6]      6.55      5.72      5.66     -1.43     15.78    124.57      1.00

  theta[7]      4.81      5.95      4.19     -3.90     13.40    299.66      1.00



Number of divergences: 19



>>> pe = mcmc.get_extra_fields()['potential_energy']

>>> print('Expected log joint density: {:.2f}'.format(np.mean(-pe)))  

Expected log joint density: -54.55

Split Gelman Rubin 诊断 (r_hat) 的值大于 1 表明链尚未完全收敛。有效样本量 (n_eff) 的值较低,特别是对于 tau,以及散度转移的数量看起来有问题。幸运的是,这是一个常见的病态,可以通过在模型中对 tau 使用非中心化参数化来纠正。这在 NumPyro 中很容易实现,只需结合 TransformedDistribution 实例和 reparameterization 效应处理器。让我们重写相同的模型,但不是从 Normal(mu, tau) 中采样 theta,而是从一个使用 AffineTransform 变换过的基础 Normal(0, 1) 分布中采样。请注意,这样做时,NumPyro 会通过为基础 Normal(0, 1) 分布生成样本 theta_base 来运行 HMC。我们看到结果链没有出现相同的病态——所有参数的 Gelman Rubin 诊断均为 1,有效样本量看起来非常好!

>>> from numpyro.infer.reparam import TransformReparam



>>> # Eight Schools example - Non-centered Reparametrization

... def eight_schools_noncentered(J, sigma, y=None):

...     mu = numpyro.sample('mu', dist.Normal(0, 5))

...     tau = numpyro.sample('tau', dist.HalfCauchy(5))

...     with numpyro.plate('J', J):

...         with numpyro.handlers.reparam(config={'theta': TransformReparam()}):

...             theta = numpyro.sample(

...                 'theta',

...                 dist.TransformedDistribution(dist.Normal(0., 1.),

...                                              dist.transforms.AffineTransform(mu, tau)))

...         numpyro.sample('obs', dist.Normal(theta, sigma), obs=y)



>>> nuts_kernel = NUTS(eight_schools_noncentered)

>>> mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)

>>> rng_key = random.PRNGKey(0)

>>> mcmc.run(rng_key, J, sigma, y=y, extra_fields=('potential_energy',))

>>> mcmc.print_summary(exclude_deterministic=False)  



                   mean       std    median      5.0%     95.0%     n_eff     r_hat

           mu      4.08      3.51      4.14     -1.69      9.71    720.43      1.00

          tau      3.96      3.31      3.09      0.01      8.34    488.63      1.00

     theta[0]      6.48      5.72      6.08     -2.53     14.96    801.59      1.00

     theta[1]      4.95      5.10      4.91     -3.70     12.82   1183.06      1.00

     theta[2]      3.65      5.58      3.72     -5.71     12.13    581.31      1.00

     theta[3]      4.56      5.04      4.32     -3.14     12.92   1282.60      1.00

     theta[4]      3.41      4.79      3.47     -4.16     10.79    801.25      1.00

     theta[5]      3.58      4.80      3.78     -3.95     11.55   1101.33      1.00

     theta[6]      6.31      5.17      5.75     -2.93     13.87   1081.11      1.00

     theta[7]      4.81      5.38      4.61     -3.29     14.05    954.14      1.00

theta_base[0]      0.41      0.95      0.40     -1.09      1.95    851.45      1.00

theta_base[1]      0.15      0.95      0.20     -1.42      1.66   1568.11      1.00

theta_base[2]     -0.08      0.98     -0.10     -1.68      1.54   1037.16      1.00

theta_base[3]      0.06      0.89      0.05     -1.42      1.47   1745.02      1.00

theta_base[4]     -0.14      0.94     -0.16     -1.65      1.45    719.85      1.00

theta_base[5]     -0.10      0.96     -0.14     -1.57      1.51   1128.45      1.00

theta_base[6]      0.38      0.95      0.42     -1.32      1.82   1026.50      1.00

theta_base[7]      0.10      0.97      0.10     -1.51      1.65   1190.98      1.00



Number of divergences: 0



>>> pe = mcmc.get_extra_fields()['potential_energy']

>>> # Compare with the earlier value

>>> print('Expected log joint density: {:.2f}'.format(np.mean(-pe)))  

Expected log joint density: -46.09

请注意,对于带有 loc,scale 参数的分布类,例如 NormalCauchyStudentT,我们还提供了 LocScaleReparam 再参数化器来实现相同的目的。相应的代码将是

with numpyro.handlers.reparam(config={'theta': LocScaleReparam(centered=0)}):

    theta = numpyro.sample('theta', dist.Normal(mu, tau))

现在,假设我们有一所新学校,尚未观察到任何考试成绩,但我们想生成预测。NumPyro 为此目的提供了一个 Predictive 类。请注意,在没有观察到任何数据的情况下,我们只需使用总体级别参数来生成预测。Predictive 工具将未观测到的 mutau 位置设定为从我们上次 MCMC 运行的后验分布中抽取的值,然后向前运行模型以生成预测。

>>> from numpyro.infer import Predictive



>>> # New School

... def new_school():

...     mu = numpyro.sample('mu', dist.Normal(0, 5))

...     tau = numpyro.sample('tau', dist.HalfCauchy(5))

...     return numpyro.sample('obs', dist.Normal(mu, tau))



>>> predictive = Predictive(new_school, mcmc.get_samples())

>>> samples_predictive = predictive(random.PRNGKey(1))

>>> print(np.mean(samples_predictive['obs']))  

3.9886456

更多示例

有关如何在 NumPyro 中指定模型和进行推断的更多示例

Pyro 用户会注意到,模型指定和推断的 API,包括分布 API,在很大程度上与 Pyro 相同,这是设计使然。但是,有一些重要的核心差异(反映在内部实现中),用户应该注意。例如,在 NumPyro 中,没有全局参数存储或随机状态,这使得我们可以利用 JAX 的 JIT 编译。此外,用户可能需要以更函数式的风格编写模型,以便更好地与 JAX 配合使用。有关差异列表,请参阅常见问题解答

推断算法概述

我们概述了 NumPyro 支持的大部分推断算法,并提供了一些关于哪些推断算法可能适用于不同类型模型的指南。

MCMC

  • NUTSHMC 的一种自适应变体,可能是 NumPyro 中最常用的推断算法。请注意,NUTS 和 HMC 不直接适用于带有离散隐变量的模型,但在离散变量具有有限支持且可以进行求和(即枚举)的情况下,NumPyro 将自动对离散隐变量求和,并对剩余的连续隐变量执行 NUTS/HMC。

如上所述,在某些情况下,模型再参数化对于获得良好性能可能很重要。请注意,一般来说,随着隐空间维度的增加,推断会变得更加困难。请参阅不良几何形状教程以获取更多提示和技巧。

  • MixedHMC 对于包含连续和离散隐变量的模型来说,是一种有效的推断策略。

  • HMCECS 对于具有大量数据点的模型来说,是一种有效的推断策略。它适用于具有连续隐变量的模型。参见此处的示例。

  • BarkerMH 是一种基于梯度的 MCMC 方法,对于某些模型可能与 HMC 和 NUTS 竞争。它适用于具有连续隐变量的模型。

  • HMCGibbs 将 HMC/NUTS 步骤与自定义 Gibbs 更新结合。Gibbs 更新必须由用户指定。

  • DiscreteHMCGibbs 将 HMC/NUTS 步骤与离散隐变量的 Gibbs 更新结合。相应的 Gibbs 更新是自动计算的。

  • SA 是 NumPyro 中唯一不利用梯度的 MCMC 方法。它仅适用于具有连续隐变量的模型。预计它对于隐变量维度较低或中等的模型表现最佳。对于具有不可微分对数密度的模型,它可能是一个不错的选择。请注意,SA 通常需要非常大量的样本,因为混合速度往往很慢。从积极的一面来看,单个步骤可以很快。

与 HMC/NUTS 类似,所有剩余的 MCMC 算法如果可能都支持对离散隐变量进行枚举(参见限制)。枚举位置需要用 infer={'enumerate': 'parallel'} 进行标记,就像在注释示例中一样。

嵌套采样

随机变分推断

  • 变分目标

    • Trace_ELBO 是我们基本的 ELBO 实现。

    • TraceMeanField_ELBO 类似于 Trace_ELBO,但如果可能,会解析地计算部分 ELBO。

    • TraceGraph_ELBO 为带有离散隐变量的模型提供了方差减少策略。一般来说,带有离散隐变量的模型应始终使用此 ELBO。

    • TraceEnum_ELBO 为带有离散隐变量的模型提供了变量枚举策略。一般来说,当可以进行枚举时,带有离散隐变量的模型应始终使用此 ELBO。

  • 自动指导函数(适用于带有连续隐变量的模型)

    • AutoNormalAutoDiagonalNormal 是我们基本的均值场指导函数。如果隐空间是非欧几里得的(例如由于某个采样位置的非负约束),则会在底层自动使用适当的双射变换来将未约束空间(定义正态变分分布的位置)映射到相应的约束空间(请注意,这对于所有自动指导函数都适用)。当您尝试让变分推断在您正在开发的模型上工作时,这些指导函数是一个很好的起点。

    • AutoMultivariateNormalAutoLowRankMultivariateNormal 也构建正态变分分布,但提供了更大的灵活性,因为它们可以捕获后验中的相关性。请注意,在高维设置中拟合这些指导函数可能很困难。

    • AutoDelta 用于通过 MAP(最大后验估计)计算点估计。参见此处的使用示例。

    • AutoBNAFNormalAutoIAFNormal 提供了由归一化流参数化的灵活变分分布。

    • AutoDAIS 是一种强大的变分推断算法,它利用 HMC。它是处理高度相关后验的不错选择,但根据模型的性质,计算成本可能很高。

    • AutoSurrogateLikelihoodDAIS 是一种强大的变分推断算法,它利用 HMC 并支持数据子采样。

    • AutoSemiDAIS 为局部隐变量构建了一个类似于 AutoDAIS 的后验近似,但通过利用全局隐变量的参数化指导函数,支持在 ELBO 训练期间进行数据子采样。

    • AutoLaplaceApproximation 可用于计算拉普拉斯近似。

Stein 变分推断

有关更多详细信息,请参阅文档

安装

有限的 Windows 支持:请注意,NumPyro 未在 Windows 上测试,可能需要从源代码构建 jaxlib。有关更多详细信息,请参阅此 JAX 问题。或者,您可以安装 适用于 Linux 的 Windows 子系统,并在其上像在 Linux 系统上一样使用 NumPyro。如果您想在 Windows 上使用 GPU,另请参阅 适用于 Linux 的 Windows 子系统上的 CUDA此论坛帖子

要安装 NumPyro 和最新 CPU 版的 JAX,可以使用 pip

pip install numpyro

如果在执行上述命令时出现兼容性问题,可以强制安装一个已知

兼容的 CPU 版 JAX,使用

pip install 'numpyro[cpu]'

要在 GPU 上使用 NumPyro,您需要先安装 CUDA,然后使用以下 pip 命令

pip install 'numpyro[cuda]' -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

如果您需要进一步指导,请参阅 JAX GPU 安装说明

要在 Cloud TPU 上运行 NumPyro,您可以查看一些 JAX 在 Cloud TPU 上的示例

对于 Cloud TPU VM,您需要按照 Cloud TPU VM JAX 快速入门指南 中详细说明的方式设置 TPU 后端。

在您确认 TPU 后端已正确设置后,

您可以使用 pip install numpyro 命令安装 NumPyro。

默认平台:如果安装了支持 CUDA 的 jaxlib 包,JAX 将默认使用 GPU。您可以使用 set_platform 工具函数 numpyro.set_platform("cpu") 在程序开始时切换到 CPU。

您也可以从源代码安装 NumPyro

git clone https://github.com/pyro-ppl/numpyro.git

cd numpyro

# install jax/jaxlib first for CUDA support

pip install -e '.[dev]'  # contains additional dependencies for NumPyro development

您也可以使用 conda 安装 NumPyro

conda install -c conda-forge numpyro

常见问题解答

  1. 与 Pyro 不同,numpyro.sample('x', dist.Normal(0, 1)) 不起作用。为什么?

    您很可能在推断上下文之外使用了 numpyro.sample 语句。JAX 没有全局随机状态,因此,分布采样器需要一个显式的随机数生成器密钥 (PRNGKey) 来生成样本。NumPyro 的推断算法在幕后使用 seed 处理器来传入随机数生成器密钥。

    您的选项包括

    • PRNGKey,例如 dist.Normal(0, 1).sample(PRNGKey(0))

    • numpyro.sample 提供 rng_key 参数。例如 numpyro.sample('x', dist.Normal(0, 1), rng_key=PRNGKey(0))

    • 将代码包装在 seed 处理器中,可以作为上下文管理器使用,也可以作为包装原始可调用对象的函数使用。例如

      with handlers.seed(rng_seed=0):  # random.PRNGKey(0) is used
      
          x = numpyro.sample('x', dist.Beta(1, 1))    # uses a PRNGKey split from random.PRNGKey(0)
      
          y = numpyro.sample('y', dist.Bernoulli(x))  # uses different PRNGKey split from the last one
      

      ,或者作为高阶函数

      def fn():
      
          x = numpyro.sample('x', dist.Beta(1, 1))
      
          y = numpyro.sample('y', dist.Bernoulli(x))
      
          return y
      
      
      
      print(handlers.seed(fn, rng_seed=0)())
      
  2. 我可以使用相同的 Pyro 模型在 NumPyro 中进行推断吗?

    正如您从示例中可能注意到的那样,NumPyro 支持所有 Pyro 原语,如 sampleparamplatemodule,以及效应处理器。此外,我们确保 distributions API 基于 torch.distributions,并且 SVIMCMC 等推断类具有相同的接口。再加上 NumPy 和 PyTorch 操作 API 的相似性,确保包含 Pyro 原语语句的模型只需稍作修改即可与任一后端一起使用。下面列出了一些差异以及所需的更改示例

    • 模型中的任何 torch 操作都需要用相应的 jax.numpy 操作来编写。此外,并非所有 torch 操作都有 numpy 对应项(反之亦然),有时 API 也存在细微差异。

    • 推断上下文之外的 pyro.sample 语句需要包装在 seed 处理器中,如上所述。

    • 没有全局参数存储,因此在推断上下文之外使用 numpyro.param 将不起作用。要从 SVI 中检索优化后的参数值,请使用 SVI.get_params 方法。请注意,您仍然可以在模型内部使用 param 语句,并且 NumPyro 在 SVI 中运行模型时,将在内部使用 substitute 效应处理器来替换优化器中的值。

    • PyTorch 神经网络模块需要重写为 staxflaxhaiku 神经网络。有关这两种后端之间语法的差异,请参阅 VAEProdLDA 示例。

    • JAX 最适合函数式代码,特别是如果我们想利用 JIT 编译,NumPyro 在内部对许多推断子程序都进行了 JIT 编译。因此,如果您的模型具有 JAX 跟踪器不可见的副作用,则可能需要以更函数式的风格重写。

    对于大多数小型模型,在 NumPyro 中运行推断所需的更改应该很小。此外,我们正在开发 pyro-api,它允许您编写相同的代码并将其分派到包括 NumPyro 在内的多个后端。这必然会更具限制性,但优点是与后端无关。有关示例,请参阅文档,并请告诉我们您的反馈。

  3. 我如何为项目做贡献?

    感谢您对项目的关注!您可以查看 Github 上标记有 适合初学者的问题 标签的 issue。另外,欢迎在论坛上与我们联系。

未来/正在进行的工作

近期,我们计划开展以下工作。请为功能需求和增强功能开启新的 issue

  • 提高不同模型上推断的鲁棒性,性能分析和性能调优。

  • 支持作为 pyro-api 通用建模接口一部分的更多功能。

  • 更多推断算法,特别是那些需要二阶导数或使用 HMC 的算法。

  • Funsor 集成以支持带有延迟采样的推断算法。

  • 受 Pyro 研究目标和应用重点以及社区兴趣驱动的其他领域。

引用 NumPyro

NumPyro 背后的核心思想和对迭代 NUTS 的描述可以在这篇发表在 NeurIPS 2019 机器学习程序变换研讨会上的论文中找到。

如果您使用 NumPyro,请考虑引用

@article{phan2019composable,

  title={Composable Effects for Flexible and Accelerated Probabilistic Programming in NumPyro},

  author={Phan, Du and Pradhan, Neeraj and Jankowiak, Martin},

  journal={arXiv preprint arXiv:1912.11554},

  year={2019}

}

以及

@article{bingham2019pyro,

  author    = {Eli Bingham and

               Jonathan P. Chen and

               Martin Jankowiak and

               Fritz Obermeyer and

               Neeraj Pradhan and

               Theofanis Karaletsos and

               Rohit Singh and

               Paul A. Szerlip and

               Paul Horsfall and

               Noah D. Goodman},

  title     = {Pyro: Deep Universal Probabilistic Programming},

  journal   = {J. Mach. Learn. Res.},

  volume    = {20},

  pages     = {28:1--28:6},

  year      = {2019},

  url       = {http://jmlr.org/papers/v20/18-403.html}

}