随机变分推断 (SVI)
我们简要概述 NumPyro 中三种最常用的 ELBO 实现
Trace_ELBO 是我们基本的 ELBO 实现。
TraceMeanField_ELBO 类似于
Trace_ELBO
,但如果可能的话,它可以解析计算 ELBO 的一部分。TraceGraph_ELBO 为具有离散潜在变量的模型提供了方差缩减策略。一般来说,对于具有离散潜在变量的模型,应始终使用此 ELBO。
TraceEnum_ELBO 为具有离散潜在变量的模型提供了变量枚举策略。一般来说,当可以枚举时,对于具有离散潜在变量的模型,应始终使用此 ELBO。
- class SVI(model, guide, optim, loss, **static_kwargs)[source]
基类:
object
给定 ELBO 损失目标的随机变分推断。
参考资料
SVI 第一部分:Pyro 中的随机变分推断简介, (https://pyro.org.cn/examples/svi_part_i.html)
示例
>>> from jax import random >>> import jax.numpy as jnp >>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.distributions import constraints >>> from numpyro.infer import Predictive, SVI, Trace_ELBO >>> def model(data): ... f = numpyro.sample("latent_fairness", dist.Beta(10, 10)) ... with numpyro.plate("N", data.shape[0] if data is not None else 10): ... numpyro.sample("obs", dist.Bernoulli(f), obs=data) >>> def guide(data): ... alpha_q = numpyro.param("alpha_q", 15., constraint=constraints.positive) ... beta_q = numpyro.param("beta_q", lambda rng_key: random.exponential(rng_key), ... constraint=constraints.positive) ... numpyro.sample("latent_fairness", dist.Beta(alpha_q, beta_q)) >>> data = jnp.concatenate([jnp.ones(6), jnp.zeros(4)]) >>> optimizer = numpyro.optim.Adam(step_size=0.0005) >>> svi = SVI(model, guide, optimizer, loss=Trace_ELBO()) >>> svi_result = svi.run(random.PRNGKey(0), 2000, data) >>> params = svi_result.params >>> inferred_mean = params["alpha_q"] / (params["alpha_q"] + params["beta_q"]) >>> # use guide to make predictive >>> predictive = Predictive(model, guide=guide, params=params, num_samples=1000) >>> samples = predictive(random.PRNGKey(1), data=None) >>> # get posterior samples >>> predictive = Predictive(guide, params=params, num_samples=1000) >>> posterior_samples = predictive(random.PRNGKey(1), data=None) >>> # use posterior samples to make predictive >>> predictive = Predictive(model, posterior_samples, params=params, num_samples=1000) >>> samples = predictive(random.PRNGKey(1), data=None)
- 参数:
model – 使用 Pyro 原语的模型 Python 可调用对象。
guide – 使用 Pyro 原语的 Guide(识别网络)Python 可调用对象。
optim –
一个
_NumpyroOptim
的实例,一个jax.example_libraries.optimizers.Optimizer
或一个 OptaxGradientTransformation
。如果您传递一个 Optax 优化器,它将自动使用numpyro.optim.optax_to_numpyro()
进行封装。>>> from optax import adam, chain, clip >>> svi = SVI(model, guide, chain(clip(10.0), adam(1e-3)), loss=Trace_ELBO())
loss – ELBO 损失,即负证据下界,用于最小化。
static_kwargs – 模型/Guide 的静态参数,即在拟合过程中保持不变的参数。
- 返回:
包含 (init_fn, update_fn, evaluate) 的元组。
- update(svi_state, *args, forward_mode_differentiation=False, **kwargs)[source]
使用优化器执行 SVI 的单个步骤(可能在批次/小批次数据上)。
- 参数:
svi_state – 当前 SVI 状态。
args – 模型/Guide 的参数(这些参数在拟合过程中可能会变化)。
forward_mode_differentiation – 布尔标志,指示是否使用前向模式微分。默认为 False。
kwargs – 模型/Guide 的关键字参数(这些参数在拟合过程中可能会变化)。
- 返回:
包含 (svi_state, loss) 的元组。
- stable_update(svi_state, *args, forward_mode_differentiation=False, **kwargs)[source]
类似于
update()
,但如果损失或新状态包含无效值,则返回当前状态。- 参数:
svi_state – 当前 SVI 状态。
args – 模型/Guide 的参数(这些参数在拟合过程中可能会变化)。
forward_mode_differentiation – 布尔标志,指示是否使用前向模式微分。默认为 False。
kwargs – 模型/Guide 的关键字参数(这些参数在拟合过程中可能会变化)。
- 返回:
包含 (svi_state, loss) 的元组。
- run(rng_key, num_steps, *args, progress_bar=True, stable_update=False, forward_mode_differentiation=False, init_state=None, init_params=None, **kwargs)[source]
(实验性接口) 运行 SVI num_steps 迭代次数,然后返回优化后的参数和每一步的损失堆叠。如果 num_steps 很大,设置 progress_bar=False 可以使运行更快。
注意
对于复杂的训练过程(例如需要提前停止、分 epoch 训练、参数/关键字参数变化等),我们建议使用更灵活的方法
init()
、update()
、evaluate()
来定制您的训练流程。- 参数:
rng_key (jax.random.PRNGKey) – 随机数生成器种子。
num_steps (int) – 优化步数。
args – 模型/Guide 的参数
progress_bar (bool) – 是否启用进度条更新。默认为
True
。stable_update (bool) – 是否使用
stable_update()
更新状态。默认为 False。forward_mode_differentiation (bool) – 是否使用前向模式微分或反向模式微分。默认情况下,我们使用反向模式,但在某些情况下,前向模式对于提高性能很有用。此外,JAX 上的一些控制流工具,例如 jax.lax.while_loop 或 jax.lax.fori_loop,仅支持前向模式微分。有关更多信息,请参阅 JAX 的 Autodiff Cookbook。
init_state (SVIState) –
如果不是 None,则从上次 SVI 运行的最终状态开始 SVI。用法
svi = SVI(model, guide, optimizer, loss=Trace_ELBO()) svi_result = svi.run(random.PRNGKey(0), 2000, data) # upon inspection of svi_result the user decides that the model has not converged # continue from the end of the previous svi run rather than beginning again from iteration 0 svi_result = svi.run(random.PRNGKey(1), 2000, data, init_state=svi_result.state)
init_params (dict) – 如果不是 None,则使用此字典中的值初始化
numpyro.param
站点,而不是使用init_value
在numpyro.param
原语中初始化。kwargs – 模型/Guide 的关键字参数
- 返回:
一个命名元组,包含字段 params 和 losses,其中 params 包含
numpyro.param
站点的优化值,losses 是过程中收集的损失。- 返回类型:
- SVIState = <class 'numpyro.infer.svi.SVIState'>
- 一个
namedtuple()
,包含以下字段: optim_state - 当前优化器的状态。
mutable_state - 用于存储 “mutable” 站点值的额外状态
rng_key - 用于迭代的随机数生成器种子。
- 一个
- SVIRunResult = <class 'numpyro.infer.svi.SVIRunResult'>
- 一个
namedtuple()
,包含以下字段: params - 优化后的参数。
state - 最后的
SVIState
losses - 每一步收集的损失。
- 一个
ELBO
- class ELBO(num_particles=1, vectorize_particles=True)[source]
基类:
object
所有 ELBO 目标的基类。
子类应该实现
loss()
或loss_with_mutable_state()
。- 参数:
num_particles – 用于构成 ELBO(梯度)估计量的粒子/样本数量。
vectorize_particles – 是否使用 jax.vmap 并行计算 num_particles 多个粒子的 ELBO。如果为 False,则使用 jax.lax.map。默认为 True。您还可以传递一个可调用对象来指定自定义向量化策略,例如 jax.pmap。
- can_infer_discrete = False
- loss(rng_key, param_map, model, guide, *args, **kwargs)[source]
使用 num_particles 数量的样本/粒子来评估 ELBO 的估计量。
- 参数:
rng_key (jax.random.PRNGKey) – 随机数生成器种子。
param_map (dict) – 以站点名称为键的当前参数值的字典。
model – 使用 NumPyro 原语的模型的 Python 可调用对象。
guide – 使用 NumPyro 原语的 Guide 的 Python 可调用对象。
args – 模型/Guide 的参数(这些参数在拟合过程中可能会变化)。
kwargs – 模型/Guide 的关键字参数(这些参数在拟合过程中可能会变化)。
- 返回:
要最小化的负证据下界 (ELBO)。
- loss_with_mutable_state(rng_key, param_map, model, guide, *args, **kwargs)[source]
类似于
loss()
,但也会更新并返回可变状态,该状态存储mutable()
站点的值。- 参数:
rng_key (jax.random.PRNGKey) – 随机数生成器种子。
param_map (dict) – 以站点名称为键的当前参数值的字典。
model – 使用 NumPyro 原语的模型的 Python 可调用对象。
guide – 使用 NumPyro 原语的 Guide 的 Python 可调用对象。
args – 模型/Guide 的参数(这些参数在拟合过程中可能会变化)。
kwargs – 模型/Guide 的关键字参数(这些参数在拟合过程中可能会变化)。
- 返回:
包含 ELBO 损失和可变状态的字典
Trace_ELBO
- class Trace_ELBO(num_particles: int = 1, vectorize_particles: bool = True, multi_sample_guide: bool = False, sum_sites: bool = True)[source]
基类:
ELBO
ELBO-based SVI 的 trace 实现。估计器根据参考资料 [1] 和 [2] 的思路构建。模型或 Guide 的依赖结构没有限制。
这是证据下界(变分推断中的基本目标)的最基本实现。此实现具有各种限制(例如,它仅支持带有重参数化采样器的随机变量),但可用作构建更复杂损失目标的模板。
有关更多详细信息,请参阅 https://pyro.org.cn/examples/svi_part_i.html。
参考资料
概率编程中的自动化变分推断, David Wingate, Theo Weber
黑箱变分推断, Rajesh Ranganath, Sean Gerrish, David M. Blei
- 参数:
num_particles – 用于构成 ELBO(梯度)估计量的粒子/样本数量。
vectorize_particles – 是否使用 jax.vmap 并行计算 num_particles 多个粒子的 ELBO。如果为 False,则使用 jax.lax.map。默认为 True。您还可以传递一个可调用对象来指定自定义向量化策略,例如 jax.pmap。
multi_sample_guide – 是否假设 Guide 提出多个样本。
sum_sites – 是否汇总所有站点的 ELBO 贡献,或以站点为键返回贡献字典。
- loss_with_mutable_state(rng_key, param_map, model, guide, *args, **kwargs)[source]
类似于
loss()
,但也会更新并返回可变状态,该状态存储mutable()
站点的值。- 参数:
rng_key (jax.random.PRNGKey) – 随机数生成器种子。
param_map (dict) – 以站点名称为键的当前参数值的字典。
model – 使用 NumPyro 原语的模型的 Python 可调用对象。
guide – 使用 NumPyro 原语的 Guide 的 Python 可调用对象。
args – 模型/Guide 的参数(这些参数在拟合过程中可能会变化)。
kwargs – 模型/Guide 的关键字参数(这些参数在拟合过程中可能会变化)。
- 返回:
包含 ELBO 损失和可变状态的字典
TraceEnum_ELBO
- class TraceEnum_ELBO(num_particles=1, max_plate_nesting=inf, vectorize_particles=True)[source]
基类:
ELBO
(实验性) ELBO-based SVI 的 TraceEnum 实现。梯度估计器根据参考资料 [1] 的思路构建,专门用于 ELBO 的情况。它支持模型和 Guide 的任意依赖结构。
利用 trace 中记录的细粒度条件依赖信息来减少梯度估计器的方差。特别是,使用来源跟踪 [2] 来查找依赖于每个不可重参数化样本站点的
cost
项。使用平板因子图的 TVE 算法 [3] 消除枚举变量。注意
目前,此目标不支持 AutoContinous guides。我们建议用户使用 AutoNormal guide 作为替代的自动解决方案。
参考资料
- [1] Storchastic:通用随机自动微分框架,
Emile van Kriekenc, Jakub M. Tomczak, Annette ten Teije
- [2] 概率程序用于高效推断的非标准解释,
David Wingate, Noah Goodman, Andreas Stuhlmüller, Jeffrey Siskind
- [3] 平板因子图的张量变量消除,
Fritz Obermeyer, Eli Bingham, Martin Jankowiak, Justin Chiu, Neeraj Pradhan, Alexander M. Rush, Noah Goodman
- can_infer_discrete = True
- loss(rng_key, param_map, model, guide, *args, **kwargs)[source]
使用 num_particles 数量的样本/粒子来评估 ELBO 的估计量。
- 参数:
rng_key (jax.random.PRNGKey) – 随机数生成器种子。
param_map (dict) – 以站点名称为键的当前参数值的字典。
model – 使用 NumPyro 原语的模型的 Python 可调用对象。
guide – 使用 NumPyro 原语的 Guide 的 Python 可调用对象。
args – 模型/Guide 的参数(这些参数在拟合过程中可能会变化)。
kwargs – 模型/Guide 的关键字参数(这些参数在拟合过程中可能会变化)。
- 返回:
要最小化的负证据下界 (ELBO)。
TraceGraph_ELBO
- class TraceGraph_ELBO(num_particles=1, vectorize_particles=True)[source]
基类:
ELBO
ELBO-based SVI 的 TraceGraph 实现。梯度估计器根据参考资料 [1] 的思路构建,专门用于 ELBO 的情况。它支持模型和 Guide 的任意依赖结构。利用 trace 中记录的细粒度条件依赖信息来减少梯度估计器的方差。特别是,使用来源跟踪 [2] 来查找依赖于每个不可重参数化样本站点的
cost
项。参考资料
- [1] 使用随机计算图的梯度估计,
John Schulman, Nicolas Heess, Theophane Weber, Pieter Abbeel
- [2] 概率程序用于高效推断的非标准解释,
David Wingate, Noah Goodman, Andreas Stuhlmüller, Jeffrey Siskind
- can_infer_discrete = True
- loss(rng_key, param_map, model, guide, *args, **kwargs)[source]
使用 num_particles 数量的样本/粒子来评估 ELBO 的估计量。
- 参数:
rng_key (jax.random.PRNGKey) – 随机数生成器种子。
param_map (dict) – 以站点名称为键的当前参数值的字典。
model – 使用 NumPyro 原语的模型的 Python 可调用对象。
guide – 使用 NumPyro 原语的 Guide 的 Python 可调用对象。
args – 模型/Guide 的参数(这些参数在拟合过程中可能会变化)。
kwargs – 模型/Guide 的关键字参数(这些参数在拟合过程中可能会变化)。
- 返回:
要最小化的负证据下界 (ELBO)。
TraceMeanField_ELBO
- class TraceMeanField_ELBO(num_particles: int = 1, vectorize_particles: bool = True, sum_sites: bool = True)[source]
基类:
ELBO
ELBO-based SVI 的 trace 实现。这是 NumPyro 中目前唯一一个在可行时使用解析 KL 散度的 ELBO 估计器。
- 参数:
num_particles – 用于构成 ELBO(梯度)估计量的粒子/样本数量。
vectorize_particles – 是否使用 jax.vmap 并行计算 num_particles 多个粒子的 ELBO。如果为 False,则使用 jax.lax.map。默认为 True。您还可以传递一个可调用对象来指定自定义向量化策略,例如 jax.pmap。
sum_sites – 是否汇总所有站点的 ELBO 贡献,或以站点为键返回贡献字典。
警告
如果平均场条件不满足,此估计器可能会给出不正确的结果。平均场条件是此估计器正确的充分而非必要条件。精确条件是对于 Guide 中的每个潜在变量 z,其在模型中的父节点不能包含任何在 Guide 中是 z 的后代的潜在变量。此处“模型中的父节点”和“Guide 中的后代”是相对于对应的(统计)依赖结构而言的。例如,如果模型和 Guide 具有相同的依赖结构,则此条件始终满足。
- loss_with_mutable_state(rng_key, param_map, model, guide, *args, **kwargs)[source]
类似于
loss()
,但也会更新并返回可变状态,该状态存储mutable()
站点的值。- 参数:
rng_key (jax.random.PRNGKey) – 随机数生成器种子。
param_map (dict) – 以站点名称为键的当前参数值的字典。
model – 使用 NumPyro 原语的模型的 Python 可调用对象。
guide – 使用 NumPyro 原语的 Guide 的 Python 可调用对象。
args – 模型/Guide 的参数(这些参数在拟合过程中可能会变化)。
kwargs – 模型/Guide 的关键字参数(这些参数在拟合过程中可能会变化)。
- 返回:
包含 ELBO 损失和可变状态的字典
RenyiELBO
- class RenyiELBO(alpha=0, num_particles=2)[source]
基类:
ELBO
Renyi \(\alpha\) 散度变分推断的实现,遵循参考资料 [1]。为了使目标成为严格下界,我们要求 \(\alpha \ge 0\)。然而,请注意,根据参考资料 [1],取决于数据集,\(\alpha < 0\) 可能会得到更好的结果。在特殊情况 \(\alpha = 0\) 下,目标函数是参考资料 [2] 中推导出的重要加权自编码器。
注意
设置 \(\alpha < 1\) 得到比通常 ELBO 更好的边界。
- 参数:
alpha (float) – \(\alpha\) 散度的阶。此处 \(\alpha \neq 1\)。默认为 0。
num_particles – 用于构成目标(梯度)估计量的粒子/样本数量。默认为 2。
vectorize_particles – 是否使用 jax.vmap 并行计算 num_particles 多个粒子的 ELBO。如果为 False,则使用 jax.lax.map。默认为 True。您还可以传递一个可调用对象来指定自定义向量化策略,例如 jax.pmap。
示例
def model(data): with numpyro.plate("batch", 10000, subsample_size=100): latent = numpyro.sample("latent", dist.Normal(0, 1)) batch = numpyro.subsample(data, event_dim=0) numpyro.sample("data", dist.Bernoulli(logits=latent), obs=batch) def guide(data): w_loc = numpyro.param("w_loc", 1.) w_scale = numpyro.param("w_scale", 1.) with numpyro.plate("batch", 10000, subsample_size=100): batch = numpyro.subsample(data, event_dim=0) loc = w_loc * batch scale = jnp.exp(w_scale * batch) numpyro.sample("latent", dist.Normal(loc, scale)) elbo = RenyiELBO(num_particles=10) svi = SVI(model, guide, optax.adam(0.1), elbo)
参考资料
Renyi 散度变分推断, Yingzhen Li, Richard E. Turner
重要加权自编码器, Yuri Burda, Roger Grosse, Ruslan Salakhutdinov
- loss(rng_key, param_map, model, guide, *args, **kwargs)[source]
使用 num_particles 数量的样本/粒子来评估 ELBO 的估计量。
- 参数:
rng_key (jax.random.PRNGKey) – 随机数生成器种子。
param_map (dict) – 以站点名称为键的当前参数值的字典。
model – 使用 NumPyro 原语的模型的 Python 可调用对象。
guide – 使用 NumPyro 原语的 Guide 的 Python 可调用对象。
args – 模型/Guide 的参数(这些参数在拟合过程中可能会变化)。
kwargs – 模型/Guide 的关键字参数(这些参数在拟合过程中可能会变化)。
- 返回:
要最小化的负证据下界 (ELBO)。