马尔可夫链蒙特卡洛 (MCMC)

我们提供 NumPyro 中 MCMC 算法的高层概述

  • NUTSHMC 的一个自适应变体,可能是 NumPyro 中最常用的 MCMC 算法。请注意,NUTS 和 HMC 不直接适用于带有离散潜变量的模型,但在离散变量支持有限且枚举可行的情况下,NumPyro 会自动对离散潜变量进行求和(即枚举),并对剩余的连续潜变量执行 NUTS/HMC。如上所述,模型重参数化在某些情况下可能对于获得良好的性能至关重要。请注意,一般来说,随着潜在空间维度的增加,我们预计推断会更加困难。有关更多提示和技巧,请参阅不良几何形状教程。

  • MixedHMC 对于包含连续和离散潜变量的模型来说,是一种有效的推断策略。

  • HMCECS 对于数据点数量庞大的模型来说,是一种有效的推断策略。它适用于具有连续潜变量的模型。有关详细用法,请参阅此示例

  • BarkerMH 是一种基于梯度的 MCMC 方法,对于某些模型来说,它可能与 HMC 和 NUTS 具有竞争力。它适用于具有连续潜变量的模型。

  • HMCGibbs 将 HMC/NUTS 步骤与自定义 Gibbs 更新相结合。Gibbs 更新必须由用户指定。

  • DiscreteHMCGibbs 将 HMC/NUTS 步骤与离散潜变量的 Gibbs 更新相结合。相应的 Gibbs 更新是自动计算的。

  • SA 是一种无梯度的 MCMC 方法。它仅适用于具有连续潜变量的模型。预计对于潜在维度低到中等的模型性能最佳。对于对数密度不可导的模型来说,这可能是一个不错的选择。请注意,SA 通常需要非常大量的样本,因为混合往往很慢。从积极的一面看,单个步骤可以很快。

  • AIES 是一种无梯度的集成 MCMC 方法,通过链之间共享信息来指导 Metropolis-Hastings 提议。它仅适用于具有连续潜变量的模型。预计对于潜在维度低到中等的模型性能最佳。对于对数密度不可导的模型来说,这可能是一个不错的选择,并且对无似然模型具有鲁棒性。AIES 通常要求链的数量是潜在参数数量的两倍(理想情况下更大)。

  • ESS 是一种无梯度的集成 MCMC 方法,通过链之间共享信息来找到好的切片采样方向。它往往比 AIES 更具样本效率。它仅适用于具有连续潜变量的模型。预计对于潜在维度低到中等的模型性能最佳,并且对于对数密度不可导的模型来说可能是一个不错的选择。ESS 通常要求链的数量是潜在参数数量的两倍(理想情况下更大)。

与 HMC/NUTS 类似,如果可能,所有剩余的 MCMC 算法都支持对离散潜变量进行枚举(参见限制)。枚举站点需要标记 infer={‘enumerate’: ‘parallel’},就像标注示例中那样。

class MCMC(sampler, *, num_warmup, num_samples, num_chains=1, thinning=1, postprocess_fn=None, chain_method='parallel', progress_bar=True, jit_model_args=False)[source]

基类:object

提供对 NumPyro 中马尔可夫链蒙特卡洛推断算法的访问。

注意

chain_method 是一个实验性参数,未来版本可能会移除。

注意

progress_bar 设置为 False 在许多情况下可以提高速度。但这可能比其他选项需要更多内存。

注意

如果在 Jupyter Notebook 中将 num_chains 设置为大于 1,则需要在启动 Jupyter 的环境中安装 ipywidgets,以便进度条正确渲染。如果您正在使用 Jupyter Notebook 或 Jupyter Lab,请同时安装相应的扩展包,例如 widgetsnbextensionjupyterlab_widgets

注意

如果您的数据集很大并且您有多个加速设备可用,您可以将计算分布到多个设备上。请确保您的 jax 版本是 v0.4.4 或更新。例如,

import jax
from jax.experimental import mesh_utils
from jax.sharding import PositionalSharding
import numpy as np
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS

X = np.random.randn(128, 3)
y = np.random.randn(128)

def model(X, y):
    beta = numpyro.sample("beta", dist.Normal(0, 1).expand([3]))
    numpyro.sample("obs", dist.Normal(X @ beta, 1), obs=y)

mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
# See https://jax.net.cn/en/stable/notebooks/Distributed_arrays_and_automatic_parallelization.html
sharding = PositionalSharding(mesh_utils.create_device_mesh((8,)))
X_shard = jax.device_put(X, sharding.reshape(8, 1))
y_shard = jax.device_put(y, sharding.reshape(8))
mcmc.run(jax.random.PRNGKey(0), X_shard, y_shard)
参数:
  • sampler (MCMCKernel) – MCMCKernel 的实例,用于确定运行 MCMC 的采样器。目前,只有 HMCNUTS 可用。

  • num_warmup (int) – 热身步骤的数量。

  • num_samples (int) – 从马尔可夫链生成的样本数量。

  • thinning (int) – 控制保留的热身后期样本比例的正整数。例如,如果 thinning 是 2,则每隔一个样本被保留。默认为 1,即不进行稀释。

  • num_chains (int) – 要运行的 MCMC 链的数量。默认情况下,链将使用 jax.pmap() 并行运行。如果可用设备不足,链将按顺序运行。

  • postprocess_fn – 后处理可调用对象 - 用于将采样器返回的无约束样本值集合转换为位于样本站点支持范围内的约束值。此外,它还用于返回模型中确定性站点的值。

  • chain_method (str) – 一个可调用的 jax 转换,如 jax.vmap 或 ‘parallel’(默认)、‘sequential’、‘vectorized’ 之一。‘parallel’ 方法用于在 XLA 设备(CPU/GPU/TPU)上并行执行采样过程,如果设备不足以进行 ‘parallel’,则回退到 ‘sequential’ 方法按顺序绘制链。‘vectorized’ 方法是一个实验性功能,它向量化了采样方法,因此允许我们在单个设备上并行收集样本。

  • progress_bar (bool) – 是否启用进度条更新。默认为 True

  • jit_model_args (bool) – 如果设置为 True,这将把势能计算编译为模型参数的函数。因此,在大小相同但数据集不同的情况下再次调用 MCMC.run 不会导致额外的编译成本。请注意,目前这对于 num_chains > 1chain_method == 'parallel' 的情况不生效。

注意

混合并行和向量化采样是可能的,即使用显式 pmap 在多个设备上运行向量化链。目前,这样做需要禁用进度条。例如,

def do_mcmc(rng_key, n_vectorized=8):
    nuts_kernel = NUTS(model)
    mcmc = MCMC(
        nuts_kernel,
        progress_bar=False,
        num_chains=n_vectorized,
        chain_method='vectorized'
    )
    mcmc.run(
        rng_key,
        extra_fields=("potential_energy",),
    )
    return {**mcmc.get_samples(), **mcmc.get_extra_fields()}
# Number of devices to pmap over
n_parallel = jax.local_device_count()
rng_keys = jax.random.split(PRNGKey(rng_seed), n_parallel)
traces = pmap(do_mcmc)(rng_keys)
# concatenate traces along pmap'ed axis
trace = {k: np.concatenate(v) for k, v in traces.items()}
property post_warmup_state

采样阶段之前的状态。如果此属性不为 None,则 run() 将跳过热身阶段,并从此属性指定的状态开始。

注意

此属性可用于顺序绘制 MCMC 样本。例如,

mcmc = MCMC(NUTS(model), num_warmup=100, num_samples=100)
mcmc.run(random.PRNGKey(0))
first_100_samples = mcmc.get_samples()
mcmc.post_warmup_state = mcmc.last_state
mcmc.run(mcmc.post_warmup_state.rng_key)  # or mcmc.run(random.PRNGKey(1))
second_100_samples = mcmc.get_samples()
property last_state

采样阶段结束时的最终 MCMC 状态。

warmup(rng_key, *args, extra_fields=(), collect_warmup=False, init_params=None, **kwargs)[source]

运行 MCMC 热身适应阶段。调用此方法后,self.post_warmup_state 将被设置,并且 run() 方法将跳过热身适应阶段。要针对新数据再次运行 warmup,需要再次运行 warmup()

参数:
  • rng_key (random.PRNGKey) – 用于采样的随机数生成器密钥。

  • args – 提供给 numpyro.infer.mcmc.MCMCKernel.init() 方法的参数。这些通常是 model 所需的参数。

  • extra_fields (tuple or list) – 在 MCMC 运行期间要收集的状态对象中的额外字段(除了 default_fields()),例如 HMC 的 numpyro.infer.hmc.HMCState。使用 “~`sampler.sample_field`.`sample_site`” 从收集中排除样本站点。例如,如果您使用 NUTS 采样器,则 “~z.a” 将阻止收集站点 “a”。要收集站点 “a” 在无约束空间中的样本,我们可以在此处指定变量,例如 extra_fields=(“z.a”,)

  • collect_warmup (bool) – 是否收集热身阶段的样本。默认为 False

  • init_params – 开始采样的初始参数。类型必须与提供给核的 potential_fn 的输入类型一致。如果核是通过 numpyro 模型实例化的,这里的初始参数对应于无约束空间中的潜值。

  • kwargs – 提供给 numpyro.infer.mcmc.MCMCKernel.init() 方法的关键字参数。这些通常是 model 所需的关键字参数。

run(rng_key, *args, extra_fields=(), init_params=None, **kwargs)[source]

运行 MCMC 采样器并收集样本。

参数:
  • rng_key (random.PRNGKey) – 用于采样的随机数生成器密钥。对于多链,可以提供一批 num_chains 密钥。如果 rng_key 没有 batch_size,它将被分割成一批 num_chains 密钥。

  • args – 提供给 numpyro.infer.mcmc.MCMCKernel.init() 方法的参数。这些通常是 model 所需的参数。

  • extra_fields (tuple or list of str) – 在 MCMC 运行期间要收集的状态对象中的额外字段(除了 “z”“diverging”),例如 HMC 的 numpyro.infer.hmc.HMCState。请注意,可以使用点来访问子字段,例如,可以使用 “adapt_state.step_size” 来收集每个步骤的步长。使用 “~`sampler.sample_field`.`sample_site`” 从收集中排除样本站点。例如,如果您使用 NUTS 采样器,则 “~z.a” 将阻止收集站点 “a”。要收集站点 “a” 在无约束空间中的样本,我们可以在此处指定变量,例如 extra_fields=(“z.a”,)

  • init_params – 开始采样的初始参数。类型必须与提供给核的 potential_fn 的输入类型一致。如果核是通过 numpyro 模型实例化的,这里的初始参数对应于无约束空间中的潜值。

  • kwargs – 提供给 numpyro.infer.mcmc.MCMCKernel.init() 方法的关键字参数。这些通常是 model 所需的关键字参数。

注意

jax 允许 Python 代码在编译代码尚未完成时继续执行。这在尝试分析代码速度时可能会导致问题。请参阅 https://jax.net.cn/en/stable/async_dispatch.htmlhttps://jax.net.cn/en/stable/profiling.html 以获取有关分析 jax 程序的提示。

get_samples(group_by_chain=False)[source]

获取 MCMC 运行中的样本。

参数:

group_by_chain (bool) – 是否保留链的维度。如果为 True,所有样本的前导维度大小将是 num_chains。

返回值:

样本的数据类型与 init_params 相同。如果使用包含 Pyro 原语的模型,数据类型是一个以站点名称为键的 dict,但更一般地可以是任何 jaxlib.pytree()(例如,当为接受 list 参数的 HMC 定义 potential_fn 时)。

示例

然后您可以将这些样本传递给 Predictive

posterior_samples = mcmc.get_samples()
predictive = Predictive(model, posterior_samples=posterior_samples)
samples = predictive(rng_key1, *model_args, **model_kwargs)
get_extra_fields(group_by_chain=False)[source]

获取 MCMC 运行中的额外字段。

参数:

group_by_chain (bool) – 是否保留链的维度。如果为 True,所有样本的前导维度大小将是 num_chains。

返回值:

额外字段以在 run()extra_fields 关键字中指定的字段名作为键。

print_summary(prob=0.9, exclude_deterministic=True)[source]

打印运行此 MCMC 实例期间收集的后验样本统计信息。

参数:
  • prob (float) – 可信区间内样本的概率质量。

  • exclude_deterministic (bool) – 是否打印确定性站点的统计信息。

transfer_states_to_host()[source]

通过将收集的样本传输到主机设备来减少内存占用。

MCMC 核

MCMCKernel

class MCMCKernel[source]

基类:ABC

定义用于 MCMC 推断的马尔可夫转移核的接口。

示例

>>> from collections import namedtuple
>>> from jax import random
>>> import jax.numpy as jnp
>>> import numpyro
>>> import numpyro.distributions as dist
>>> from numpyro.infer import MCMC

>>> MHState = namedtuple("MHState", ["u", "rng_key"])

>>> class MetropolisHastings(numpyro.infer.mcmc.MCMCKernel):
...     sample_field = "u"
...
...     def __init__(self, potential_fn, step_size=0.1):
...         self.potential_fn = potential_fn
...         self.step_size = step_size
...
...     def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
...         return MHState(init_params, rng_key)
...
...     def sample(self, state, model_args, model_kwargs):
...         u, rng_key = state
...         rng_key, key_proposal, key_accept = random.split(rng_key, 3)
...         u_proposal = dist.Normal(u, self.step_size).sample(key_proposal)
...         accept_prob = jnp.exp(self.potential_fn(u) - self.potential_fn(u_proposal))
...         u_new = jnp.where(dist.Uniform().sample(key_accept) < accept_prob, u_proposal, u)
...         return MHState(u_new, rng_key)

>>> def f(x):
...     return ((x - 2) ** 2).sum()

>>> kernel = MetropolisHastings(f)
>>> mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000)
>>> mcmc.run(random.PRNGKey(0), init_params=jnp.array([1., 2.]))
>>> posterior_samples = mcmc.get_samples()
>>> mcmc.print_summary()  
postprocess_fn(model_args, model_kwargs)[source]

获取一个函数,该函数将样本站点上的无约束值转换为受限到站点支持范围内的值,并返回模型中的确定性站点。

参数:
  • model_args – 模型的参数。

  • model_kwargs – 模型的关键字参数。

abstract init(rng_key, num_warmup, init_params, model_args, model_kwargs)[source]

初始化 MCMCKernel 并返回一个初始状态以开始采样。

参数:
  • rng_key (random.PRNGKey) – 初始化核的随机数生成器密钥。

  • num_warmup (int) – 热身步骤的数量。这在热身期间进行适应非常有用。

  • init_params (tuple) – 开始采样的初始参数。类型必须与 potential_fn 的输入类型一致。

  • model_args – 提供给模型的参数。

  • model_kwargs – 提供给模型的关键字参数。

返回值:

表示核状态的初始状态。这可以是任何注册为 pytree 的类。

abstract sample(state, model_args, model_kwargs)[source]

给定当前 state,使用给定的转移核返回下一个 state

参数:
  • state

    表示核状态的 pytree 类。对于 HMC,它由 HMCState 给出。通常,这可以是任何支持 getattr 的类。

  • model_args – 提供给模型的参数。

  • model_kwargs – 提供给模型的关键字参数。

返回值:

下一个 state

property sample_field

传递给 sample()state 对象的属性,表示 MCMC 样本。它由 postprocess_fn() 和在 MCMC.print_summary() 中报告结果时使用。

property default_fields

在 MCMC 运行期间(调用 MCMC.run() 时)默认要收集的 state 对象的属性。

property is_ensemble_kernel

表示核是否为集成核。如果为 True,则在 MCMC 运行期间(调用 MCMC.run() 时),如果 chain_method = “vectorized”,将显示 diagnostics_str。

get_diagnostics_str(state)[source]

给定当前 state,返回要添加到进度条中的诊断字符串,用于诊断目的。

BarkerMH

class BarkerMH(model=None, potential_fn=None, step_size=1.0, adapt_step_size=True, adapt_mass_matrix=True, dense_mass=False, target_accept_prob=0.4, init_strategy=<function init_to_uniform>)[source]

基类:MCMCKernel

这是一种基于梯度的 Metropolis-Hastings 类型 MCMC 算法,它使用依赖于势能梯度的偏对称提议分布(Barker 提议;参见参考文献 [1])。特别是,提议分布在当前样本处沿梯度方向倾斜。

我们期望此算法对于低到中等维度的模型特别有效,在这些模型中它可能与 HMC 和 NUTS 具有竞争力。

注意

建议在 MCMC 中使用此核并设置 progress_bar=False,以减少 JAX 的调度开销。

参考文献

  1. Barker 提议:在基于梯度的 MCMC 中结合鲁棒性和效率。Samuel Livingstone, Giacomo Zanella。

参数:
  • model – 包含 Pyro 原语 的 Python 可调用对象。如果提供了 model,将使用 model 推断 potential_fn

  • potential_fn – 计算给定输入参数的势能的 Python 可调用对象。potential_fn 的输入参数可以是任何 Python 集合类型,前提是提供给 init()init_params 参数类型相同。

  • step_size (float) – Barker 提议中使用的(初始)步长。

  • adapt_step_size (bool) – 是否在热身期间调整步长。默认为 adapt_step_size==True

  • adapt_mass_matrix (bool) – 是否在热身期间调整质量矩阵。默认为 adapt_mass_matrix==True

  • dense_mass (bool) – 是否使用密集(即满秩)或对角质量矩阵。(默认为 dense_mass=False)。

  • target_accept_prob (float) – 用于指导步长调整的目标接受概率。增加此值将导致步长变小,从而采样会变慢但更鲁棒。默认为 0.8。

  • init_strategy (callable) – 每个站点的初始化函数。有关可用函数,请参阅初始化策略部分。

示例

>>> import jax
>>> import jax.numpy as jnp
>>> import numpyro
>>> import numpyro.distributions as dist
>>> from numpyro.infer import MCMC, BarkerMH

>>> def model():
...     x = numpyro.sample("x", dist.Normal().expand([10]))
...     numpyro.sample("obs", dist.Normal(x, 1.0), obs=jnp.ones(10))
>>>
>>> kernel = BarkerMH(model)
>>> mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000, progress_bar=True)
>>> mcmc.run(jax.random.PRNGKey(0))
>>> mcmc.print_summary()  
property model
property sample_field

传递给 sample()state 对象的属性,表示 MCMC 样本。它由 postprocess_fn() 和在 MCMC.print_summary() 中报告结果时使用。

get_diagnostics_str(state)[source]

给定当前 state,返回要添加到进度条中的诊断字符串,用于诊断目的。

init(rng_key, num_warmup, init_params, model_args, model_kwargs)[source]

初始化 MCMCKernel 并返回一个初始状态以开始采样。

参数:
  • rng_key (random.PRNGKey) – 初始化核的随机数生成器密钥。

  • num_warmup (int) – 热身步骤的数量。这在热身期间进行适应非常有用。

  • init_params (tuple) – 开始采样的初始参数。类型必须与 potential_fn 的输入类型一致。

  • model_args – 提供给模型的参数。

  • model_kwargs – 提供给模型的关键字参数。

返回值:

表示核状态的初始状态。这可以是任何注册为 pytree 的类。

postprocess_fn(args, kwargs)[source]

获取一个函数,该函数将样本站点上的无约束值转换为受限到站点支持范围内的值,并返回模型中的确定性站点。

参数:
  • model_args – 模型的参数。

  • model_kwargs – 模型的关键字参数。

sample(state, model_args, model_kwargs)[source]

给定当前 state,使用给定的转移核返回下一个 state

参数:
  • state

    表示核状态的 pytree 类。对于 HMC,它由 HMCState 给出。通常,这可以是任何支持 getattr 的类。

  • model_args – 提供给模型的参数。

  • model_kwargs – 提供给模型的关键字参数。

返回值:

下一个 state

HMC

class HMC(model=None, potential_fn=None, kinetic_fn=None, step_size=1.0, inverse_mass_matrix=None, adapt_step_size=True, adapt_mass_matrix=True, dense_mass=False, target_accept_prob=0.8, num_steps=None, trajectory_length=6.283185307179586, init_strategy=<function init_to_uniform>, find_heuristic_step_size=False, forward_mode_differentiation=False, regularize_mass_matrix=True)[source]

基类:MCMCKernel

使用固定轨迹长度的哈密顿蒙特卡洛推断,并提供步长和质量矩阵调整功能。

注意

在 MCMC 运行中使用核之前,postprocess_fn 将返回恒等函数。

注意

默认的初始化策略 init_to_uniform 对于某些模型来说可能不是一个好的策略。您可能想尝试其他初始化策略,例如 init_to_median

参考文献

  1. 使用哈密顿动力学的 MCMC, Radford M. Neal

参数:
  • model – 包含 Pyro 原语 的 Python 可调用对象。如果提供了 model,将使用 model 推断 potential_fn

  • potential_fn – 计算给定输入参数的势能的 Python 可调用对象。potential_fn 的输入参数可以是任何 Python 集合类型,前提是提供给 init()init_params 参数类型相同。

  • kinetic_fn – 返回给定逆质量矩阵和动量的动能的 Python 可调用对象。如果未提供,默认是欧几里得动能。

  • step_size (float) – 确定 verlet 积分器在使用哈密顿动力学计算轨迹时所取单个步骤的大小。如果未指定,将设置为 1。

  • inverse_mass_matrix (numpy.ndarray or dict) – 逆质量矩阵的初始值。如果 adapt_mass_matrix = True,这在热身阶段可能会被调整。如果没有指定值,则初始化为单位矩阵。对于具有一般 JAX pytree 参数的 potential_fn,质量矩阵条目的顺序是通过 jax.tree_flatten 获得的 pytree 参数展平版本的顺序,这有点模糊(更多信息参见 https://jax.net.cn/en/stable/pytrees.html)。如果 model 不为 None,这里我们可以将结构化块质量矩阵指定为一个字典,其中键是站点名称的元组,值是相应的质量矩阵块。有关结构化质量矩阵的更多信息,请参见 dense_mass 参数。

  • adapt_step_size (bool) – 一个标志,用于决定是否要在热身阶段使用 Dual Averaging 方案自适应调整 step_size。

  • adapt_mass_matrix (bool) – 一个标志,用于决定是否要在热身阶段使用 Welford 方案自适应调整质量矩阵。

  • dense_mass (bool or list) –

    此标志控制质量矩阵是密集(即满秩)还是对角(默认为 dense_mass=False)。要指定结构化质量矩阵,用户可以提供一个站点名称元组列表。每个元组代表联合质量矩阵中的一个块。例如,假设模型有潜在变量“x”、“y”、“z”(其中每个变量可以是多维的),可能的规格及其相应的质量矩阵结构如下所示:

    • dense_mass=[(“x”, “y”)]:对联合 (x, y) 使用密集质量矩阵,对 z 使用对角质量矩阵

    • dense_mass=[](等同于 dense_mass=False):对联合 (x, y, z) 使用对角质量矩阵

    • dense_mass=[(“x”, “y”, “z”)](等同于 full_mass=True):对联合 (x, y, z) 使用密集质量矩阵

    • dense_mass=[(“x”,), (“y”,), (“z”)]:对 x、y 和 z 各自使用密集质量矩阵(即块对角,有 3 个块)

  • target_accept_prob (float) – 使用 Dual Averaging 进行步长自适应的目标接受概率。增加此值将导致更小的步长,从而采样会更慢但更稳健。默认为 0.8。

  • num_steps (int) – 如果与 None 不同,则固定每次迭代允许的步数。

  • trajectory_length (float) – HMC 的 MCMC 轨迹长度。默认值为 \(2\pi\)

  • init_strategy (callable) – 每个站点的初始化函数。有关可用函数,请参阅初始化策略部分。

  • find_heuristic_step_size (bool) – 是否在每个适应窗口开始时使用启发式函数调整步长。默认为 False。

  • forward_mode_differentiation (bool) – 是否使用前向模式微分或反向模式微分。默认情况下,我们使用反向模式,但在某些情况下,前向模式有助于提高性能。此外,JAX 中的一些控制流实用程序(如 jax.lax.while_loopjax.lax.fori_loop)仅支持前向模式微分。更多信息请参见 JAX 的自动微分手册

  • regularize_mass_matrix (bool) – 是否在热身阶段正则化估计的质量矩阵以提高数值稳定性。默认为 True。如果 adapt_mass_matrix == False,此标志不起作用。

property model
property sample_field

传递给 sample()state 对象中表示 MCMC 样本的属性。此属性由 postprocess_fn() 使用,也用于在 MCMC.print_summary() 中报告结果。

property default_fields

在 MCMC 运行期间(调用 MCMC.run() 时)默认要收集的 state 对象的属性。

get_diagnostics_str(state)[source]

给定当前 state,返回要添加到进度条中的诊断字符串,用于诊断目的。

init(rng_key, num_warmup, init_params=None, model_args=(), model_kwargs={})[source]

初始化 MCMCKernel 并返回一个初始状态以开始采样。

参数:
  • rng_key (random.PRNGKey) – 初始化核的随机数生成器密钥。

  • num_warmup (int) – 热身步骤的数量。这在热身期间进行适应非常有用。

  • init_params (tuple) – 开始采样的初始参数。类型必须与 potential_fn 的输入类型一致。

  • model_args – 提供给模型的参数。

  • model_kwargs – 提供给模型的关键字参数。

返回值:

表示核状态的初始状态。这可以是任何注册为 pytree 的类。

postprocess_fn(args, kwargs)[source]

获取一个函数,该函数将样本站点上的无约束值转换为受限到站点支持范围内的值,并返回模型中的确定性站点。

参数:
  • model_args – 模型的参数。

  • model_kwargs – 模型的关键字参数。

sample(state, model_args, model_kwargs)[source]

从给定的 HMCState 运行 HMC 并返回结果 HMCState

参数:
  • state (HMCState) – 表示当前状态。

  • model_args – 提供给模型的参数。

  • model_kwargs – 提供给模型的关键字参数。

返回值:

运行 HMC 后的下一个 state

NUTS

class NUTS(model=None, potential_fn=None, kinetic_fn=None, step_size=1.0, inverse_mass_matrix=None, adapt_step_size=True, adapt_mass_matrix=True, dense_mass=False, target_accept_prob=0.8, trajectory_length=None, max_tree_depth=10, init_strategy=<function init_to_uniform>, find_heuristic_step_size=False, forward_mode_differentiation=False, regularize_mass_matrix=True)[source]

基类: HMC

哈密顿蒙特卡罗推断,使用具有自适应路径长度和质量矩阵自适应的 No U-Turn Sampler (NUTS)。

注意

在 MCMC 运行中使用核之前,postprocess_fn 将返回恒等函数。

注意

默认的初始化策略 init_to_uniform 对于某些模型来说可能不是一个好的策略。您可能想尝试其他初始化策略,例如 init_to_median

参考文献

  1. 使用哈密顿动力学的 MCMC, Radford M. Neal

  2. The No-U-turn sampler: adaptively setting path lengths in Hamiltonian Monte Carlo,作者 Matthew D. Hoffman, 和 Andrew Gelman。

  3. A Conceptual Introduction to Hamiltonian Monte Carlo`,作者 Michael Betancourt

参数:
  • model – 包含 Pyro 原语 的 Python 可调用对象。如果提供了 model,将使用 model 推断 potential_fn

  • potential_fn – Python 可调用对象,用于计算给定输入参数的势能。传递给 potential_fn 的输入参数可以是任何 python 集合类型,前提是传递给 init_kernelinit_params 参数具有相同的类型。

  • kinetic_fn – 返回给定逆质量矩阵和动量的动能的 Python 可调用对象。如果未提供,默认是欧几里得动能。

  • step_size (float) – 确定 verlet 积分器在使用哈密顿动力学计算轨迹时所取单个步骤的大小。如果未指定,将设置为 1。

  • inverse_mass_matrix (numpy.ndarray or dict) – 逆质量矩阵的初始值。如果 adapt_mass_matrix = True,这在热身阶段可能会被调整。如果没有指定值,则初始化为单位矩阵。对于具有一般 JAX pytree 参数的 potential_fn,质量矩阵条目的顺序是通过 jax.tree_flatten 获得的 pytree 参数展平版本的顺序,这有点模糊(更多信息参见 https://jax.net.cn/en/stable/pytrees.html)。如果 model 不为 None,这里我们可以将结构化块质量矩阵指定为一个字典,其中键是站点名称的元组,值是相应的质量矩阵块。有关结构化质量矩阵的更多信息,请参见 dense_mass 参数。

  • adapt_step_size (bool) – 一个标志,用于决定是否要在热身阶段使用 Dual Averaging 方案自适应调整 step_size。

  • adapt_mass_matrix (bool) – 一个标志,用于决定是否要在热身阶段使用 Welford 方案自适应调整质量矩阵。

  • dense_mass (bool or list) –

    此标志控制质量矩阵是密集(即满秩)还是对角(默认为 dense_mass=False)。要指定结构化质量矩阵,用户可以提供一个站点名称元组列表。每个元组代表联合质量矩阵中的一个块。例如,假设模型有潜在变量“x”、“y”、“z”(其中每个变量可以是多维的),可能的规格及其相应的质量矩阵结构如下所示:

    • dense_mass=[(“x”, “y”)]:对联合 (x, y) 使用密集质量矩阵,对 z 使用对角质量矩阵

    • dense_mass=[](等同于 dense_mass=False):对联合 (x, y, z) 使用对角质量矩阵

    • dense_mass=[(“x”, “y”, “z”)](等同于 full_mass=True):对联合 (x, y, z) 使用密集质量矩阵

    • dense_mass=[(“x”,), (“y”,), (“z”)]:对 x、y 和 z 各自使用密集质量矩阵(即块对角,有 3 个块)

  • target_accept_prob (float) – 使用 Dual Averaging 进行步长自适应的目标接受概率。增加此值将导致更小的步长,从而采样会更慢但更稳健。默认为 0.8。

  • trajectory_length (float) – HMC 的 MCMC 轨迹长度。此参数在 NUTS 采样器中无效。

  • max_tree_depth (int) – NUTS 采样器倍增方案期间创建的二叉树的最大深度。默认为 10。此参数也接受整数元组 (d1, d2),其中 d1 是热身阶段的最大树深度,d2 是热身后阶段的最大树深度。

  • init_strategy (callable) – 每个站点的初始化函数。有关可用函数,请参阅初始化策略部分。

  • find_heuristic_step_size (bool) – 是否在每个适应窗口开始时使用启发式函数调整步长。默认为 False。

  • forward_mode_differentiation (bool) –

    是否使用前向模式微分或反向模式微分。默认情况下,我们使用反向模式,但在某些情况下,前向模式有助于提高性能。此外,JAX 中的一些控制流实用程序(如 jax.lax.while_loopjax.lax.fori_loop)仅支持前向模式微分。更多信息请参见 JAX 的自动微分手册

HMCGibbs

class HMCGibbs(inner_kernel, gibbs_fn, gibbs_sites)[source]

基类:MCMCKernel

[实验性接口]

HMC-within-Gibbs。此推断算法允许用户将通用基于梯度的推断 (HMC 或 NUTS) 与自定义 Gibbs 采样器结合使用。

请注意,提供从相应后验条件进行采样的正确 gibbs_fn 实现是用户的责任。

参数:
  • inner_kernelHMCNUTS 之一。

  • gibbs_fn – 一个 Python 可调用对象,返回基于 HMC 站点条件的 Gibbs 样本字典。必须包含一个参数 rng_key,应将其用于所有采样。还必须包含参数 hmc_sitesgibbs_sites,每个参数都是一个字典,键是站点名称,值是相应的样本值。请注意,给定的 gibbs_fn 可能不需要使用所有这些样本值。

  • gibbs_sites (list) – Gibbs 采样器覆盖的潜在变量站点名称列表。

示例

>>> from jax import random
>>> import jax.numpy as jnp
>>> import numpyro
>>> import numpyro.distributions as dist
>>> from numpyro.infer import MCMC, NUTS, HMCGibbs
...
>>> def model():
...     x = numpyro.sample("x", dist.Normal(0.0, 2.0))
...     y = numpyro.sample("y", dist.Normal(0.0, 2.0))
...     numpyro.sample("obs", dist.Normal(x + y, 1.0), obs=jnp.array([1.0]))
...
>>> def gibbs_fn(rng_key, gibbs_sites, hmc_sites):
...     y = hmc_sites['y']
...     new_x = dist.Normal(0.8 * (1-y), jnp.sqrt(0.8)).sample(rng_key)
...     return {'x': new_x}
...
>>> hmc_kernel = NUTS(model)
>>> kernel = HMCGibbs(hmc_kernel, gibbs_fn=gibbs_fn, gibbs_sites=['x'])
>>> mcmc = MCMC(kernel, num_warmup=100, num_samples=100, progress_bar=False)
>>> mcmc.run(random.PRNGKey(0))
>>> mcmc.print_summary()  
sample_field = 'z'
property model
get_diagnostics_str(state)[source]

给定当前 state,返回要添加到进度条中的诊断字符串,用于诊断目的。

postprocess_fn(args, kwargs)[source]

获取一个函数,该函数将样本站点上的无约束值转换为受限到站点支持范围内的值,并返回模型中的确定性站点。

参数:
  • model_args – 模型的参数。

  • model_kwargs – 模型的关键字参数。

init(rng_key, num_warmup, init_params, model_args, model_kwargs)[source]

初始化 MCMCKernel 并返回一个初始状态以开始采样。

参数:
  • rng_key (random.PRNGKey) – 初始化核的随机数生成器密钥。

  • num_warmup (int) – 热身步骤的数量。这在热身期间进行适应非常有用。

  • init_params (tuple) – 开始采样的初始参数。类型必须与 potential_fn 的输入类型一致。

  • model_args – 提供给模型的参数。

  • model_kwargs – 提供给模型的关键字参数。

返回值:

表示核状态的初始状态。这可以是任何注册为 pytree 的类。

sample(state, model_args, model_kwargs)[source]

给定当前 state,使用给定的转移核返回下一个 state

参数:
  • state

    表示核状态的 pytree 类。对于 HMC,它由 HMCState 给出。通常,这可以是任何支持 getattr 的类。

  • model_args – 提供给模型的参数。

  • model_kwargs – 提供给模型的关键字参数。

返回值:

下一个 state

DiscreteHMCGibbs

class DiscreteHMCGibbs(inner_kernel, *, random_walk=False, modified=False)[source]

基类: HMCGibbs

[实验性接口]

HMCGibbs 的子类,对离散潜在站点执行 Metropolis 更新。

注意

站点更新顺序在每一步随机排列。

注意

此类支持离散潜在变量的枚举。要边缘化一个离散潜在站点,我们可以在其相应的 sample() 语句中指定 infer={‘enumerate’: ‘parallel’} 关键字。

参数:
  • inner_kernelHMCNUTS 之一。

  • random_walk (bool) – 如果为 False,则使用 Gibbs 采样从条件分布 p(gibbs_site | remaining sites) 中抽取样本。否则,从 gibbs_site 的域中均匀抽取样本。默认为 False。

  • modified (bool) – 是否使用参考文献 [1] 中建议的修改后的提议,该提议总是为当前的 Gibbs 站点提议一个新状态。默认为 False。修改后的方案在文献中称为“修改后的 Gibbs 采样器”或“Metropolised Gibbs 采样器”。

参考文献

  1. Peskun’s theorem and a modified discrete-state Gibbs sampler,作者 Liu, J. S. (1996)

示例

>>> from jax import random
>>> import jax.numpy as jnp
>>> import numpyro
>>> import numpyro.distributions as dist
>>> from numpyro.infer import DiscreteHMCGibbs, MCMC, NUTS
...
>>> def model(probs, locs):
...     c = numpyro.sample("c", dist.Categorical(probs))
...     numpyro.sample("x", dist.Normal(locs[c], 0.5))
...
>>> probs = jnp.array([0.15, 0.3, 0.3, 0.25])
>>> locs = jnp.array([-2, 0, 2, 4])
>>> kernel = DiscreteHMCGibbs(NUTS(model), modified=True)
>>> mcmc = MCMC(kernel, num_warmup=1000, num_samples=100000, progress_bar=False)
>>> mcmc.run(random.PRNGKey(0), probs, locs)
>>> mcmc.print_summary()  
>>> samples = mcmc.get_samples()["x"]
>>> assert abs(jnp.mean(samples) - 1.3) < 0.1
>>> assert abs(jnp.var(samples) - 4.36) < 0.5
init(rng_key, num_warmup, init_params, model_args, model_kwargs)[source]

初始化 MCMCKernel 并返回一个初始状态以开始采样。

参数:
  • rng_key (random.PRNGKey) – 初始化核的随机数生成器密钥。

  • num_warmup (int) – 热身步骤的数量。这在热身期间进行适应非常有用。

  • init_params (tuple) – 开始采样的初始参数。类型必须与 potential_fn 的输入类型一致。

  • model_args – 提供给模型的参数。

  • model_kwargs – 提供给模型的关键字参数。

返回值:

表示核状态的初始状态。这可以是任何注册为 pytree 的类。

sample(state, model_args, model_kwargs)[source]

给定当前 state,使用给定的转移核返回下一个 state

参数:
  • state

    表示核状态的 pytree 类。对于 HMC,它由 HMCState 给出。通常,这可以是任何支持 getattr 的类。

  • model_args – 提供给模型的参数。

  • model_kwargs – 提供给模型的关键字参数。

返回值:

下一个 state

MixedHMC

class MixedHMC(inner_kernel, *, num_discrete_updates=None, random_walk=False, modified=False)[source]

基类: DiscreteHMCGibbs

混合哈密顿蒙特卡罗 (参考文献 [1]) 的实现。

注意

每次 MCMC 迭代更新的离散站点数 (n_D 在参考文献 [1] 中) 固定为值 1。

参考文献

  1. Mixed Hamiltonian Monte Carlo for Mixed Discrete and Continuous Variables,作者 Guangyao Zhou (2020)

  2. Peskun’s theorem and a modified discrete-state Gibbs sampler,作者 Liu, J. S. (1996)

参数:
  • inner_kernel – 一个 HMC 核函数。

  • num_discrete_updates (int) – 更新离散变量的次数。默认为离散潜在变量的数量。

  • random_walk (bool) – 如果为 False,则使用 Gibbs 采样从条件分布 p(gibbs_site | remaining sites) 中抽取样本,其中 gibbs_site 是模型中的一个离散样本站点。否则,从 gibbs_site 的域中均匀抽取样本。默认为 False。

  • modified (bool) – 是否使用参考文献 [2] 中建议的修改后的提议,该提议总是为当前的 Gibbs 站点(即离散站点)提议一个新状态。默认为 False。修改后的方案在文献中称为“修改后的 Gibbs 采样器”或“Metropolised Gibbs 采样器”。

示例

>>> from jax import random
>>> import jax.numpy as jnp
>>> import numpyro
>>> import numpyro.distributions as dist
>>> from numpyro.infer import HMC, MCMC, MixedHMC
...
>>> def model(probs, locs):
...     c = numpyro.sample("c", dist.Categorical(probs))
...     numpyro.sample("x", dist.Normal(locs[c], 0.5))
...
>>> probs = jnp.array([0.15, 0.3, 0.3, 0.25])
>>> locs = jnp.array([-2, 0, 2, 4])
>>> kernel = MixedHMC(HMC(model, trajectory_length=1.2), num_discrete_updates=20)
>>> mcmc = MCMC(kernel, num_warmup=1000, num_samples=100000, progress_bar=False)
>>> mcmc.run(random.PRNGKey(0), probs, locs)
>>> mcmc.print_summary()  
>>> samples = mcmc.get_samples()
>>> assert "x" in samples and "c" in samples
>>> assert abs(jnp.mean(samples["x"]) - 1.3) < 0.1
>>> assert abs(jnp.var(samples["x"]) - 4.36) < 0.5
init(rng_key, num_warmup, init_params, model_args, model_kwargs)[source]

初始化 MCMCKernel 并返回一个初始状态以开始采样。

参数:
  • rng_key (random.PRNGKey) – 初始化核的随机数生成器密钥。

  • num_warmup (int) – 热身步骤的数量。这在热身期间进行适应非常有用。

  • init_params (tuple) – 开始采样的初始参数。类型必须与 potential_fn 的输入类型一致。

  • model_args – 提供给模型的参数。

  • model_kwargs – 提供给模型的关键字参数。

返回值:

表示核状态的初始状态。这可以是任何注册为 pytree 的类。

sample(state, model_args, model_kwargs)[source]

给定当前 state,使用给定的转移核返回下一个 state

参数:
  • state

    表示核状态的 pytree 类。对于 HMC,它由 HMCState 给出。通常,这可以是任何支持 getattr 的类。

  • model_args – 提供给模型的参数。

  • model_kwargs – 提供给模型的关键字参数。

返回值:

下一个 state

HMCECS

class HMCECS(inner_kernel, *, num_blocks=1, proxy=None)[source]

基类: HMCGibbs

[实验性接口]

具有能量守恒子采样的 HMC。

这是 HMCGibbs 的子类,用于对使用 plate 原语进行子采样语句的模型执行 HMC-within-Gibbs。它实现了参考文献 [1] 的算法 1,但使用对数似然的朴素估计(无控制变量),因此可能导致高方差。

此函数可以将子采样索引划分为块,并在每个 MCMC 步骤中只更新一个块,以提高提议子采样的接受率,详细信息见 [3]。

注意

新的子采样索引在每个 MCMC 步骤中随机有放回地提议。

参考文献

  1. Hamiltonian Monte Carlo with energy conserving subsampling,作者 Dang, K. D., Quiroz, M., Kohn, R., Minh-Ngoc, T., & Villani, M. (2019)

  2. Speeding Up MCMC by Efficient Data Subsampling,作者 Quiroz, M., Kohn, R., Villani, M., & Tran, M. N. (2018)

  3. The Block Pseudo-Margional Sampler,作者 Tran, M.-N., Kohn, R., Quiroz, M. Villani, M. (2017)

  4. The Fundamental Incompatibility of Scalable Hamiltonian Monte Carlo and Naive Data Subsampling,作者 Betancourt, M. (2015)

参数:
  • inner_kernelHMCNUTS 之一。

  • num_blocks (int) – 将子采样划分为的块数。

  • proxy – 用于似然估计的 taylor_proxy(),或者 None 表示朴素(轨迹间)子采样,如 [4] 中所述。

示例

>>> from jax import random
>>> import jax.numpy as jnp
>>> import numpyro
>>> import numpyro.distributions as dist
>>> from numpyro.infer import HMCECS, MCMC, NUTS
...
>>> def model(data):
...     x = numpyro.sample("x", dist.Normal(0, 1))
...     with numpyro.plate("N", data.shape[0], subsample_size=100):
...         batch = numpyro.subsample(data, event_dim=0)
...         numpyro.sample("obs", dist.Normal(x, 1), obs=batch)
...
>>> data = random.normal(random.PRNGKey(0), (10000,)) + 1
>>> kernel = HMCECS(NUTS(model), num_blocks=10)
>>> mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000)
>>> mcmc.run(random.PRNGKey(0), data)
>>> samples = mcmc.get_samples()["x"]
>>> assert abs(jnp.mean(samples) - 1.) < 0.1
postprocess_fn(args, kwargs)[source]

获取一个函数,该函数将样本站点上的无约束值转换为受限到站点支持范围内的值,并返回模型中的确定性站点。

参数:
  • model_args – 模型的参数。

  • model_kwargs – 模型的关键字参数。

init(rng_key, num_warmup, init_params, model_args, model_kwargs)[source]

初始化 MCMCKernel 并返回一个初始状态以开始采样。

参数:
  • rng_key (random.PRNGKey) – 初始化核的随机数生成器密钥。

  • num_warmup (int) – 热身步骤的数量。这在热身期间进行适应非常有用。

  • init_params (tuple) – 开始采样的初始参数。类型必须与 potential_fn 的输入类型一致。

  • model_args – 提供给模型的参数。

  • model_kwargs – 提供给模型的关键字参数。

返回值:

表示核状态的初始状态。这可以是任何注册为 pytree 的类。

sample(state, model_args, model_kwargs)[source]

给定当前 state,使用给定的转移核返回下一个 state

参数:
  • state

    表示核状态的 pytree 类。对于 HMC,它由 HMCState 给出。通常,这可以是任何支持 getattr 的类。

  • model_args – 提供给模型的参数。

  • model_kwargs – 提供给模型的关键字参数。

返回值:

下一个 state

static taylor_proxy(reference_params, degree=2)[source]

这只是一个方便的静态方法,用于调用 taylor_proxy()

SA

class SA(model=None, potential_fn=None, adapt_state_size=None, dense_mass=True, init_strategy=<function init_to_uniform>)[source]

基类:MCMCKernel

样本自适应 MCMC,一种无梯度采样器。

这是一种非常快的采样器(按 n_eff / s 计算),但需要许多热身(burn-in)步。在每个 MCMC 步骤中,我们只需评估一个点的势能函数。

请注意,与参考文献 [1] 不同,我们返回的是大小为 num_chains x num_samples 的近似后验样本的随机选择(即稀疏)子集,而不是 num_chains x num_samples x adapt_state_size。

注意

我们建议在使用 MCMC 时将此核函数与 progress_bar=False 一起使用,以减少 JAX 的调度开销。

参考文献

  1. Sample Adaptive MCMC (https://papers.nips.cc/paper/9107-sample-adaptive-mcmc),作者 Michael Zhu

参数:
  • model – 包含 Pyro 原语 的 Python 可调用对象。如果提供了 model,将使用 model 推断 potential_fn

  • potential_fn – Python 可调用对象,用于计算给定输入参数的目标势能。传递给 potential_fn 的输入参数可以是任何 python 集合类型,前提是传递给 init()init_params 参数具有相同的类型。

  • adapt_state_size (int) – 用于生成提议分布的点数。默认为潜在变量大小的 2 倍。

  • dense_mass (bool) – 一个标志,用于决定质量矩阵是密集还是对角(默认为 dense_mass=True

  • init_strategy (callable) – 每个站点的初始化函数。有关可用函数,请参阅初始化策略部分。

init(rng_key, num_warmup, init_params=None, model_args=(), model_kwargs={})[source]

初始化 MCMCKernel 并返回一个初始状态以开始采样。

参数:
  • rng_key (random.PRNGKey) – 初始化核的随机数生成器密钥。

  • num_warmup (int) – 热身步骤的数量。这在热身期间进行适应非常有用。

  • init_params (tuple) – 开始采样的初始参数。类型必须与 potential_fn 的输入类型一致。

  • model_args – 提供给模型的参数。

  • model_kwargs – 提供给模型的关键字参数。

返回值:

表示核状态的初始状态。这可以是任何注册为 pytree 的类。

property model
property sample_field

传递给 sample()state 对象中表示 MCMC 样本的属性。此属性由 postprocess_fn() 使用,也用于在 MCMC.print_summary() 中报告结果。

property default_fields

在 MCMC 运行期间(调用 MCMC.run() 时)默认要收集的 state 对象的属性。

get_diagnostics_str(state)[source]

给定当前 state,返回要添加到进度条中的诊断字符串,用于诊断目的。

postprocess_fn(args, kwargs)[source]

获取一个函数,该函数将样本站点上的无约束值转换为受限到站点支持范围内的值,并返回模型中的确定性站点。

参数:
  • model_args – 模型的参数。

  • model_kwargs – 模型的关键字参数。

sample(state, model_args, model_kwargs)[source]

从给定的 SAState 运行 SA 并返回结果 SAState

参数:
  • state (SAState) – 表示当前状态。

  • model_args – 提供给模型的参数。

  • model_kwargs – 提供给模型的关键字参数。

返回值:

运行 SA 后的下一个 state

EnsembleSampler

class EnsembleSampler(model=None, potential_fn=None, *, randomize_split, init_strategy)[source]

基类: MCMCKernel, ABC

集成采样器的抽象类。每个 MCMC 样本被分为两个子迭代,其中更新一半的集成。

参数:
  • model – 包含 Pyro 原语 的 Python 可调用对象。如果提供了 model,将使用 model 推断 potential_fn

  • potential_fn – Python 可调用对象,用于计算给定输入参数的势能。传递给 potential_fn 的输入参数可以是任何 python 集合类型,前提是传递给 init()init_params 参数具有相同的类型。

  • randomize_split (bool) – 是否在每次迭代时随机排列链的顺序。

  • init_strategy (callable) – 每个站点的初始化函数。有关可用函数,请参阅初始化策略部分。

property model
property sample_field

传递给 sample()state 对象中表示 MCMC 样本的属性。此属性由 postprocess_fn() 使用,也用于在 MCMC.print_summary() 中报告结果。

property is_ensemble_kernel

表示核是否为集成核。如果为 True,则在 MCMC 运行期间(调用 MCMC.run() 时),如果 chain_method = “vectorized”,将显示 diagnostics_str。

abstract init_inner_state(rng_key)[source]

返回 inner_state

abstract update_active_chains(active, inactive, inner_state)[source]

返回 (更新的活动链集, 更新的内部状态)

init(rng_key, num_warmup, init_params=None, model_args=(), model_kwargs={})[source]

初始化 MCMCKernel 并返回一个初始状态以开始采样。

参数:
  • rng_key (random.PRNGKey) – 初始化核的随机数生成器密钥。

  • num_warmup (int) – 热身步骤的数量。这在热身期间进行适应非常有用。

  • init_params (tuple) – 开始采样的初始参数。类型必须与 potential_fn 的输入类型一致。

  • model_args – 提供给模型的参数。

  • model_kwargs – 提供给模型的关键字参数。

返回值:

表示核状态的初始状态。这可以是任何注册为 pytree 的类。

postprocess_fn(args, kwargs)[source]

获取一个函数,该函数将样本站点上的无约束值转换为受限到站点支持范围内的值,并返回模型中的确定性站点。

参数:
  • model_args – 模型的参数。

  • model_kwargs – 模型的关键字参数。

sample(state, model_args, model_kwargs)[source]

给定当前 state,使用给定的转移核返回下一个 state

参数:
  • state

    表示核状态的 pytree 类。对于 HMC,它由 HMCState 给出。通常,这可以是任何支持 getattr 的类。

  • model_args – 提供给模型的参数。

  • model_kwargs – 提供给模型的关键字参数。

返回值:

下一个 state

AIES

class AIES(model=None, potential_fn=None, randomize_split=False, moves=None, init_strategy=<function init_to_uniform>)[source]

基类: EnsembleSampler

仿射不变集成采样 (Affine-Invariant Ensemble Sampling):一种无梯度方法,通过链之间共享信息来改进 Metropolis-Hastings 提议。适用于低到中维模型。通常,num_chains 应该至少是模型维度的两倍。

注意

此核函数必须与 MCMC 中的 num_chains > 1chain_method=”vectorized 一起使用。链数必须能被 2 整除。

参考文献

  1. emcee: The MCMC Hammer (https://iopscience.iop.org/article/10.1086/670067),

    作者 Daniel Foreman-Mackey, David W. Hogg, Dustin Lang, 和 Jonathan Goodman。

参数:
  • model – 包含 Pyro 原语 的 Python 可调用对象。如果提供了 model,将使用 model 推断 potential_fn

  • potential_fn – Python 可调用对象,用于计算给定输入参数的势能。传递给 potential_fn 的输入参数可以是任何 python 集合类型,前提是传递给 init()init_params 参数具有相同的类型。

  • randomize_split (bool) – 是否在每次迭代时随机排列链的顺序。默认为 False。

  • moves – 一个字典,将移动映射到其被选中的相应概率。有效键为 AIES.DEMove()AIES.StretchMove()。两者在实践中通常表现良好。如果概率之和超过 1,则概率将被归一化。默认为 {AIES.DEMove(): 1.0}

  • init_strategy (callable) – 每个站点的初始化函数。有关可用函数,请参阅初始化策略部分。

示例

>>> import jax
>>> import jax.numpy as jnp
>>> import numpyro
>>> import numpyro.distributions as dist
>>> from numpyro.infer import MCMC, AIES

>>> def model():
...    x = numpyro.sample("x", dist.Normal().expand([10]))
...    numpyro.sample("obs", dist.Normal(x, 1.0), obs=jnp.ones(10))
>>>
>>> kernel = AIES(model, moves={AIES.DEMove() : 0.5,
...                            AIES.StretchMove() : 0.5})
>>> mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=20, chain_method='vectorized')
>>> mcmc.run(jax.random.PRNGKey(0))
get_diagnostics_str(state)[source]

给定当前 state,返回要添加到进度条中的诊断字符串,用于诊断目的。

init_inner_state(rng_key)[source]

返回 inner_state

update_active_chains(active, inactive, inner_state)[source]

返回 (更新的活动链集, 更新的内部状态)

static DEMove(sigma=1e-05, g0=None)[source]

使用差分进化的提议。

差分进化提议 的实现遵循 Nelson 等人 (2013)

参数:
  • sigma – (可选)用于拉伸提议向量的高斯分布的标准差。默认为 1.0e-5

  • (可选) (g0) – 提议向量的平均拉伸因子。默认情况下,按照两篇参考文献的建议,它是 2.38 / sqrt(2*ndim)

static StretchMove(a=2.0)[source]

Goodman & Weare (2010) 的“拉伸移动”,并进行并行化,如 Foreman-Mackey 等人 (2013) 所述。

参数:

a – (可选)拉伸比例参数。(默认值:2.0

ESS

class ESS(model=None, potential_fn=None, randomize_split=True, moves=None, max_steps=10000, max_iter=10000, init_mu=1.0, tune_mu=True, init_strategy=<function init_to_uniform>)[source]

基类: EnsembleSampler

集成切片采样 (Ensemble Slice Sampling):一种无梯度方法,通过链之间共享信息来找到更好的切片采样方向。适用于低到中维模型。通常,num_chains 应该至少是模型维度的两倍。

注意

此核函数必须与 MCMC 中的 num_chains > 1chain_method=”vectorized 一起使用。链数必须能被 2 整除。

参考文献

  1. zeus: a PYTHON implementation of ensemble slice sampling for efficient Bayesian parameter inference (https://academic.oup.com/mnras/article/508/3/3589/6381726),

    作者 Minas Karamanis, Florian Beutler, 和 John A. Peacock。

  2. Ensemble slice sampling (https://link.springer.com/article/10.1007/s11222-021-10038-2),

    作者 Minas Karamanis, Florian Beutler。

参数:
  • model – 包含 Pyro 原语 的 Python 可调用对象。如果提供了 model,将使用 model 推断 potential_fn

  • potential_fn – Python 可调用对象,用于计算给定输入参数的势能。传递给 potential_fn 的输入参数可以是任何 python 集合类型,前提是传递给 init()init_params 参数具有相同的类型。

  • randomize_split (bool) – 是否在每次迭代时随机排列链的顺序。默认为 True。

  • moves – 一个字典,将移动映射到其被选中的相应概率。如果概率之和超过 1,则概率将被归一化。有效键包括:ESS.DifferentialMove() -> 默认提议,适用于广泛的目标分布,ESS.GaussianMove() -> 适用于近似正态分布的目标,ESS.KDEMove() -> 适用于多峰后验 - 需要大的 num_chains,并且必须良好初始化,ESS.RandomMove() -> 无链交互,用于调试。默认为 {ESS.DifferentialMove(): 1.0}

  • max_steps (int) – 每样本最大步出步数。默认为 10,000。

  • max_iter (int) – 每样本最大扩展/收缩次数。默认为 10,000。

  • init_mu (float) – 初始比例因子。默认为 1.0。

  • tune_mu (bool) – 是否调整初始比例因子。默认为 True。

  • init_strategy (callable) – 每个站点的初始化函数。有关可用函数,请参阅初始化策略部分。

示例

>>> import jax
>>> import jax.numpy as jnp
>>> import numpyro
>>> import numpyro.distributions as dist
>>> from numpyro.infer import MCMC, ESS

>>> def model():
...    x = numpyro.sample("x", dist.Normal().expand([10]))
...    numpyro.sample("obs", dist.Normal(x, 1.0), obs=jnp.ones(10))
>>>
>>> kernel = ESS(model, moves={ESS.DifferentialMove() : 0.8,
...                            ESS.RandomMove() : 0.2})
>>> mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=20, chain_method='vectorized')
>>> mcmc.run(jax.random.PRNGKey(0))
init_inner_state(rng_key)[source]

返回 inner_state

update_active_chains(active, inactive, inner_state)[source]

返回 (更新的活动链集, 更新的内部状态)

static RandomMove()[source]

Karamanis & Beutler (2020) 的“随机移动”并进行并行化。使用此移动时,步行者沿随机方向移动。步行者之间没有通信,此移动对应于朴素的切片采样方法。此移动仅应用于调试目的。

static KDEMove(bw_method=None)[source]

Karamanis & Beutler (2020) 的“KDE 移动”并进行并行化。使用此移动时,使用高斯核密度估计方法跟踪互补集成步行者的分布。然后步行者沿从该分布采样的随机方向向量移动。

static GaussianMove()[source]

Karamanis & Beutler (2020) 的“高斯移动”并进行并行化。使用此移动时,步行者沿由互补集成步行者的高斯近似采样的随机向量定义的方向移动。

static DifferentialMove()[source]

Karamanis & Beutler (2020) 的“差分移动”并进行并行化。使用此移动时,步行者沿由从互补集成中随机(无放回)采样对定义的随机方向移动。这是默认选择,适用于广泛的目标分布。

hmc(potential_fn=None, potential_fn_gen=None, kinetic_fn=None, algo='NUTS')[source]

哈密顿蒙特卡罗推断,使用固定步数或具有自适应路径长度的 No U-Turn Sampler (NUTS)。

参考文献

  1. 使用哈密顿动力学的 MCMC, Radford M. Neal

  2. The No-U-turn sampler: adaptively setting path lengths in Hamiltonian Monte Carlo,作者 Matthew D. Hoffman, 和 Andrew Gelman。

  3. A Conceptual Introduction to Hamiltonian Monte Carlo`,作者 Michael Betancourt

参数:
  • potential_fn – Python 可调用对象,用于计算给定输入参数的势能。传递给 potential_fn 的输入参数可以是任何 python 集合类型,前提是传递给 init_kernelinit_params 参数具有相同的类型。

  • potential_fn_gen – Python 可调用对象,当提供模型参数/关键字参数时,返回 potential_fn。可以提供此参数以便在数据变化的情况下对同一模型进行推断。如果数据形状保持不变,我们可以编译一次 sample_kernel,并将其用于多次推断运行。

  • kinetic_fn – 返回给定逆质量矩阵和动量的动能的 Python 可调用对象。如果未提供,默认是欧几里得动能。

  • algo (str) – 是否运行固定步数的 HMC 或自适应路径长度的 NUTS。默认为 NUTS

返回值:

一个包含两个可调用对象的元组 (init_kernel, sample_kernel),第一个用于初始化采样器,第二个用于在现有样本基础上生成样本。

警告

强烈建议您使用更高层的 MCMC API,而不是直接使用此接口。

示例

>>> import jax
>>> from jax import random
>>> import jax.numpy as jnp
>>> import numpyro
>>> import numpyro.distributions as dist
>>> from numpyro.infer.hmc import hmc
>>> from numpyro.infer.util import initialize_model
>>> from numpyro.util import fori_collect

>>> true_coefs = jnp.array([1., 2., 3.])
>>> data = random.normal(random.PRNGKey(2), (2000, 3))
>>> labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample(random.PRNGKey(3))
>>>
>>> def model(data, labels):
...     coefs = numpyro.sample('coefs', dist.Normal(jnp.zeros(3), jnp.ones(3)))
...     intercept = numpyro.sample('intercept', dist.Normal(0., 10.))
...     return numpyro.sample('y', dist.Bernoulli(logits=(coefs * data + intercept).sum(-1)), obs=labels)
>>>
>>> model_info = initialize_model(random.PRNGKey(0), model, model_args=(data, labels,))
>>> init_kernel, sample_kernel = hmc(model_info.potential_fn, algo='NUTS')
>>> hmc_state = init_kernel(model_info.param_info,
...                         trajectory_length=10,
...                         num_warmup=300)
>>> samples = fori_collect(0, 500, sample_kernel, hmc_state,
...                        transform=lambda state: model_info.postprocess_fn(state.z))
>>> print(jnp.mean(samples['coefs'], axis=0))  
[0.9153987 2.0754058 2.9621222]
init_kernel(init_params, num_warmup, step_size=1.0, inverse_mass_matrix=None, adapt_step_size=True, adapt_mass_matrix=True, dense_mass=False, target_accept_prob=0.8, *, num_steps=None, trajectory_length=6.283185307179586, max_tree_depth=10, find_heuristic_step_size=False, regularize_mass_matrix=True, model_args=(), model_kwargs=None, rng_key=None)

初始化 HMC 采样器。

参数:
  • init_params – 开始采样的初始参数。其类型必须与 potential_fn 的输入类型一致。

  • num_warmup (int) – 热身步数;热身期间生成的样本将被丢弃。

  • step_size (float) – 确定 verlet 积分器在使用哈密顿动力学计算轨迹时所取单个步骤的大小。如果未指定,将设置为 1。

  • inverse_mass_matrix (numpy.ndarray or dict) – 逆质量矩阵的初始值。如果 adapt_mass_matrix = True,这在热身阶段可能会被调整。如果没有指定值,则初始化为单位矩阵。对于具有一般 JAX pytree 参数的 potential_fn,质量矩阵条目的顺序是通过 jax.tree_flatten 获得的 pytree 参数展平版本的顺序,这有点模糊(更多信息参见 https://jax.net.cn/en/stable/pytrees.html)。如果 model 不为 None,这里我们可以将结构化块质量矩阵指定为一个字典,其中键是站点名称的元组,值是相应的质量矩阵块。有关结构化质量矩阵的更多信息,请参见 dense_mass 参数。

  • adapt_step_size (bool) – 一个标志,用于决定是否要在热身阶段使用 Dual Averaging 方案自适应调整 step_size。

  • adapt_mass_matrix (bool) – 一个标志,用于决定是否要在热身阶段使用 Welford 方案自适应调整质量矩阵。

  • dense_mass (bool or list) –

    此标志控制质量矩阵是密集(即满秩)还是对角(默认为 dense_mass=False)。要指定结构化质量矩阵,用户可以提供一个站点名称元组列表。每个元组代表联合质量矩阵中的一个块。例如,假设模型有潜在变量“x”、“y”、“z”(其中每个变量可以是多维的),可能的规格及其相应的质量矩阵结构如下所示:

    • dense_mass=[(“x”, “y”)]:对联合 (x, y) 使用密集质量矩阵,对 z 使用对角质量矩阵

    • dense_mass=[](等同于 dense_mass=False):对联合 (x, y, z) 使用对角质量矩阵

    • dense_mass=[(“x”, “y”, “z”)](等同于 full_mass=True):对联合 (x, y, z) 使用密集质量矩阵

    • dense_mass=[(“x”,), (“y”,), (“z”)]:对 x、y 和 z 各自使用密集质量矩阵(即块对角,有 3 个块)

  • target_accept_prob (float) – 使用 Dual Averaging 进行步长自适应的目标接受概率。增加此值将导致更小的步长,从而采样会更慢但更稳健。默认为 0.8。

  • num_steps (int) – 如果与 None 不同,则固定每次迭代允许的步数。

  • trajectory_length (float) – HMC 的 MCMC 轨迹长度。默认值为 \(2\pi\)

  • max_tree_depth (int) – NUTS 采样器倍增方案期间创建的二叉树的最大深度。默认为 10。此参数也接受整数元组 (d1, d2),其中 d1 是热身阶段的最大树深度,d2 是热身后阶段的最大树深度。

  • find_heuristic_step_size (bool) – 是否在每个适应窗口开始时使用启发式函数调整步长。默认为 False。

  • forward_mode_differentiation (bool) –

    是否使用前向模式微分或反向模式微分。默认情况下,我们使用反向模式,但在某些情况下,前向模式有助于提高性能。此外,JAX 中的一些控制流实用程序(如 jax.lax.while_loopjax.lax.fori_loop)仅支持前向模式微分。更多信息请参见 JAX 的自动微分手册

  • regularize_mass_matrix (bool) – 是否在热身阶段正则化估计的质量矩阵以提高数值稳定性。默认为 True。如果 adapt_mass_matrix == False,此标志不起作用。

  • model_args (tuple) – 如果指定了 potential_fn_gen,则为模型参数。

  • model_kwargs (dict) – 如果指定了 potential_fn_gen,则为模型关键字参数。

  • rng_key (jax.random.PRNGKey) – 用作随机性来源的随机键。

sample_kernel(hmc_state, model_args=(), model_kwargs=None)

给定现有的 HMCState,使用固定(可能已适应)步长运行 HMC 并返回新的 HMCState

参数:
  • hmc_state – 当前样本(及关联状态)。

  • model_args (tuple) – 如果指定了 potential_fn_gen,则为模型参数。

  • model_kwargs (dict) – 如果指定了 potential_fn_gen,则为模型关键字参数。

返回值:

通过模拟哈密顿动力学从现有状态生成的新提议 HMCState

taylor_proxy(reference_params, degree)[source]

用于无偏对数似然估计的控制变量,使用围绕参考参数的泰勒展开。建议在 [1] 中用于子采样。

参数:
  • reference_params (dict) – 在 MLE 或 MAP 估计处的模型参数化。

  • degree – 泰勒展开的项数,可以是 1 或 2。

参考文献

[1] On Markov chain Monte Carlo Methods For Tall Data

作者 Bardenet., R., Doucet, A., Holmes, C. (2017)

BarkerMHState = <class 'numpyro.infer.barker.BarkerMHState'>

一个 namedtuple(),包含以下字段:

  • i - 迭代次数。热身结束后重置为 0。

  • z - Python 集合,表示潜在站点的值(来自后验的无约束样本)。

  • potential_energy - 在给定 z 值处计算的势能。

  • z_grad - 势能相对于潜在样本站点的梯度。

  • accept_prob - 提议的接受概率。请注意,如果被拒绝,z 不对应于提议。

  • mean_accept_prob - 热身自适应或采样期间(用于诊断)直到当前迭代的平均接受概率。

  • adapt_state - 一个 HMCAdaptState namedtuple,包含热身期间的自适应信息

    • step_size - 下一次迭代中积分器使用的步长。

    • inverse_mass_matrix - 下一次迭代中使用的逆质量矩阵。

    • mass_matrix_sqrt - 下一次迭代中使用的质量矩阵的平方根。在密集质量矩阵的情况下,这是质量矩阵的 Cholesky 分解。

  • rng_key - 用于生成提议等的随机数生成器种子。

HMCState = <class 'numpyro.infer.hmc.HMCState'>

一个 namedtuple(),包含以下字段:

  • i - 迭代次数。热身结束后重置为 0。

  • z - Python 集合,表示潜在站点的值(来自后验的无约束样本)。

  • z_grad - 势能相对于潜在样本站点的梯度。

  • potential_energy - 在给定 z 值处计算的势能。

  • energy - 当前状态的势能和动能之和。

  • r - 当前动量变量。如果为 None,则在每个采样步骤开始时将抽取新的动量变量。

  • trajectory_length - 每个采样步骤中运行 HMC 动力学的时间长度。此字段在 NUTS 中不使用。

  • num_steps - 哈密顿轨迹中的步数(用于诊断)。在 HMC 采样器中,为了适应步长,trajectory_length 应为 None。在 NUTS 采样器中,轨迹的树深度可以从此字段计算得出:tree_depth = np.log2(num_steps).astype(int) + 1

  • accept_prob - 提议的接受概率。请注意,如果被拒绝,z 不对应于提议。

  • mean_accept_prob - 热身自适应或采样期间(用于诊断)直到当前迭代的平均接受概率。

  • diverging - 一个布尔值,指示当前轨迹是否发散。

  • adapt_state - 一个 HMCAdaptState namedtuple,包含热身期间的自适应信息

    • step_size - 下一次迭代中积分器使用的步长。

    • inverse_mass_matrix - 下一次迭代中使用的逆质量矩阵。

    • mass_matrix_sqrt - 下一次迭代中使用的质量矩阵的平方根。在密集质量矩阵的情况下,这是质量矩阵的 Cholesky 分解。

  • rng_key - 用于迭代的随机数生成器种子。

HMCGibbsState = <class 'numpyro.infer.hmc_gibbs.HMCGibbsState'>
  • z - 当前潜在变量值(包括 HMC 和 Gibbs 站点)的字典。

  • hmc_state - 当前 HMCState

  • rng_key - 当前步骤的随机键。

SAState = <class 'numpyro.infer.sa.SAState'>

在样本自适应 MCMC 中使用的 namedtuple()。包含以下字段:

  • i - 迭代次数。热身结束后重置为 0。

  • z - Python 集合,表示潜在站点的值(来自后验的无约束样本)。

  • potential_energy - 在给定 z 值处计算的势能。

  • accept_prob - 提议的接受概率。请注意,如果被拒绝,z 不对应于提议。

  • mean_accept_prob - 热身或采样期间(用于诊断)直到当前迭代的平均接受概率。

  • diverging - 一个布尔值,指示新样本的势能是否与当前势能发散。

  • adapt_state - 一个 SAAdaptState namedtuple,包含自适应信息

    • zs - 用于生成提议的点/状态。

    • pes - zs 的势能。

    • loc - 这些 zs 的均值。

    • inv_mass_matrix_sqrt - 如果使用密集质量矩阵,这是 zs 协方差的 Cholesky 分解。否则,这是这些 zs 的标准差。

  • rng_key - 用于迭代的随机数生成器种子。

EnsembleSamplerState = <class 'numpyro.infer.ensemble.EnsembleSamplerState'>

一个 namedtuple(),包含以下字段:

  • z - Python 集合,表示潜在站点的值(来自后验的无约束样本)。

  • inner_state - 一个 namedtuple,包含更新一半集成所需的信息。

  • rng_key - 用于生成提议等的随机数生成器种子。

AIESState = <class 'numpyro.infer.ensemble.AIESState'>

一个 namedtuple(),包含以下字段。

  • i - 迭代次数。

  • accept_prob - 提议的接受概率。请注意,如果被拒绝,z 不对应于提议。

  • mean_accept_prob - 热身自适应或采样期间(用于诊断)直到当前迭代的平均接受概率。

  • rng_key - 用于生成提议等的随机数生成器种子。

ESSState = <class 'numpyro.infer.ensemble.ESSState'>

用作集成采样器内部状态的 namedtuple()。包含以下字段:

  • i - 迭代次数。

  • n_expansions - 当前批次中的扩展次数。用于调整 mu。

  • n_contractions - 当前批次中的收缩次数。用于调整 mu。

  • mu - 比例因子。如果 tune_mu=True,则进行调整。

  • rng_key - 用于生成提议等的随机数生成器种子。

TensorFlow 核函数

TensorFlow Probability (TFP) MCMC 核函数的薄包装器。有关 TFP MCMC 核函数接口的详细信息,请参见其 TransitionKernel 文档

TFPKernel

class TFPKernel(model=None, potential_fn=None, init_strategy=<function init_to_uniform>, **kernel_kwargs)[source]

TensorFlow Probability (TFP) MCMC 转换核函数的薄包装器。TFP 中的参数 target_log_prob_fnmodelpotential_fn(后者是 target_log_prob_fn 的负值)代替。

此类可用于将 TFP 核函数转换为 NumPyro 兼容的核函数,如下所示:

from numpyro.contrib.tfp.mcmc import TFPKernel

kernel = TFPKernel[tfp.mcmc.NoUTurnSampler](model, step_size=1.)

注意

默认情况下,未校准的核函数将作为 MetropolisHastings 核函数的内层核函数。

注意

对于 ReplicaExchangeMC,TFP 要求内层核函数的 step_size 形状必须为 [len(inverse_temperatures), 1][len(inverse_temperatures), latent_size]

参数:
  • model – 包含 Pyro 原语 的 Python 可调用对象。如果提供了 model,将使用 model 推断 potential_fn

  • potential_fn – Python 可调用对象,用于计算给定输入参数的目标势能。传递给 potential_fn 的输入参数可以是任何 python 集合类型,前提是传递给 init()init_params 参数具有相同的类型。

  • init_strategy (callable) – 每个站点的初始化函数。有关可用函数,请参阅初始化策略部分。

  • kernel_kwargs – 要传递给 TFP 核函数构造器的其他参数。

HamiltonianMonteCarlo

class HamiltonianMonteCarlo(model=None, potential_fn=None, init_strategy=<function init_to_uniform>, **kernel_kwargs)

tensorflow_probability.substrates.jax.mcmc.hmc.HamiltonianMonteCarlo 包装到 TFPKernel 中。TFP 内核构造中的第一个参数 target_log_prob_fnmodelpotential_fn 替代。

MetropolisAdjustedLangevinAlgorithm

class MetropolisAdjustedLangevinAlgorithm(model=None, potential_fn=None, init_strategy=<function init_to_uniform>, **kernel_kwargs)

tensorflow_probability.substrates.jax.mcmc.langevin.MetropolisAdjustedLangevinAlgorithm 包装到 TFPKernel 中。TFP 内核构造中的第一个参数 target_log_prob_fnmodelpotential_fn 替代。

NoUTurnSampler

class NoUTurnSampler(model=None, potential_fn=None, init_strategy=<function init_to_uniform>, **kernel_kwargs)

tensorflow_probability.substrates.jax.mcmc.nuts.NoUTurnSampler 包装到 TFPKernel 中。TFP 内核构造中的第一个参数 target_log_prob_fnmodelpotential_fn 替代。

RandomWalkMetropolis

class RandomWalkMetropolis(model=None, potential_fn=None, init_strategy=<function init_to_uniform>, **kernel_kwargs)

tensorflow_probability.substrates.jax.mcmc.random_walk_metropolis.RandomWalkMetropolis 包装到 TFPKernel 中。TFP 内核构造中的第一个参数 target_log_prob_fnmodelpotential_fn 替代。

ReplicaExchangeMC

class ReplicaExchangeMC(model=None, potential_fn=None, init_strategy=<function init_to_uniform>, **kernel_kwargs)

tensorflow_probability.substrates.jax.mcmc.replica_exchange_mc.ReplicaExchangeMC 包装到 TFPKernel 中。TFP 内核构造中的第一个参数 target_log_prob_fnmodelpotential_fn 替代。

SliceSampler

class SliceSampler(model=None, potential_fn=None, init_strategy=<function init_to_uniform>, **kernel_kwargs)

tensorflow_probability.substrates.jax.mcmc.slice_sampler_kernel.SliceSampler 包装到 TFPKernel 中。TFP 内核构造中的第一个参数 target_log_prob_fnmodelpotential_fn 替代。

UncalibratedHamiltonianMonteCarlo

class UncalibratedHamiltonianMonteCarlo(model=None, potential_fn=None, init_strategy=<function init_to_uniform>, **kernel_kwargs)

tensorflow_probability.substrates.jax.mcmc.hmc.UncalibratedHamiltonianMonteCarlo 包装到 TFPKernel 中。TFP 内核构造中的第一个参数 target_log_prob_fnmodelpotential_fn 替代。

UncalibratedLangevin

class UncalibratedLangevin(model=None, potential_fn=None, init_strategy=<function init_to_uniform>, **kernel_kwargs)

tensorflow_probability.substrates.jax.mcmc.langevin.UncalibratedLangevin 包装到 TFPKernel 中。TFP 内核构造中的第一个参数 target_log_prob_fnmodelpotential_fn 替代。

UncalibratedRandomWalk

class UncalibratedRandomWalk(model=None, potential_fn=None, init_strategy=<function init_to_uniform>, **kernel_kwargs)

tensorflow_probability.substrates.jax.mcmc.random_walk_metropolis.UncalibratedRandomWalk 包装到 TFPKernel 中。TFP 内核构造中的第一个参数 target_log_prob_fnmodelpotential_fn 替代。

MCMC 工具函数

initialize_model(rng_key, model, *, init_strategy=<function init_to_uniform>, dynamic_args=False, model_args=(), model_kwargs=None, forward_mode_differentiation=False, validate_grad=True)[source]

(实验性接口)此辅助函数内部调用 get_potential_fn()find_valid_initial_params(),返回一个元组 (init_params_info, potential_fn, postprocess_fn, model_trace)。

参数:
  • rng_key (jax.random.PRNGKey) – 用于从先验分布采样的随机数生成器种子。返回的 init_params 将具有 rng_key.shape[:-1] 的批量形状。

  • model – 包含 Pyro 原语的 Python 可调用对象。

  • init_strategy (callable) – 每个站点的初始化函数。有关可用函数,请参阅初始化策略部分。

  • dynamic_args (bool) – 如果为 True,则 potential_fnconstraints_fn 本身依赖于模型参数。当提供 *model_args, **model_kwargs 时,它们分别返回可调用对象 potential_fnconstraints_fn

  • model_args (tuple) – 提供给模型的参数。

  • model_kwargs (dict) – 提供给模型的关键字参数。

  • forward_mode_differentiation (bool) –

    是否使用前向模式微分或反向模式微分。默认情况下,我们使用反向模式,但在某些情况下,前向模式有助于提高性能。此外,JAX 中的一些控制流实用程序(如 jax.lax.while_loopjax.lax.fori_loop)仅支持前向模式微分。更多信息请参见 JAX 的自动微分手册

  • validate_grad (bool) – 是否验证初始参数的梯度。默认为 True。

返回值:

一个命名元组 ModelInfo,包含字段 (param_info, potential_fn, postprocess_fn, model_trace),其中 param_info 是一个命名元组 ParamInfo,包含用于初始化 MCMC 的先验值、相应的势能及其梯度;postprocess_fn 是一个可调用对象,它使用逆变换将无约束的 HMC 样本转换为位于站点支持范围内的约束值,此外还返回模型中 deterministic 站点的值。

fori_collect(lower: int, upper: int, body_fun: ~typing.Callable, init_val: ~typing.Any, transform: ~typing.Callable = <function identity>, progbar: bool = True, return_last_val: bool = False, collection_size=None, thinning=1, **progbar_opts)[source]

此循环构造类似于 fori_loop(),但额外增加了从循环体中收集值的功能。此外,它允许通过 transform 对这些样本进行后处理,并更新进度条。请注意,progbar=False 会更快,尤其是在收集大量样本时。请参考 hmc() 中的示例用法。

参数:
  • lower (int) – 开始收集工作的索引。换句话说,我们将跳过收集前 lower 个值。

  • upper (int) – 循环体运行的次数。

  • body_fun – 一个可调用对象,接受一个 np.ndarray 集合并返回具有相同形状和 dtype 的集合。

  • init_val – 传递给 body_fun 的初始值。可以是包含 np.ndarray 对象的任何 Python 集合类型。

  • transform – 一个可调用对象,用于后处理 body_fn 返回的值。

  • progbar – 是否发布进度条更新。

  • return_last_val (bool) – 如果为 True,则也返回最后一个值。其类型与 init_val 相同。

  • thinning – 控制保留值的稀疏比率的正整数。默认为 1,即不进行稀疏处理。

  • collection_size (int) – 返回集合的大小。如果未指定,大小将为 (upper - lower) // thinning。如果大小大于 (upper - lower) // thinning,则只有前 (upper - lower) // thinning 个条目非零。

  • **progbar_opts – 可选的额外进度条参数。可以提供一个 diagnostics_fn,当传入来自 body_fun 的当前值时,它返回一个字符串用于更新进度条的后缀。此外,还可以提供一个 progbar_desc 关键字参数,用于标记进度条。

返回值:

init_val 类型相同的集合,其中 np.ndarray 对象的值沿主轴收集。

consensus(subposteriors, num_draws=None, diagonal=False, rng_key=None)[source]

遵循共识蒙特卡罗算法合并子后验。

参考文献

  1. Bayes and big data: The consensus Monte Carlo algorithm, Steven L. Scott, Alexander W. Blocker, Fernando V. Bonassi, Hugh A. Chipman, Edward I. George, Robert E. McCulloch

参数:
  • subposteriors (list) – 一个列表,其中每个元素都是一个样本集合。

  • num_draws (int) – 从合并后验分布中抽取的样本数量。

  • diagonal (bool) – 是否使用方差或协方差计算权重,默认为 False(使用协方差)。

  • rng_key (jax.random.PRNGKey) – 随机性来源,默认为 jax.random.PRNGKey(0)

返回值:

如果 num_draws 为 None,则直接合并子后验而不重新采样;否则,返回一个包含 num_draws 个样本的集合,其数据结构与每个子后验相同。

parametric(subposteriors, diagonal=False)[source]

遵循(易于并行化的)参数化蒙特卡罗算法合并子后验。

参考文献

  1. Asymptotically Exact, Embarrassingly Parallel MCMC, Willie Neiswanger, Chong Wang, Eric Xing

参数:
  • subposteriors (list) – 一个列表,其中每个元素都是一个样本集合。

  • diagonal (bool) – 是否使用方差或协方差计算权重,默认为 False(使用协方差)。

返回值:

合并后验分布的估计均值和方差/协方差参数。

parametric_draws(subposteriors, num_draws, diagonal=False, rng_key=None)[source]

遵循(易于并行化的)参数化蒙特卡罗算法合并子后验。

参考文献

  1. Asymptotically Exact, Embarrassingly Parallel MCMC, Willie Neiswanger, Chong Wang, Eric Xing

参数:
  • subposteriors (list) – 一个列表,其中每个元素都是一个样本集合。

  • num_draws (int) – 从合并后验分布中抽取的样本数量。

  • diagonal (bool) – 是否使用方差或协方差计算权重,默认为 False(使用协方差)。

  • rng_key (jax.random.PRNGKey) – 随机性来源,默认为 jax.random.PRNGKey(0)

返回值:

一个包含 num_draws 个样本的集合,其数据结构与每个子后验相同。