马尔可夫链蒙特卡洛 (MCMC)
我们提供 NumPyro 中 MCMC 算法的高层概述
NUTS 是 HMC 的一个自适应变体,可能是 NumPyro 中最常用的 MCMC 算法。请注意,NUTS 和 HMC 不直接适用于带有离散潜变量的模型,但在离散变量支持有限且枚举可行的情况下,NumPyro 会自动对离散潜变量进行求和(即枚举),并对剩余的连续潜变量执行 NUTS/HMC。如上所述,模型重参数化在某些情况下可能对于获得良好的性能至关重要。请注意,一般来说,随着潜在空间维度的增加,我们预计推断会更加困难。有关更多提示和技巧,请参阅不良几何形状教程。
MixedHMC 对于包含连续和离散潜变量的模型来说,是一种有效的推断策略。
HMCECS 对于数据点数量庞大的模型来说,是一种有效的推断策略。它适用于具有连续潜变量的模型。有关详细用法,请参阅此示例。
BarkerMH 是一种基于梯度的 MCMC 方法,对于某些模型来说,它可能与 HMC 和 NUTS 具有竞争力。它适用于具有连续潜变量的模型。
HMCGibbs 将 HMC/NUTS 步骤与自定义 Gibbs 更新相结合。Gibbs 更新必须由用户指定。
DiscreteHMCGibbs 将 HMC/NUTS 步骤与离散潜变量的 Gibbs 更新相结合。相应的 Gibbs 更新是自动计算的。
SA 是一种无梯度的 MCMC 方法。它仅适用于具有连续潜变量的模型。预计对于潜在维度低到中等的模型性能最佳。对于对数密度不可导的模型来说,这可能是一个不错的选择。请注意,SA 通常需要非常大量的样本,因为混合往往很慢。从积极的一面看,单个步骤可以很快。
AIES 是一种无梯度的集成 MCMC 方法,通过链之间共享信息来指导 Metropolis-Hastings 提议。它仅适用于具有连续潜变量的模型。预计对于潜在维度低到中等的模型性能最佳。对于对数密度不可导的模型来说,这可能是一个不错的选择,并且对无似然模型具有鲁棒性。AIES 通常要求链的数量是潜在参数数量的两倍(理想情况下更大)。
ESS 是一种无梯度的集成 MCMC 方法,通过链之间共享信息来找到好的切片采样方向。它往往比 AIES 更具样本效率。它仅适用于具有连续潜变量的模型。预计对于潜在维度低到中等的模型性能最佳,并且对于对数密度不可导的模型来说可能是一个不错的选择。ESS 通常要求链的数量是潜在参数数量的两倍(理想情况下更大)。
与 HMC/NUTS 类似,如果可能,所有剩余的 MCMC 算法都支持对离散潜变量进行枚举(参见限制)。枚举站点需要标记 infer={‘enumerate’: ‘parallel’},就像标注示例中那样。
- class MCMC(sampler, *, num_warmup, num_samples, num_chains=1, thinning=1, postprocess_fn=None, chain_method='parallel', progress_bar=True, jit_model_args=False)[source]
基类:
object
提供对 NumPyro 中马尔可夫链蒙特卡洛推断算法的访问。
注意
chain_method 是一个实验性参数,未来版本可能会移除。
注意
将 progress_bar 设置为 False 在许多情况下可以提高速度。但这可能比其他选项需要更多内存。
注意
如果在 Jupyter Notebook 中将 num_chains 设置为大于 1,则需要在启动 Jupyter 的环境中安装 ipywidgets,以便进度条正确渲染。如果您正在使用 Jupyter Notebook 或 Jupyter Lab,请同时安装相应的扩展包,例如 widgetsnbextension 或 jupyterlab_widgets。
注意
如果您的数据集很大并且您有多个加速设备可用,您可以将计算分布到多个设备上。请确保您的 jax 版本是 v0.4.4 或更新。例如,
import jax from jax.experimental import mesh_utils from jax.sharding import PositionalSharding import numpy as np import numpyro import numpyro.distributions as dist from numpyro.infer import MCMC, NUTS X = np.random.randn(128, 3) y = np.random.randn(128) def model(X, y): beta = numpyro.sample("beta", dist.Normal(0, 1).expand([3])) numpyro.sample("obs", dist.Normal(X @ beta, 1), obs=y) mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10) # See https://jax.net.cn/en/stable/notebooks/Distributed_arrays_and_automatic_parallelization.html sharding = PositionalSharding(mesh_utils.create_device_mesh((8,))) X_shard = jax.device_put(X, sharding.reshape(8, 1)) y_shard = jax.device_put(y, sharding.reshape(8)) mcmc.run(jax.random.PRNGKey(0), X_shard, y_shard)
- 参数:
sampler (MCMCKernel) –
MCMCKernel
的实例,用于确定运行 MCMC 的采样器。目前,只有HMC
和NUTS
可用。num_warmup (int) – 热身步骤的数量。
num_samples (int) – 从马尔可夫链生成的样本数量。
thinning (int) – 控制保留的热身后期样本比例的正整数。例如,如果 thinning 是 2,则每隔一个样本被保留。默认为 1,即不进行稀释。
num_chains (int) – 要运行的 MCMC 链的数量。默认情况下,链将使用
jax.pmap()
并行运行。如果可用设备不足,链将按顺序运行。postprocess_fn – 后处理可调用对象 - 用于将采样器返回的无约束样本值集合转换为位于样本站点支持范围内的约束值。此外,它还用于返回模型中确定性站点的值。
chain_method (str) – 一个可调用的 jax 转换,如 jax.vmap 或 ‘parallel’(默认)、‘sequential’、‘vectorized’ 之一。‘parallel’ 方法用于在 XLA 设备(CPU/GPU/TPU)上并行执行采样过程,如果设备不足以进行 ‘parallel’,则回退到 ‘sequential’ 方法按顺序绘制链。‘vectorized’ 方法是一个实验性功能,它向量化了采样方法,因此允许我们在单个设备上并行收集样本。
progress_bar (bool) – 是否启用进度条更新。默认为
True
。jit_model_args (bool) – 如果设置为 True,这将把势能计算编译为模型参数的函数。因此,在大小相同但数据集不同的情况下再次调用 MCMC.run 不会导致额外的编译成本。请注意,目前这对于
num_chains > 1
且chain_method == 'parallel'
的情况不生效。
注意
混合并行和向量化采样是可能的,即使用显式 pmap 在多个设备上运行向量化链。目前,这样做需要禁用进度条。例如,
def do_mcmc(rng_key, n_vectorized=8): nuts_kernel = NUTS(model) mcmc = MCMC( nuts_kernel, progress_bar=False, num_chains=n_vectorized, chain_method='vectorized' ) mcmc.run( rng_key, extra_fields=("potential_energy",), ) return {**mcmc.get_samples(), **mcmc.get_extra_fields()} # Number of devices to pmap over n_parallel = jax.local_device_count() rng_keys = jax.random.split(PRNGKey(rng_seed), n_parallel) traces = pmap(do_mcmc)(rng_keys) # concatenate traces along pmap'ed axis trace = {k: np.concatenate(v) for k, v in traces.items()}
- property post_warmup_state
采样阶段之前的状态。如果此属性不为 None,则
run()
将跳过热身阶段,并从此属性指定的状态开始。注意
此属性可用于顺序绘制 MCMC 样本。例如,
mcmc = MCMC(NUTS(model), num_warmup=100, num_samples=100) mcmc.run(random.PRNGKey(0)) first_100_samples = mcmc.get_samples() mcmc.post_warmup_state = mcmc.last_state mcmc.run(mcmc.post_warmup_state.rng_key) # or mcmc.run(random.PRNGKey(1)) second_100_samples = mcmc.get_samples()
- property last_state
采样阶段结束时的最终 MCMC 状态。
- warmup(rng_key, *args, extra_fields=(), collect_warmup=False, init_params=None, **kwargs)[source]
运行 MCMC 热身适应阶段。调用此方法后,self.post_warmup_state 将被设置,并且
run()
方法将跳过热身适应阶段。要针对新数据再次运行 warmup,需要再次运行warmup()
。- 参数:
rng_key (random.PRNGKey) – 用于采样的随机数生成器密钥。
args – 提供给
numpyro.infer.mcmc.MCMCKernel.init()
方法的参数。这些通常是 model 所需的参数。extra_fields (tuple or list) – 在 MCMC 运行期间要收集的状态对象中的额外字段(除了
default_fields()
),例如 HMC 的numpyro.infer.hmc.HMCState
。使用 “~`sampler.sample_field`.`sample_site`” 从收集中排除样本站点。例如,如果您使用 NUTS 采样器,则 “~z.a” 将阻止收集站点 “a”。要收集站点 “a” 在无约束空间中的样本,我们可以在此处指定变量,例如 extra_fields=(“z.a”,)。collect_warmup (bool) – 是否收集热身阶段的样本。默认为 False。
init_params – 开始采样的初始参数。类型必须与提供给核的 potential_fn 的输入类型一致。如果核是通过 numpyro 模型实例化的,这里的初始参数对应于无约束空间中的潜值。
kwargs – 提供给
numpyro.infer.mcmc.MCMCKernel.init()
方法的关键字参数。这些通常是 model 所需的关键字参数。
- run(rng_key, *args, extra_fields=(), init_params=None, **kwargs)[source]
运行 MCMC 采样器并收集样本。
- 参数:
rng_key (random.PRNGKey) – 用于采样的随机数生成器密钥。对于多链,可以提供一批 num_chains 密钥。如果 rng_key 没有 batch_size,它将被分割成一批 num_chains 密钥。
args – 提供给
numpyro.infer.mcmc.MCMCKernel.init()
方法的参数。这些通常是 model 所需的参数。extra_fields (tuple or list of str) – 在 MCMC 运行期间要收集的状态对象中的额外字段(除了 “z”、“diverging”),例如 HMC 的
numpyro.infer.hmc.HMCState
。请注意,可以使用点来访问子字段,例如,可以使用 “adapt_state.step_size” 来收集每个步骤的步长。使用 “~`sampler.sample_field`.`sample_site`” 从收集中排除样本站点。例如,如果您使用 NUTS 采样器,则 “~z.a” 将阻止收集站点 “a”。要收集站点 “a” 在无约束空间中的样本,我们可以在此处指定变量,例如 extra_fields=(“z.a”,)。init_params – 开始采样的初始参数。类型必须与提供给核的 potential_fn 的输入类型一致。如果核是通过 numpyro 模型实例化的,这里的初始参数对应于无约束空间中的潜值。
kwargs – 提供给
numpyro.infer.mcmc.MCMCKernel.init()
方法的关键字参数。这些通常是 model 所需的关键字参数。
注意
jax 允许 Python 代码在编译代码尚未完成时继续执行。这在尝试分析代码速度时可能会导致问题。请参阅 https://jax.net.cn/en/stable/async_dispatch.html 和 https://jax.net.cn/en/stable/profiling.html 以获取有关分析 jax 程序的提示。
- get_samples(group_by_chain=False)[source]
获取 MCMC 运行中的样本。
- 参数:
group_by_chain (bool) – 是否保留链的维度。如果为 True,所有样本的前导维度大小将是 num_chains。
- 返回值:
样本的数据类型与 init_params 相同。如果使用包含 Pyro 原语的模型,数据类型是一个以站点名称为键的 dict,但更一般地可以是任何
jaxlib.pytree()
(例如,当为接受 list 参数的 HMC 定义 potential_fn 时)。
示例
然后您可以将这些样本传递给
Predictive
posterior_samples = mcmc.get_samples() predictive = Predictive(model, posterior_samples=posterior_samples) samples = predictive(rng_key1, *model_args, **model_kwargs)
MCMC 核
MCMCKernel
- class MCMCKernel[source]
基类:
ABC
定义用于
MCMC
推断的马尔可夫转移核的接口。示例
>>> from collections import namedtuple >>> from jax import random >>> import jax.numpy as jnp >>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.infer import MCMC >>> MHState = namedtuple("MHState", ["u", "rng_key"]) >>> class MetropolisHastings(numpyro.infer.mcmc.MCMCKernel): ... sample_field = "u" ... ... def __init__(self, potential_fn, step_size=0.1): ... self.potential_fn = potential_fn ... self.step_size = step_size ... ... def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): ... return MHState(init_params, rng_key) ... ... def sample(self, state, model_args, model_kwargs): ... u, rng_key = state ... rng_key, key_proposal, key_accept = random.split(rng_key, 3) ... u_proposal = dist.Normal(u, self.step_size).sample(key_proposal) ... accept_prob = jnp.exp(self.potential_fn(u) - self.potential_fn(u_proposal)) ... u_new = jnp.where(dist.Uniform().sample(key_accept) < accept_prob, u_proposal, u) ... return MHState(u_new, rng_key) >>> def f(x): ... return ((x - 2) ** 2).sum() >>> kernel = MetropolisHastings(f) >>> mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000) >>> mcmc.run(random.PRNGKey(0), init_params=jnp.array([1., 2.])) >>> posterior_samples = mcmc.get_samples() >>> mcmc.print_summary()
- postprocess_fn(model_args, model_kwargs)[source]
获取一个函数,该函数将样本站点上的无约束值转换为受限到站点支持范围内的值,并返回模型中的确定性站点。
- 参数:
model_args – 模型的参数。
model_kwargs – 模型的关键字参数。
- abstract init(rng_key, num_warmup, init_params, model_args, model_kwargs)[source]
初始化 MCMCKernel 并返回一个初始状态以开始采样。
- property sample_field
传递给
sample()
的 state 对象的属性,表示 MCMC 样本。它由postprocess_fn()
和在MCMC.print_summary()
中报告结果时使用。
- property default_fields
在 MCMC 运行期间(调用
MCMC.run()
时)默认要收集的 state 对象的属性。
- property is_ensemble_kernel
表示核是否为集成核。如果为 True,则在 MCMC 运行期间(调用
MCMC.run()
时),如果 chain_method = “vectorized”,将显示 diagnostics_str。
BarkerMH
- class BarkerMH(model=None, potential_fn=None, step_size=1.0, adapt_step_size=True, adapt_mass_matrix=True, dense_mass=False, target_accept_prob=0.4, init_strategy=<function init_to_uniform>)[source]
基类:
MCMCKernel
这是一种基于梯度的 Metropolis-Hastings 类型 MCMC 算法,它使用依赖于势能梯度的偏对称提议分布(Barker 提议;参见参考文献 [1])。特别是,提议分布在当前样本处沿梯度方向倾斜。
我们期望此算法对于低到中等维度的模型特别有效,在这些模型中它可能与 HMC 和 NUTS 具有竞争力。
注意
建议在
MCMC
中使用此核并设置 progress_bar=False,以减少 JAX 的调度开销。参考文献
Barker 提议:在基于梯度的 MCMC 中结合鲁棒性和效率。Samuel Livingstone, Giacomo Zanella。
- 参数:
model – 包含 Pyro
原语
的 Python 可调用对象。如果提供了 model,将使用 model 推断 potential_fn。potential_fn – 计算给定输入参数的势能的 Python 可调用对象。potential_fn 的输入参数可以是任何 Python 集合类型,前提是提供给
init()
的 init_params 参数类型相同。step_size (float) – Barker 提议中使用的(初始)步长。
adapt_step_size (bool) – 是否在热身期间调整步长。默认为
adapt_step_size==True
。adapt_mass_matrix (bool) – 是否在热身期间调整质量矩阵。默认为
adapt_mass_matrix==True
。dense_mass (bool) – 是否使用密集(即满秩)或对角质量矩阵。(默认为
dense_mass=False
)。target_accept_prob (float) – 用于指导步长调整的目标接受概率。增加此值将导致步长变小,从而采样会变慢但更鲁棒。默认为 0.8。
init_strategy (callable) – 每个站点的初始化函数。有关可用函数,请参阅初始化策略部分。
示例
>>> import jax >>> import jax.numpy as jnp >>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.infer import MCMC, BarkerMH >>> def model(): ... x = numpyro.sample("x", dist.Normal().expand([10])) ... numpyro.sample("obs", dist.Normal(x, 1.0), obs=jnp.ones(10)) >>> >>> kernel = BarkerMH(model) >>> mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000, progress_bar=True) >>> mcmc.run(jax.random.PRNGKey(0)) >>> mcmc.print_summary()
- property model
- property sample_field
传递给
sample()
的 state 对象的属性,表示 MCMC 样本。它由postprocess_fn()
和在MCMC.print_summary()
中报告结果时使用。
- init(rng_key, num_warmup, init_params, model_args, model_kwargs)[source]
初始化 MCMCKernel 并返回一个初始状态以开始采样。
HMC
- class HMC(model=None, potential_fn=None, kinetic_fn=None, step_size=1.0, inverse_mass_matrix=None, adapt_step_size=True, adapt_mass_matrix=True, dense_mass=False, target_accept_prob=0.8, num_steps=None, trajectory_length=6.283185307179586, init_strategy=<function init_to_uniform>, find_heuristic_step_size=False, forward_mode_differentiation=False, regularize_mass_matrix=True)[source]
基类:
MCMCKernel
使用固定轨迹长度的哈密顿蒙特卡洛推断,并提供步长和质量矩阵调整功能。
注意
在 MCMC 运行中使用核之前,postprocess_fn 将返回恒等函数。
注意
默认的初始化策略
init_to_uniform
对于某些模型来说可能不是一个好的策略。您可能想尝试其他初始化策略,例如init_to_median
。参考文献
使用哈密顿动力学的 MCMC, Radford M. Neal
- 参数:
model – 包含 Pyro
原语
的 Python 可调用对象。如果提供了 model,将使用 model 推断 potential_fn。potential_fn – 计算给定输入参数的势能的 Python 可调用对象。potential_fn 的输入参数可以是任何 Python 集合类型,前提是提供给
init()
的 init_params 参数类型相同。kinetic_fn – 返回给定逆质量矩阵和动量的动能的 Python 可调用对象。如果未提供,默认是欧几里得动能。
step_size (float) – 确定 verlet 积分器在使用哈密顿动力学计算轨迹时所取单个步骤的大小。如果未指定,将设置为 1。
inverse_mass_matrix (numpy.ndarray or dict) – 逆质量矩阵的初始值。如果 adapt_mass_matrix = True,这在热身阶段可能会被调整。如果没有指定值,则初始化为单位矩阵。对于具有一般 JAX pytree 参数的 potential_fn,质量矩阵条目的顺序是通过 jax.tree_flatten 获得的 pytree 参数展平版本的顺序,这有点模糊(更多信息参见 https://jax.net.cn/en/stable/pytrees.html)。如果 model 不为 None,这里我们可以将结构化块质量矩阵指定为一个字典,其中键是站点名称的元组,值是相应的质量矩阵块。有关结构化质量矩阵的更多信息,请参见 dense_mass 参数。
adapt_step_size (bool) – 一个标志,用于决定是否要在热身阶段使用 Dual Averaging 方案自适应调整 step_size。
adapt_mass_matrix (bool) – 一个标志,用于决定是否要在热身阶段使用 Welford 方案自适应调整质量矩阵。
此标志控制质量矩阵是密集(即满秩)还是对角(默认为
dense_mass=False
)。要指定结构化质量矩阵,用户可以提供一个站点名称元组列表。每个元组代表联合质量矩阵中的一个块。例如,假设模型有潜在变量“x”、“y”、“z”(其中每个变量可以是多维的),可能的规格及其相应的质量矩阵结构如下所示:dense_mass=[(“x”, “y”)]:对联合 (x, y) 使用密集质量矩阵,对 z 使用对角质量矩阵
dense_mass=[](等同于 dense_mass=False):对联合 (x, y, z) 使用对角质量矩阵
dense_mass=[(“x”, “y”, “z”)](等同于 full_mass=True):对联合 (x, y, z) 使用密集质量矩阵
dense_mass=[(“x”,), (“y”,), (“z”)]:对 x、y 和 z 各自使用密集质量矩阵(即块对角,有 3 个块)
target_accept_prob (float) – 使用 Dual Averaging 进行步长自适应的目标接受概率。增加此值将导致更小的步长,从而采样会更慢但更稳健。默认为 0.8。
num_steps (int) – 如果与 None 不同,则固定每次迭代允许的步数。
trajectory_length (float) – HMC 的 MCMC 轨迹长度。默认值为 \(2\pi\)。
init_strategy (callable) – 每个站点的初始化函数。有关可用函数,请参阅初始化策略部分。
find_heuristic_step_size (bool) – 是否在每个适应窗口开始时使用启发式函数调整步长。默认为 False。
forward_mode_differentiation (bool) – 是否使用前向模式微分或反向模式微分。默认情况下,我们使用反向模式,但在某些情况下,前向模式有助于提高性能。此外,JAX 中的一些控制流实用程序(如 jax.lax.while_loop 或 jax.lax.fori_loop)仅支持前向模式微分。更多信息请参见 JAX 的自动微分手册。
regularize_mass_matrix (bool) – 是否在热身阶段正则化估计的质量矩阵以提高数值稳定性。默认为 True。如果
adapt_mass_matrix == False
,此标志不起作用。
- property model
- property sample_field
传递给
sample()
的 state 对象中表示 MCMC 样本的属性。此属性由postprocess_fn()
使用,也用于在MCMC.print_summary()
中报告结果。
- property default_fields
在 MCMC 运行期间(调用
MCMC.run()
时)默认要收集的 state 对象的属性。
- init(rng_key, num_warmup, init_params=None, model_args=(), model_kwargs={})[source]
初始化 MCMCKernel 并返回一个初始状态以开始采样。
NUTS
- class NUTS(model=None, potential_fn=None, kinetic_fn=None, step_size=1.0, inverse_mass_matrix=None, adapt_step_size=True, adapt_mass_matrix=True, dense_mass=False, target_accept_prob=0.8, trajectory_length=None, max_tree_depth=10, init_strategy=<function init_to_uniform>, find_heuristic_step_size=False, forward_mode_differentiation=False, regularize_mass_matrix=True)[source]
基类:
HMC
哈密顿蒙特卡罗推断,使用具有自适应路径长度和质量矩阵自适应的 No U-Turn Sampler (NUTS)。
注意
在 MCMC 运行中使用核之前,postprocess_fn 将返回恒等函数。
注意
默认的初始化策略
init_to_uniform
对于某些模型来说可能不是一个好的策略。您可能想尝试其他初始化策略,例如init_to_median
。参考文献
使用哈密顿动力学的 MCMC, Radford M. Neal
The No-U-turn sampler: adaptively setting path lengths in Hamiltonian Monte Carlo,作者 Matthew D. Hoffman, 和 Andrew Gelman。
A Conceptual Introduction to Hamiltonian Monte Carlo`,作者 Michael Betancourt
- 参数:
model – 包含 Pyro
原语
的 Python 可调用对象。如果提供了 model,将使用 model 推断 potential_fn。potential_fn – Python 可调用对象,用于计算给定输入参数的势能。传递给 potential_fn 的输入参数可以是任何 python 集合类型,前提是传递给 init_kernel 的 init_params 参数具有相同的类型。
kinetic_fn – 返回给定逆质量矩阵和动量的动能的 Python 可调用对象。如果未提供,默认是欧几里得动能。
step_size (float) – 确定 verlet 积分器在使用哈密顿动力学计算轨迹时所取单个步骤的大小。如果未指定,将设置为 1。
inverse_mass_matrix (numpy.ndarray or dict) – 逆质量矩阵的初始值。如果 adapt_mass_matrix = True,这在热身阶段可能会被调整。如果没有指定值,则初始化为单位矩阵。对于具有一般 JAX pytree 参数的 potential_fn,质量矩阵条目的顺序是通过 jax.tree_flatten 获得的 pytree 参数展平版本的顺序,这有点模糊(更多信息参见 https://jax.net.cn/en/stable/pytrees.html)。如果 model 不为 None,这里我们可以将结构化块质量矩阵指定为一个字典,其中键是站点名称的元组,值是相应的质量矩阵块。有关结构化质量矩阵的更多信息,请参见 dense_mass 参数。
adapt_step_size (bool) – 一个标志,用于决定是否要在热身阶段使用 Dual Averaging 方案自适应调整 step_size。
adapt_mass_matrix (bool) – 一个标志,用于决定是否要在热身阶段使用 Welford 方案自适应调整质量矩阵。
此标志控制质量矩阵是密集(即满秩)还是对角(默认为
dense_mass=False
)。要指定结构化质量矩阵,用户可以提供一个站点名称元组列表。每个元组代表联合质量矩阵中的一个块。例如,假设模型有潜在变量“x”、“y”、“z”(其中每个变量可以是多维的),可能的规格及其相应的质量矩阵结构如下所示:dense_mass=[(“x”, “y”)]:对联合 (x, y) 使用密集质量矩阵,对 z 使用对角质量矩阵
dense_mass=[](等同于 dense_mass=False):对联合 (x, y, z) 使用对角质量矩阵
dense_mass=[(“x”, “y”, “z”)](等同于 full_mass=True):对联合 (x, y, z) 使用密集质量矩阵
dense_mass=[(“x”,), (“y”,), (“z”)]:对 x、y 和 z 各自使用密集质量矩阵(即块对角,有 3 个块)
target_accept_prob (float) – 使用 Dual Averaging 进行步长自适应的目标接受概率。增加此值将导致更小的步长,从而采样会更慢但更稳健。默认为 0.8。
trajectory_length (float) – HMC 的 MCMC 轨迹长度。此参数在 NUTS 采样器中无效。
max_tree_depth (int) – NUTS 采样器倍增方案期间创建的二叉树的最大深度。默认为 10。此参数也接受整数元组 (d1, d2),其中 d1 是热身阶段的最大树深度,d2 是热身后阶段的最大树深度。
init_strategy (callable) – 每个站点的初始化函数。有关可用函数,请参阅初始化策略部分。
find_heuristic_step_size (bool) – 是否在每个适应窗口开始时使用启发式函数调整步长。默认为 False。
forward_mode_differentiation (bool) –
是否使用前向模式微分或反向模式微分。默认情况下,我们使用反向模式,但在某些情况下,前向模式有助于提高性能。此外,JAX 中的一些控制流实用程序(如 jax.lax.while_loop 或 jax.lax.fori_loop)仅支持前向模式微分。更多信息请参见 JAX 的自动微分手册。
HMCGibbs
- class HMCGibbs(inner_kernel, gibbs_fn, gibbs_sites)[source]
基类:
MCMCKernel
[实验性接口]
HMC-within-Gibbs。此推断算法允许用户将通用基于梯度的推断 (HMC 或 NUTS) 与自定义 Gibbs 采样器结合使用。
请注意,提供从相应后验条件进行采样的正确 gibbs_fn 实现是用户的责任。
- 参数:
示例
>>> from jax import random >>> import jax.numpy as jnp >>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.infer import MCMC, NUTS, HMCGibbs ... >>> def model(): ... x = numpyro.sample("x", dist.Normal(0.0, 2.0)) ... y = numpyro.sample("y", dist.Normal(0.0, 2.0)) ... numpyro.sample("obs", dist.Normal(x + y, 1.0), obs=jnp.array([1.0])) ... >>> def gibbs_fn(rng_key, gibbs_sites, hmc_sites): ... y = hmc_sites['y'] ... new_x = dist.Normal(0.8 * (1-y), jnp.sqrt(0.8)).sample(rng_key) ... return {'x': new_x} ... >>> hmc_kernel = NUTS(model) >>> kernel = HMCGibbs(hmc_kernel, gibbs_fn=gibbs_fn, gibbs_sites=['x']) >>> mcmc = MCMC(kernel, num_warmup=100, num_samples=100, progress_bar=False) >>> mcmc.run(random.PRNGKey(0)) >>> mcmc.print_summary()
- sample_field = 'z'
- property model
- postprocess_fn(args, kwargs)[source]
获取一个函数,该函数将样本站点上的无约束值转换为受限到站点支持范围内的值,并返回模型中的确定性站点。
- 参数:
model_args – 模型的参数。
model_kwargs – 模型的关键字参数。
DiscreteHMCGibbs
- class DiscreteHMCGibbs(inner_kernel, *, random_walk=False, modified=False)[source]
基类:
HMCGibbs
[实验性接口]
HMCGibbs
的子类,对离散潜在站点执行 Metropolis 更新。注意
站点更新顺序在每一步随机排列。
注意
此类支持离散潜在变量的枚举。要边缘化一个离散潜在站点,我们可以在其相应的
sample()
语句中指定 infer={‘enumerate’: ‘parallel’} 关键字。- 参数:
参考文献
Peskun’s theorem and a modified discrete-state Gibbs sampler,作者 Liu, J. S. (1996)
示例
>>> from jax import random >>> import jax.numpy as jnp >>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.infer import DiscreteHMCGibbs, MCMC, NUTS ... >>> def model(probs, locs): ... c = numpyro.sample("c", dist.Categorical(probs)) ... numpyro.sample("x", dist.Normal(locs[c], 0.5)) ... >>> probs = jnp.array([0.15, 0.3, 0.3, 0.25]) >>> locs = jnp.array([-2, 0, 2, 4]) >>> kernel = DiscreteHMCGibbs(NUTS(model), modified=True) >>> mcmc = MCMC(kernel, num_warmup=1000, num_samples=100000, progress_bar=False) >>> mcmc.run(random.PRNGKey(0), probs, locs) >>> mcmc.print_summary() >>> samples = mcmc.get_samples()["x"] >>> assert abs(jnp.mean(samples) - 1.3) < 0.1 >>> assert abs(jnp.var(samples) - 4.36) < 0.5
MixedHMC
- class MixedHMC(inner_kernel, *, num_discrete_updates=None, random_walk=False, modified=False)[source]
基类:
DiscreteHMCGibbs
混合哈密顿蒙特卡罗 (参考文献 [1]) 的实现。
注意
每次 MCMC 迭代更新的离散站点数 (n_D 在参考文献 [1] 中) 固定为值 1。
参考文献
Mixed Hamiltonian Monte Carlo for Mixed Discrete and Continuous Variables,作者 Guangyao Zhou (2020)
Peskun’s theorem and a modified discrete-state Gibbs sampler,作者 Liu, J. S. (1996)
- 参数:
inner_kernel – 一个
HMC
核函数。num_discrete_updates (int) – 更新离散变量的次数。默认为离散潜在变量的数量。
random_walk (bool) – 如果为 False,则使用 Gibbs 采样从条件分布 p(gibbs_site | remaining sites) 中抽取样本,其中 gibbs_site 是模型中的一个离散样本站点。否则,从 gibbs_site 的域中均匀抽取样本。默认为 False。
modified (bool) – 是否使用参考文献 [2] 中建议的修改后的提议,该提议总是为当前的 Gibbs 站点(即离散站点)提议一个新状态。默认为 False。修改后的方案在文献中称为“修改后的 Gibbs 采样器”或“Metropolised Gibbs 采样器”。
示例
>>> from jax import random >>> import jax.numpy as jnp >>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.infer import HMC, MCMC, MixedHMC ... >>> def model(probs, locs): ... c = numpyro.sample("c", dist.Categorical(probs)) ... numpyro.sample("x", dist.Normal(locs[c], 0.5)) ... >>> probs = jnp.array([0.15, 0.3, 0.3, 0.25]) >>> locs = jnp.array([-2, 0, 2, 4]) >>> kernel = MixedHMC(HMC(model, trajectory_length=1.2), num_discrete_updates=20) >>> mcmc = MCMC(kernel, num_warmup=1000, num_samples=100000, progress_bar=False) >>> mcmc.run(random.PRNGKey(0), probs, locs) >>> mcmc.print_summary() >>> samples = mcmc.get_samples() >>> assert "x" in samples and "c" in samples >>> assert abs(jnp.mean(samples["x"]) - 1.3) < 0.1 >>> assert abs(jnp.var(samples["x"]) - 4.36) < 0.5
HMCECS
- class HMCECS(inner_kernel, *, num_blocks=1, proxy=None)[source]
基类:
HMCGibbs
[实验性接口]
具有能量守恒子采样的 HMC。
这是
HMCGibbs
的子类,用于对使用plate
原语进行子采样语句的模型执行 HMC-within-Gibbs。它实现了参考文献 [1] 的算法 1,但使用对数似然的朴素估计(无控制变量),因此可能导致高方差。此函数可以将子采样索引划分为块,并在每个 MCMC 步骤中只更新一个块,以提高提议子采样的接受率,详细信息见 [3]。
注意
新的子采样索引在每个 MCMC 步骤中随机有放回地提议。
参考文献
Hamiltonian Monte Carlo with energy conserving subsampling,作者 Dang, K. D., Quiroz, M., Kohn, R., Minh-Ngoc, T., & Villani, M. (2019)
Speeding Up MCMC by Efficient Data Subsampling,作者 Quiroz, M., Kohn, R., Villani, M., & Tran, M. N. (2018)
The Block Pseudo-Margional Sampler,作者 Tran, M.-N., Kohn, R., Quiroz, M. Villani, M. (2017)
The Fundamental Incompatibility of Scalable Hamiltonian Monte Carlo and Naive Data Subsampling,作者 Betancourt, M. (2015)
- 参数:
num_blocks (int) – 将子采样划分为的块数。
proxy – 用于似然估计的
taylor_proxy()
,或者 None 表示朴素(轨迹间)子采样,如 [4] 中所述。
示例
>>> from jax import random >>> import jax.numpy as jnp >>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.infer import HMCECS, MCMC, NUTS ... >>> def model(data): ... x = numpyro.sample("x", dist.Normal(0, 1)) ... with numpyro.plate("N", data.shape[0], subsample_size=100): ... batch = numpyro.subsample(data, event_dim=0) ... numpyro.sample("obs", dist.Normal(x, 1), obs=batch) ... >>> data = random.normal(random.PRNGKey(0), (10000,)) + 1 >>> kernel = HMCECS(NUTS(model), num_blocks=10) >>> mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000) >>> mcmc.run(random.PRNGKey(0), data) >>> samples = mcmc.get_samples()["x"] >>> assert abs(jnp.mean(samples) - 1.) < 0.1
- postprocess_fn(args, kwargs)[source]
获取一个函数,该函数将样本站点上的无约束值转换为受限到站点支持范围内的值,并返回模型中的确定性站点。
- 参数:
model_args – 模型的参数。
model_kwargs – 模型的关键字参数。
- init(rng_key, num_warmup, init_params, model_args, model_kwargs)[source]
初始化 MCMCKernel 并返回一个初始状态以开始采样。
SA
- class SA(model=None, potential_fn=None, adapt_state_size=None, dense_mass=True, init_strategy=<function init_to_uniform>)[source]
基类:
MCMCKernel
样本自适应 MCMC,一种无梯度采样器。
这是一种非常快的采样器(按 n_eff / s 计算),但需要许多热身(burn-in)步。在每个 MCMC 步骤中,我们只需评估一个点的势能函数。
请注意,与参考文献 [1] 不同,我们返回的是大小为 num_chains x num_samples 的近似后验样本的随机选择(即稀疏)子集,而不是 num_chains x num_samples x adapt_state_size。
注意
我们建议在使用
MCMC
时将此核函数与 progress_bar=False 一起使用,以减少 JAX 的调度开销。参考文献
Sample Adaptive MCMC (https://papers.nips.cc/paper/9107-sample-adaptive-mcmc),作者 Michael Zhu
- 参数:
model – 包含 Pyro
原语
的 Python 可调用对象。如果提供了 model,将使用 model 推断 potential_fn。potential_fn – Python 可调用对象,用于计算给定输入参数的目标势能。传递给 potential_fn 的输入参数可以是任何 python 集合类型,前提是传递给
init()
的 init_params 参数具有相同的类型。adapt_state_size (int) – 用于生成提议分布的点数。默认为潜在变量大小的 2 倍。
dense_mass (bool) – 一个标志,用于决定质量矩阵是密集还是对角(默认为
dense_mass=True
)init_strategy (callable) – 每个站点的初始化函数。有关可用函数,请参阅初始化策略部分。
- init(rng_key, num_warmup, init_params=None, model_args=(), model_kwargs={})[source]
初始化 MCMCKernel 并返回一个初始状态以开始采样。
- property model
- property sample_field
传递给
sample()
的 state 对象中表示 MCMC 样本的属性。此属性由postprocess_fn()
使用,也用于在MCMC.print_summary()
中报告结果。
- property default_fields
在 MCMC 运行期间(调用
MCMC.run()
时)默认要收集的 state 对象的属性。
EnsembleSampler
- class EnsembleSampler(model=None, potential_fn=None, *, randomize_split, init_strategy)[source]
基类:
MCMCKernel
,ABC
集成采样器的抽象类。每个 MCMC 样本被分为两个子迭代,其中更新一半的集成。
- 参数:
- property model
- property sample_field
传递给
sample()
的 state 对象中表示 MCMC 样本的属性。此属性由postprocess_fn()
使用,也用于在MCMC.print_summary()
中报告结果。
- property is_ensemble_kernel
表示核是否为集成核。如果为 True,则在 MCMC 运行期间(调用
MCMC.run()
时),如果 chain_method = “vectorized”,将显示 diagnostics_str。
- init(rng_key, num_warmup, init_params=None, model_args=(), model_kwargs={})[source]
初始化 MCMCKernel 并返回一个初始状态以开始采样。
AIES
- class AIES(model=None, potential_fn=None, randomize_split=False, moves=None, init_strategy=<function init_to_uniform>)[source]
基类:
EnsembleSampler
仿射不变集成采样 (Affine-Invariant Ensemble Sampling):一种无梯度方法,通过链之间共享信息来改进 Metropolis-Hastings 提议。适用于低到中维模型。通常,num_chains 应该至少是模型维度的两倍。
注意
此核函数必须与
MCMC
中的 num_chains > 1 和 chain_method=”vectorized 一起使用。链数必须能被 2 整除。参考文献
- emcee: The MCMC Hammer (https://iopscience.iop.org/article/10.1086/670067),
作者 Daniel Foreman-Mackey, David W. Hogg, Dustin Lang, 和 Jonathan Goodman。
- 参数:
model – 包含 Pyro
原语
的 Python 可调用对象。如果提供了 model,将使用 model 推断 potential_fn。potential_fn – Python 可调用对象,用于计算给定输入参数的势能。传递给 potential_fn 的输入参数可以是任何 python 集合类型,前提是传递给
init()
的 init_params 参数具有相同的类型。randomize_split (bool) – 是否在每次迭代时随机排列链的顺序。默认为 False。
moves – 一个字典,将移动映射到其被选中的相应概率。有效键为 AIES.DEMove() 和 AIES.StretchMove()。两者在实践中通常表现良好。如果概率之和超过 1,则概率将被归一化。默认为 {AIES.DEMove(): 1.0}。
init_strategy (callable) – 每个站点的初始化函数。有关可用函数,请参阅初始化策略部分。
示例
>>> import jax >>> import jax.numpy as jnp >>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.infer import MCMC, AIES >>> def model(): ... x = numpyro.sample("x", dist.Normal().expand([10])) ... numpyro.sample("obs", dist.Normal(x, 1.0), obs=jnp.ones(10)) >>> >>> kernel = AIES(model, moves={AIES.DEMove() : 0.5, ... AIES.StretchMove() : 0.5}) >>> mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=20, chain_method='vectorized') >>> mcmc.run(jax.random.PRNGKey(0))
- static DEMove(sigma=1e-05, g0=None)[source]
使用差分进化的提议。
此 差分进化提议 的实现遵循 Nelson 等人 (2013)。
- 参数:
sigma – (可选)用于拉伸提议向量的高斯分布的标准差。默认为 1.0e-5。
(可选) (g0) – 提议向量的平均拉伸因子。默认情况下,按照两篇参考文献的建议,它是 2.38 / sqrt(2*ndim)。
- static StretchMove(a=2.0)[source]
Goodman & Weare (2010) 的“拉伸移动”,并进行并行化,如 Foreman-Mackey 等人 (2013) 所述。
- 参数:
a – (可选)拉伸比例参数。(默认值:
2.0
)
ESS
- class ESS(model=None, potential_fn=None, randomize_split=True, moves=None, max_steps=10000, max_iter=10000, init_mu=1.0, tune_mu=True, init_strategy=<function init_to_uniform>)[source]
基类:
EnsembleSampler
集成切片采样 (Ensemble Slice Sampling):一种无梯度方法,通过链之间共享信息来找到更好的切片采样方向。适用于低到中维模型。通常,num_chains 应该至少是模型维度的两倍。
注意
此核函数必须与
MCMC
中的 num_chains > 1 和 chain_method=”vectorized 一起使用。链数必须能被 2 整除。参考文献
- zeus: a PYTHON implementation of ensemble slice sampling for efficient Bayesian parameter inference (https://academic.oup.com/mnras/article/508/3/3589/6381726),
作者 Minas Karamanis, Florian Beutler, 和 John A. Peacock。
- Ensemble slice sampling (https://link.springer.com/article/10.1007/s11222-021-10038-2),
作者 Minas Karamanis, Florian Beutler。
- 参数:
model – 包含 Pyro
原语
的 Python 可调用对象。如果提供了 model,将使用 model 推断 potential_fn。potential_fn – Python 可调用对象,用于计算给定输入参数的势能。传递给 potential_fn 的输入参数可以是任何 python 集合类型,前提是传递给
init()
的 init_params 参数具有相同的类型。randomize_split (bool) – 是否在每次迭代时随机排列链的顺序。默认为 True。
moves – 一个字典,将移动映射到其被选中的相应概率。如果概率之和超过 1,则概率将被归一化。有效键包括:ESS.DifferentialMove() -> 默认提议,适用于广泛的目标分布,ESS.GaussianMove() -> 适用于近似正态分布的目标,ESS.KDEMove() -> 适用于多峰后验 - 需要大的 num_chains,并且必须良好初始化,ESS.RandomMove() -> 无链交互,用于调试。默认为 {ESS.DifferentialMove(): 1.0}。
max_steps (int) – 每样本最大步出步数。默认为 10,000。
max_iter (int) – 每样本最大扩展/收缩次数。默认为 10,000。
init_mu (float) – 初始比例因子。默认为 1.0。
tune_mu (bool) – 是否调整初始比例因子。默认为 True。
init_strategy (callable) – 每个站点的初始化函数。有关可用函数,请参阅初始化策略部分。
示例
>>> import jax >>> import jax.numpy as jnp >>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.infer import MCMC, ESS >>> def model(): ... x = numpyro.sample("x", dist.Normal().expand([10])) ... numpyro.sample("obs", dist.Normal(x, 1.0), obs=jnp.ones(10)) >>> >>> kernel = ESS(model, moves={ESS.DifferentialMove() : 0.8, ... ESS.RandomMove() : 0.2}) >>> mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=20, chain_method='vectorized') >>> mcmc.run(jax.random.PRNGKey(0))
- static RandomMove()[source]
Karamanis & Beutler (2020) 的“随机移动”并进行并行化。使用此移动时,步行者沿随机方向移动。步行者之间没有通信,此移动对应于朴素的切片采样方法。此移动仅应用于调试目的。
- static KDEMove(bw_method=None)[source]
Karamanis & Beutler (2020) 的“KDE 移动”并进行并行化。使用此移动时,使用高斯核密度估计方法跟踪互补集成步行者的分布。然后步行者沿从该分布采样的随机方向向量移动。
- static GaussianMove()[source]
Karamanis & Beutler (2020) 的“高斯移动”并进行并行化。使用此移动时,步行者沿由互补集成步行者的高斯近似采样的随机向量定义的方向移动。
- static DifferentialMove()[source]
Karamanis & Beutler (2020) 的“差分移动”并进行并行化。使用此移动时,步行者沿由从互补集成中随机(无放回)采样对定义的随机方向移动。这是默认选择,适用于广泛的目标分布。
- hmc(potential_fn=None, potential_fn_gen=None, kinetic_fn=None, algo='NUTS')[source]
哈密顿蒙特卡罗推断,使用固定步数或具有自适应路径长度的 No U-Turn Sampler (NUTS)。
参考文献
使用哈密顿动力学的 MCMC, Radford M. Neal
The No-U-turn sampler: adaptively setting path lengths in Hamiltonian Monte Carlo,作者 Matthew D. Hoffman, 和 Andrew Gelman。
A Conceptual Introduction to Hamiltonian Monte Carlo`,作者 Michael Betancourt
- 参数:
potential_fn – Python 可调用对象,用于计算给定输入参数的势能。传递给 potential_fn 的输入参数可以是任何 python 集合类型,前提是传递给 init_kernel 的 init_params 参数具有相同的类型。
potential_fn_gen – Python 可调用对象,当提供模型参数/关键字参数时,返回 potential_fn。可以提供此参数以便在数据变化的情况下对同一模型进行推断。如果数据形状保持不变,我们可以编译一次 sample_kernel,并将其用于多次推断运行。
kinetic_fn – 返回给定逆质量矩阵和动量的动能的 Python 可调用对象。如果未提供,默认是欧几里得动能。
algo (str) – 是否运行固定步数的
HMC
或自适应路径长度的NUTS
。默认为NUTS
。
- 返回值:
一个包含两个可调用对象的元组 (init_kernel, sample_kernel),第一个用于初始化采样器,第二个用于在现有样本基础上生成样本。
警告
强烈建议您使用更高层的
MCMC
API,而不是直接使用此接口。示例
>>> import jax >>> from jax import random >>> import jax.numpy as jnp >>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.infer.hmc import hmc >>> from numpyro.infer.util import initialize_model >>> from numpyro.util import fori_collect >>> true_coefs = jnp.array([1., 2., 3.]) >>> data = random.normal(random.PRNGKey(2), (2000, 3)) >>> labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample(random.PRNGKey(3)) >>> >>> def model(data, labels): ... coefs = numpyro.sample('coefs', dist.Normal(jnp.zeros(3), jnp.ones(3))) ... intercept = numpyro.sample('intercept', dist.Normal(0., 10.)) ... return numpyro.sample('y', dist.Bernoulli(logits=(coefs * data + intercept).sum(-1)), obs=labels) >>> >>> model_info = initialize_model(random.PRNGKey(0), model, model_args=(data, labels,)) >>> init_kernel, sample_kernel = hmc(model_info.potential_fn, algo='NUTS') >>> hmc_state = init_kernel(model_info.param_info, ... trajectory_length=10, ... num_warmup=300) >>> samples = fori_collect(0, 500, sample_kernel, hmc_state, ... transform=lambda state: model_info.postprocess_fn(state.z)) >>> print(jnp.mean(samples['coefs'], axis=0)) [0.9153987 2.0754058 2.9621222]
- init_kernel(init_params, num_warmup, step_size=1.0, inverse_mass_matrix=None, adapt_step_size=True, adapt_mass_matrix=True, dense_mass=False, target_accept_prob=0.8, *, num_steps=None, trajectory_length=6.283185307179586, max_tree_depth=10, find_heuristic_step_size=False, regularize_mass_matrix=True, model_args=(), model_kwargs=None, rng_key=None)
初始化 HMC 采样器。
- 参数:
init_params – 开始采样的初始参数。其类型必须与 potential_fn 的输入类型一致。
num_warmup (int) – 热身步数;热身期间生成的样本将被丢弃。
step_size (float) – 确定 verlet 积分器在使用哈密顿动力学计算轨迹时所取单个步骤的大小。如果未指定,将设置为 1。
inverse_mass_matrix (numpy.ndarray or dict) – 逆质量矩阵的初始值。如果 adapt_mass_matrix = True,这在热身阶段可能会被调整。如果没有指定值,则初始化为单位矩阵。对于具有一般 JAX pytree 参数的 potential_fn,质量矩阵条目的顺序是通过 jax.tree_flatten 获得的 pytree 参数展平版本的顺序,这有点模糊(更多信息参见 https://jax.net.cn/en/stable/pytrees.html)。如果 model 不为 None,这里我们可以将结构化块质量矩阵指定为一个字典,其中键是站点名称的元组,值是相应的质量矩阵块。有关结构化质量矩阵的更多信息,请参见 dense_mass 参数。
adapt_step_size (bool) – 一个标志,用于决定是否要在热身阶段使用 Dual Averaging 方案自适应调整 step_size。
adapt_mass_matrix (bool) – 一个标志,用于决定是否要在热身阶段使用 Welford 方案自适应调整质量矩阵。
此标志控制质量矩阵是密集(即满秩)还是对角(默认为
dense_mass=False
)。要指定结构化质量矩阵,用户可以提供一个站点名称元组列表。每个元组代表联合质量矩阵中的一个块。例如,假设模型有潜在变量“x”、“y”、“z”(其中每个变量可以是多维的),可能的规格及其相应的质量矩阵结构如下所示:dense_mass=[(“x”, “y”)]:对联合 (x, y) 使用密集质量矩阵,对 z 使用对角质量矩阵
dense_mass=[](等同于 dense_mass=False):对联合 (x, y, z) 使用对角质量矩阵
dense_mass=[(“x”, “y”, “z”)](等同于 full_mass=True):对联合 (x, y, z) 使用密集质量矩阵
dense_mass=[(“x”,), (“y”,), (“z”)]:对 x、y 和 z 各自使用密集质量矩阵(即块对角,有 3 个块)
target_accept_prob (float) – 使用 Dual Averaging 进行步长自适应的目标接受概率。增加此值将导致更小的步长,从而采样会更慢但更稳健。默认为 0.8。
num_steps (int) – 如果与 None 不同,则固定每次迭代允许的步数。
trajectory_length (float) – HMC 的 MCMC 轨迹长度。默认值为 \(2\pi\)。
max_tree_depth (int) – NUTS 采样器倍增方案期间创建的二叉树的最大深度。默认为 10。此参数也接受整数元组 (d1, d2),其中 d1 是热身阶段的最大树深度,d2 是热身后阶段的最大树深度。
find_heuristic_step_size (bool) – 是否在每个适应窗口开始时使用启发式函数调整步长。默认为 False。
forward_mode_differentiation (bool) –
是否使用前向模式微分或反向模式微分。默认情况下,我们使用反向模式,但在某些情况下,前向模式有助于提高性能。此外,JAX 中的一些控制流实用程序(如 jax.lax.while_loop 或 jax.lax.fori_loop)仅支持前向模式微分。更多信息请参见 JAX 的自动微分手册。
regularize_mass_matrix (bool) – 是否在热身阶段正则化估计的质量矩阵以提高数值稳定性。默认为 True。如果
adapt_mass_matrix == False
,此标志不起作用。model_args (tuple) – 如果指定了 potential_fn_gen,则为模型参数。
model_kwargs (dict) – 如果指定了 potential_fn_gen,则为模型关键字参数。
rng_key (jax.random.PRNGKey) – 用作随机性来源的随机键。
- sample_kernel(hmc_state, model_args=(), model_kwargs=None)
给定现有的
HMCState
,使用固定(可能已适应)步长运行 HMC 并返回新的HMCState
。
- taylor_proxy(reference_params, degree)[source]
用于无偏对数似然估计的控制变量,使用围绕参考参数的泰勒展开。建议在 [1] 中用于子采样。
- 参数:
reference_params (dict) – 在 MLE 或 MAP 估计处的模型参数化。
degree – 泰勒展开的项数,可以是 1 或 2。
参考文献
- [1] On Markov chain Monte Carlo Methods For Tall Data
作者 Bardenet., R., Doucet, A., Holmes, C. (2017)
- BarkerMHState = <class 'numpyro.infer.barker.BarkerMHState'>
一个
namedtuple()
,包含以下字段:i - 迭代次数。热身结束后重置为 0。
z - Python 集合,表示潜在站点的值(来自后验的无约束样本)。
potential_energy - 在给定
z
值处计算的势能。z_grad - 势能相对于潜在样本站点的梯度。
accept_prob - 提议的接受概率。请注意,如果被拒绝,
z
不对应于提议。mean_accept_prob - 热身自适应或采样期间(用于诊断)直到当前迭代的平均接受概率。
adapt_state - 一个
HMCAdaptState
namedtuple,包含热身期间的自适应信息step_size - 下一次迭代中积分器使用的步长。
inverse_mass_matrix - 下一次迭代中使用的逆质量矩阵。
mass_matrix_sqrt - 下一次迭代中使用的质量矩阵的平方根。在密集质量矩阵的情况下,这是质量矩阵的 Cholesky 分解。
rng_key - 用于生成提议等的随机数生成器种子。
- HMCState = <class 'numpyro.infer.hmc.HMCState'>
一个
namedtuple()
,包含以下字段:i - 迭代次数。热身结束后重置为 0。
z - Python 集合,表示潜在站点的值(来自后验的无约束样本)。
z_grad - 势能相对于潜在样本站点的梯度。
potential_energy - 在给定
z
值处计算的势能。energy - 当前状态的势能和动能之和。
r - 当前动量变量。如果为 None,则在每个采样步骤开始时将抽取新的动量变量。
trajectory_length - 每个采样步骤中运行 HMC 动力学的时间长度。此字段在 NUTS 中不使用。
num_steps - 哈密顿轨迹中的步数(用于诊断)。在 HMC 采样器中,为了适应步长,trajectory_length 应为 None。在 NUTS 采样器中,轨迹的树深度可以从此字段计算得出:tree_depth = np.log2(num_steps).astype(int) + 1。
accept_prob - 提议的接受概率。请注意,如果被拒绝,
z
不对应于提议。mean_accept_prob - 热身自适应或采样期间(用于诊断)直到当前迭代的平均接受概率。
diverging - 一个布尔值,指示当前轨迹是否发散。
adapt_state - 一个
HMCAdaptState
namedtuple,包含热身期间的自适应信息step_size - 下一次迭代中积分器使用的步长。
inverse_mass_matrix - 下一次迭代中使用的逆质量矩阵。
mass_matrix_sqrt - 下一次迭代中使用的质量矩阵的平方根。在密集质量矩阵的情况下,这是质量矩阵的 Cholesky 分解。
rng_key - 用于迭代的随机数生成器种子。
- HMCGibbsState = <class 'numpyro.infer.hmc_gibbs.HMCGibbsState'>
z - 当前潜在变量值(包括 HMC 和 Gibbs 站点)的字典。
hmc_state - 当前
HMCState
。rng_key - 当前步骤的随机键。
- SAState = <class 'numpyro.infer.sa.SAState'>
在样本自适应 MCMC 中使用的
namedtuple()
。包含以下字段:i - 迭代次数。热身结束后重置为 0。
z - Python 集合,表示潜在站点的值(来自后验的无约束样本)。
potential_energy - 在给定
z
值处计算的势能。accept_prob - 提议的接受概率。请注意,如果被拒绝,
z
不对应于提议。mean_accept_prob - 热身或采样期间(用于诊断)直到当前迭代的平均接受概率。
diverging - 一个布尔值,指示新样本的势能是否与当前势能发散。
adapt_state - 一个
SAAdaptState
namedtuple,包含自适应信息zs - 用于生成提议的点/状态。
pes - zs 的势能。
loc - 这些 zs 的均值。
inv_mass_matrix_sqrt - 如果使用密集质量矩阵,这是 zs 协方差的 Cholesky 分解。否则,这是这些 zs 的标准差。
rng_key - 用于迭代的随机数生成器种子。
- EnsembleSamplerState = <class 'numpyro.infer.ensemble.EnsembleSamplerState'>
一个
namedtuple()
,包含以下字段:z - Python 集合,表示潜在站点的值(来自后验的无约束样本)。
inner_state - 一个 namedtuple,包含更新一半集成所需的信息。
rng_key - 用于生成提议等的随机数生成器种子。
- AIESState = <class 'numpyro.infer.ensemble.AIESState'>
一个
namedtuple()
,包含以下字段。i - 迭代次数。
accept_prob - 提议的接受概率。请注意,如果被拒绝,
z
不对应于提议。mean_accept_prob - 热身自适应或采样期间(用于诊断)直到当前迭代的平均接受概率。
rng_key - 用于生成提议等的随机数生成器种子。
- ESSState = <class 'numpyro.infer.ensemble.ESSState'>
用作集成采样器内部状态的
namedtuple()
。包含以下字段:i - 迭代次数。
n_expansions - 当前批次中的扩展次数。用于调整 mu。
n_contractions - 当前批次中的收缩次数。用于调整 mu。
mu - 比例因子。如果 tune_mu=True,则进行调整。
rng_key - 用于生成提议等的随机数生成器种子。
TensorFlow 核函数
TensorFlow Probability (TFP) MCMC 核函数的薄包装器。有关 TFP MCMC 核函数接口的详细信息,请参见其 TransitionKernel 文档。
TFPKernel
- class TFPKernel(model=None, potential_fn=None, init_strategy=<function init_to_uniform>, **kernel_kwargs)[source]
TensorFlow Probability (TFP) MCMC 转换核函数的薄包装器。TFP 中的参数 target_log_prob_fn 被 model 或 potential_fn(后者是 target_log_prob_fn 的负值)代替。
此类可用于将 TFP 核函数转换为 NumPyro 兼容的核函数,如下所示:
from numpyro.contrib.tfp.mcmc import TFPKernel kernel = TFPKernel[tfp.mcmc.NoUTurnSampler](model, step_size=1.)
注意
默认情况下,未校准的核函数将作为
MetropolisHastings
核函数的内层核函数。注意
对于
ReplicaExchangeMC
,TFP 要求内层核函数的 step_size 形状必须为 [len(inverse_temperatures), 1] 或 [len(inverse_temperatures), latent_size]。
HamiltonianMonteCarlo
- class HamiltonianMonteCarlo(model=None, potential_fn=None, init_strategy=<function init_to_uniform>, **kernel_kwargs)
将 tensorflow_probability.substrates.jax.mcmc.hmc.HamiltonianMonteCarlo 包装到
TFPKernel
中。TFP 内核构造中的第一个参数 target_log_prob_fn 被 model 或 potential_fn 替代。
MetropolisAdjustedLangevinAlgorithm
- class MetropolisAdjustedLangevinAlgorithm(model=None, potential_fn=None, init_strategy=<function init_to_uniform>, **kernel_kwargs)
将 tensorflow_probability.substrates.jax.mcmc.langevin.MetropolisAdjustedLangevinAlgorithm 包装到
TFPKernel
中。TFP 内核构造中的第一个参数 target_log_prob_fn 被 model 或 potential_fn 替代。
NoUTurnSampler
- class NoUTurnSampler(model=None, potential_fn=None, init_strategy=<function init_to_uniform>, **kernel_kwargs)
将 tensorflow_probability.substrates.jax.mcmc.nuts.NoUTurnSampler 包装到
TFPKernel
中。TFP 内核构造中的第一个参数 target_log_prob_fn 被 model 或 potential_fn 替代。
RandomWalkMetropolis
- class RandomWalkMetropolis(model=None, potential_fn=None, init_strategy=<function init_to_uniform>, **kernel_kwargs)
将 tensorflow_probability.substrates.jax.mcmc.random_walk_metropolis.RandomWalkMetropolis 包装到
TFPKernel
中。TFP 内核构造中的第一个参数 target_log_prob_fn 被 model 或 potential_fn 替代。
ReplicaExchangeMC
- class ReplicaExchangeMC(model=None, potential_fn=None, init_strategy=<function init_to_uniform>, **kernel_kwargs)
将 tensorflow_probability.substrates.jax.mcmc.replica_exchange_mc.ReplicaExchangeMC 包装到
TFPKernel
中。TFP 内核构造中的第一个参数 target_log_prob_fn 被 model 或 potential_fn 替代。
SliceSampler
- class SliceSampler(model=None, potential_fn=None, init_strategy=<function init_to_uniform>, **kernel_kwargs)
将 tensorflow_probability.substrates.jax.mcmc.slice_sampler_kernel.SliceSampler 包装到
TFPKernel
中。TFP 内核构造中的第一个参数 target_log_prob_fn 被 model 或 potential_fn 替代。
UncalibratedHamiltonianMonteCarlo
- class UncalibratedHamiltonianMonteCarlo(model=None, potential_fn=None, init_strategy=<function init_to_uniform>, **kernel_kwargs)
将 tensorflow_probability.substrates.jax.mcmc.hmc.UncalibratedHamiltonianMonteCarlo 包装到
TFPKernel
中。TFP 内核构造中的第一个参数 target_log_prob_fn 被 model 或 potential_fn 替代。
UncalibratedLangevin
- class UncalibratedLangevin(model=None, potential_fn=None, init_strategy=<function init_to_uniform>, **kernel_kwargs)
将 tensorflow_probability.substrates.jax.mcmc.langevin.UncalibratedLangevin 包装到
TFPKernel
中。TFP 内核构造中的第一个参数 target_log_prob_fn 被 model 或 potential_fn 替代。
UncalibratedRandomWalk
- class UncalibratedRandomWalk(model=None, potential_fn=None, init_strategy=<function init_to_uniform>, **kernel_kwargs)
将 tensorflow_probability.substrates.jax.mcmc.random_walk_metropolis.UncalibratedRandomWalk 包装到
TFPKernel
中。TFP 内核构造中的第一个参数 target_log_prob_fn 被 model 或 potential_fn 替代。
MCMC 工具函数
- initialize_model(rng_key, model, *, init_strategy=<function init_to_uniform>, dynamic_args=False, model_args=(), model_kwargs=None, forward_mode_differentiation=False, validate_grad=True)[source]
(实验性接口)此辅助函数内部调用
get_potential_fn()
和find_valid_initial_params()
,返回一个元组 (init_params_info, potential_fn, postprocess_fn, model_trace)。- 参数:
rng_key (jax.random.PRNGKey) – 用于从先验分布采样的随机数生成器种子。返回的 init_params 将具有
rng_key.shape[:-1]
的批量形状。model – 包含 Pyro 原语的 Python 可调用对象。
init_strategy (callable) – 每个站点的初始化函数。有关可用函数,请参阅初始化策略部分。
dynamic_args (bool) – 如果为 True,则 potential_fn 和 constraints_fn 本身依赖于模型参数。当提供 *model_args, **model_kwargs 时,它们分别返回可调用对象 potential_fn 和 constraints_fn。
model_args (tuple) – 提供给模型的参数。
model_kwargs (dict) – 提供给模型的关键字参数。
forward_mode_differentiation (bool) –
是否使用前向模式微分或反向模式微分。默认情况下,我们使用反向模式,但在某些情况下,前向模式有助于提高性能。此外,JAX 中的一些控制流实用程序(如 jax.lax.while_loop 或 jax.lax.fori_loop)仅支持前向模式微分。更多信息请参见 JAX 的自动微分手册。
validate_grad (bool) – 是否验证初始参数的梯度。默认为 True。
- 返回值:
一个命名元组 ModelInfo,包含字段 (param_info, potential_fn, postprocess_fn, model_trace),其中 param_info 是一个命名元组 ParamInfo,包含用于初始化 MCMC 的先验值、相应的势能及其梯度;postprocess_fn 是一个可调用对象,它使用逆变换将无约束的 HMC 样本转换为位于站点支持范围内的约束值,此外还返回模型中 deterministic 站点的值。
- fori_collect(lower: int, upper: int, body_fun: ~typing.Callable, init_val: ~typing.Any, transform: ~typing.Callable = <function identity>, progbar: bool = True, return_last_val: bool = False, collection_size=None, thinning=1, **progbar_opts)[source]
此循环构造类似于
fori_loop()
,但额外增加了从循环体中收集值的功能。此外,它允许通过 transform 对这些样本进行后处理,并更新进度条。请注意,progbar=False 会更快,尤其是在收集大量样本时。请参考hmc()
中的示例用法。- 参数:
lower (int) – 开始收集工作的索引。换句话说,我们将跳过收集前 lower 个值。
upper (int) – 循环体运行的次数。
body_fun – 一个可调用对象,接受一个 np.ndarray 集合并返回具有相同形状和 dtype 的集合。
init_val – 传递给 body_fun 的初始值。可以是包含 np.ndarray 对象的任何 Python 集合类型。
transform – 一个可调用对象,用于后处理 body_fn 返回的值。
progbar – 是否发布进度条更新。
return_last_val (bool) – 如果为 True,则也返回最后一个值。其类型与 init_val 相同。
thinning – 控制保留值的稀疏比率的正整数。默认为 1,即不进行稀疏处理。
collection_size (int) – 返回集合的大小。如果未指定,大小将为
(upper - lower) // thinning
。如果大小大于(upper - lower) // thinning
,则只有前(upper - lower) // thinning
个条目非零。**progbar_opts – 可选的额外进度条参数。可以提供一个 diagnostics_fn,当传入来自 body_fun 的当前值时,它返回一个字符串用于更新进度条的后缀。此外,还可以提供一个 progbar_desc 关键字参数,用于标记进度条。
- 返回值:
与 init_val 类型相同的集合,其中 np.ndarray 对象的值沿主轴收集。
- consensus(subposteriors, num_draws=None, diagonal=False, rng_key=None)[source]
遵循共识蒙特卡罗算法合并子后验。
参考文献
Bayes and big data: The consensus Monte Carlo algorithm, Steven L. Scott, Alexander W. Blocker, Fernando V. Bonassi, Hugh A. Chipman, Edward I. George, Robert E. McCulloch
- parametric(subposteriors, diagonal=False)[source]
遵循(易于并行化的)参数化蒙特卡罗算法合并子后验。
参考文献
Asymptotically Exact, Embarrassingly Parallel MCMC, Willie Neiswanger, Chong Wang, Eric Xing