运行时工具
enable_validation
- enable_validation(is_validate=True)[source]
在 NumPyro 中启用或禁用验证检查。验证检查提供有用的警告和错误,例如 NaN 检查、验证分布参数和支持值等,这对调试很有用。
注意
此工具在 JAX 的 JIT 编译或向量化变换
jax.vmap()
下不起作用。- 参数:
is_validate (bool) – 是否启用验证检查。
validation_enabled
enable_x64
set_platform
set_host_device_count
- set_host_device_count(n: int) None [source]
默认情况下,XLA 将所有 CPU 核心视为一个设备。此工具告知 XLA 有 n 个主机 (CPU) 设备可供使用。因此,这使得 JAX 中的并行映射
jax.pmap()
可以在 CPU 平台上工作。注意
此工具仅在程序开始时生效。在底层,它会设置环境变量 XLA_FLAGS=–xla_force_host_platform_device_count=[num_devices],其中 [num_device] 是所需的 CPU 设备数量 n。
警告
我们对在 XLA 中使用 xla_force_host_platform_device_count 标志的副作用的理解不完整。如果您在使用此工具时发现一些奇怪的现象,请通过我们的问题页面或论坛告知我们。更多信息可在此 JAX 问题中找到。
- 参数:
n (int) – 要使用的 CPU 设备数量。
推断工具
Predictive
- class Predictive(model: Callable, posterior_samples: dict | None = None, *, guide: Callable | None = None, params: dict | None = None, num_samples: int | None = None, return_sites: Sequence[str] | None = None, infer_discrete: bool = False, parallel: bool = False, batch_ndims: int | None = None, exclude_deterministic: bool = True)[source]
基类:
object
此类用于构建预测分布。预测分布是通过在来自 posterior_samples 的隐变量样本上运行模型获得的。
警告
Predictive 类的接口是实验性的,将来可能会更改。
请注意,为了按预期返回预测分布,模型中的观测变量(限制似然项)必须设置为 None(参见示例)。
- 参数:
model – 包含 Pyro 原语的 Python 可调用对象。
posterior_samples (dict) – 后验样本字典。
guide (callable) – 可选的指南,用于获取 posterior_samples 中不存在的站点的后验样本。
params (dict) – 模型/指南中 param 站点值的字典。
num_samples (int) – 样本数量
return_sites (list) – 要返回的站点;默认情况下,仅返回 posterior_samples 中不存在的样本站点。
infer_discrete (bool) – 是否从后验中采样离散站点,条件是观测值和
posterior_samples
中的其他隐变量值。在底层,这些站点将被标记为site["infer"]["enumerate"] = "parallel"
。请参阅 Pyro 枚举教程,了解 infer_discrete 如何工作。请注意,这需要安装funsor
。parallel (bool) – 是否使用 JAX 向量化映射
jax.vmap()
进行并行预测。默认为 False。batch_ndims –
后验样本或参数中的批次维度数。如果为 None,则在设置了指南(即不为 None)时默认为 0,否则为 1。批处理后验样本的用法:
设置 batch_ndims=0 以获得单个样本的预测
设置 batch_ndims=1 以获得形状为 (num_samples x …) 的 posterior_samples 的预测(与 guide=None 时 `batch_ndims=None` 相同)
设置 batch_ndims=2 以获得形状为 (num_chains x N x …) 的 posterior_samples 的预测。请注意,如果 num_samples 参数不为 None,其值应等于 num_chains x N。
批处理参数的用法:
设置 batch_ndims=0 从指南和参数中获取 1 个样本(与带指南的 batch_ndims=None 相同)
设置 batch_ndims=1 从指南和参数的一维批次中获取形状为 (num_samples x batch_size x …) 的预测
exclude_deterministic – 指示是否忽略后验样本中的确定性站点。
- 返回:
预测分布的样本字典。
示例
给定一个模型
def model(X, y=None): ... return numpyro.sample("obs", likelihood, obs=y)
您可以从先验预测中采样
predictive = Predictive(model, num_samples=1000) y_pred = predictive(rng_key, X)["obs"]
请注意,上面没有将 y 的值传递给 predictive,导致 y 被设置为 None。在使用 Predictive 时,将观测变量设置为 None 是使方法按预期工作的必要条件。
如果您还有后验样本,则可以从后验预测中采样
predictive = Predictive(model, posterior_samples=posterior_samples) y_pred = predictive(rng_key, X)["obs"]
请参阅
SVI
和MCMCKernel
的文档字符串,以查看上下文中的示例代码。
log_density
compute_log_probs
- compute_log_probs(model, model_args: tuple, model_kwargs: dict, params: dict, sum_log_prob: bool = True)[source]
(实验性接口)计算给定隐变量值
params
时模型的每个站点的对数密度。- 参数:
model – 包含 NumPyro 原语的 Python 可调用对象。
model_args – 提供给模型的参数。
model_kwargs – 提供给模型的关键字参数。
params – 按站点名称键控的当前参数值字典。
sum_log_prob – 在批次维度上对数概率求和。
- 返回:
将站点名称映射到对数密度和相应模型迹的字典。
get_transforms
transform_fn
constrain_fn
unconstrain_fn
potential_energy
log_likelihood
- log_likelihood(model, posterior_samples, *args, parallel=False, batch_ndims=1, **kwargs)[source]
(实验性接口)返回给定所有隐变量样本的模型中观测节点的对数似然。
- 参数:
model – 包含 Pyro 原语的 Python 可调用对象。
posterior_samples (dict) – 后验样本字典。
args – 模型参数。
batch_ndims –
后验样本中的批次维度数。一些用法:
设置 batch_ndims=0 以获取单个样本的对数似然
设置 batch_ndims=1 以获取形状为 (num_samples x …) 的 posterior_samples 的对数似然
设置 batch_ndims=2 以获取形状为 (num_chains x num_samples x …) 的 posterior_samples 的对数似然
kwargs – 模型关键字参数。
- 返回:
观测站点的对数似然字典。
find_valid_initial_params
- find_valid_initial_params(rng_key, model, *, init_strategy=<function init_to_uniform>, enum=False, model_args=(), model_kwargs=None, prototype_params=None, forward_mode_differentiation=False, validate_grad=True)[source]
(实验性接口)给定包含 Pyro 原语的模型,返回所有参数的初始有效无约束值。此函数还返回相应的势能、梯度以及一个 is_valid 标志,指示初始参数是否有效。如果对数密度的值和梯度具有有限值,则参数值被视为有效。
- 参数:
rng_key (jax.random.PRNGKey) – 从先验采样的随机数生成器种子。返回的 init_params 将具有批处理形状
rng_key.shape[:-1]
。model – 包含 Pyro 原语的 Python 可调用对象。
init_strategy (callable) – 每个站点的初始化函数。
enum (bool) – 是否枚举离散隐变量站点。
model_args (tuple) – 提供给模型的参数。
model_kwargs (dict) – 提供给模型的关键字参数。
prototype_params (dict) – 可选的原型参数,用于定义初始参数的形状。
forward_mode_differentiation (bool) – 是否使用前向模式微分或反向模式微分。默认为 False。
validate_grad (bool) – 是否验证初始参数的梯度。默认为 True。
- 返回:
init_params_info 和 is_valid 的元组,其中 init_params_info 是包含初始参数、其势能及其梯度的元组。
初始化策略
init_to_feasible
init_to_mean
- init_to_mean(site=None)[source]
初始化为先验均值。对于没有实现 .mean 属性的先验,我们采用
init_to_median()
策略。
init_to_median
- init_to_median(site=None, num_samples=15)[source]
初始化为先验中位数。对于没有实现 .sample 方法的先验,我们采用
init_to_uniform()
策略。- 参数:
num_samples (int) – 用于计算中位数的先验点数。
init_to_sample
- init_to_sample(site=None)[source]
初始化为先验样本。对于没有实现 .sample 方法的先验,我们采用
init_to_uniform()
策略。
init_to_uniform
init_to_value
- init_to_value(site=None, values={})[source]
初始化为 values 中指定的值。对于未出现在 values 中的站点,我们采用
init_to_uniform()
策略。- 参数:
values (dict) – 按站点名称键控的初始值字典。
张量索引
- vindex(tensor, args)[source]
带有广播语义的向量化高级索引。
另请参阅便利包装器
Vindex
。这对于编写与批处理和枚举兼容的索引代码非常有用,特别是对于使用离散随机变量选择混合成分。
例如,假设
x
是一个参数,其len(x.shape) == 3
,并且我们希望将表达式x[i, :, j]
从整数i,j
推广到带有批次维度和枚举维度(但没有事件维度)的张量i,j
。然后我们可以使用Vindex
编写推广版本xij = Vindex(x)[i, :, j] batch_shape = broadcast_shape(i.shape, j.shape) event_shape = (x.size(1),) assert xij.shape == batch_shape + event_shape
为了处理
x
也可能包含批次维度的情况(例如,如果x
在平板上下文中使用向量化粒子进行采样),vindex()
使用特殊约定,Ellipsis
表示批次维度(因此...
只能出现在左侧,不能出现在中间或右侧)。假设x
的事件维度为 3。那么我们可以写成old_batch_shape = x.shape[:-3] old_event_shape = x.shape[-3:] xij = Vindex(x)[..., i, :, j] # The ... denotes unknown batch shape. new_batch_shape = broadcast_shape(old_batch_shape, i.shape, j.shape) new_event_shape = (x.size(1),) assert xij.shape = new_batch_shape + new_event_shape
请注意,这种对
Ellipsis
的特殊处理与 NEP [1] 不同。形式上,此函数假设
每个参数要么是
Ellipsis
、slice(None)
、整数,要么是批处理整数张量(即事件形状为空)。此函数不支持非平凡切片或布尔张量掩码。Ellipsis
只能作为args[0]
出现在左侧。如果
args[0] 不是 Ellipsis
,则tensor
未进行批处理,并且其事件维度等于len(args)
。如果
args[0] 是 Ellipsis
,则tensor
已进行批处理,并且其事件维度等于len(args[1:])
。tensor
事件维度左侧的维度被视为批次维度,并将与张量参数的维度进行广播。
请注意,如果所有参数都不是具有
len(shape) > 0
的张量,则此函数的行为与标准索引相同if not any(isinstance(a, jnp.ndarray) and len(a.shape) > 0 for a in args): assert Vindex(x)[args] == x[args]
参考文献
- [1] https://numpy.com.cn/neps/nep-0021-advanced-indexing.html
将
vindex
作为向量化索引的辅助函数引入。此实现类似于提议的符号x.vindex[]
,只是对Ellipsis
的处理略有不同。
- 参数:
tensor (jnp.ndarray) – 要索引的张量。
args (tuple) – 一个索引,作为
__getitem__
的参数。
- 返回:
tensor[args]
的非标准解释。- 返回类型:
jnp.ndarray
模型检查
get_dependencies
- get_dependencies(model: Callable, model_args: tuple | None = None, model_kwargs: dict | None = None) dict[str, object] [source]
推断条件模型的依赖结构。
这会返回一个嵌套字典,结构如下
{ "prior_dependencies": { "variable1": {"variable1": set()}, "variable2": {"variable1": set(), "variable2": set()}, ... }, "posterior_dependencies": { "variable1": {"variable1": {"plate1"}, "variable2": set()}, ... }, }
其中
prior_dependencies 是一个字典,将下游隐变量和观测变量映射到另一个字典,后者将它们依赖的上游隐变量映射到引起完全依赖的平板集合。也就是说,包含的平板引入二次方的依赖关系,如完全二分图,而排除的平板仅引入线性的依赖关系,如独立并行的边集。先验依赖关系遵循原始模型顺序。
posterior_dependencies 是一个类似的字典,但将隐变量映射到它们在后验中依赖的隐变量或观测站点。后验依赖关系与模型顺序相反。
依赖关系会省略
numpyro.deterministic
站点和numpyro.sample(..., Delta(...))
站点。示例
这里是一个没有平板的简单示例。我们看到每个节点都依赖于自身,并且只有隐变量出现在后验中
def model_1(): a = numpyro.sample("a", dist.Normal(0, 1)) numpyro.sample("b", dist.Normal(a, 1), obs=0.0) assert get_dependencies(model_1) == { "prior_dependencies": { "a": {"a": set()}, "b": {"a": set(), "b": set()}, }, "posterior_dependencies": { "a": {"a": set(), "b": set()}, }, }
这里有一个示例,其中两个变量
a
和b
在先验中条件独立,但在后验中由于所谓的“碰撞”变量c
而条件依赖。a
和b
都影响 c`,因此在给定c
的值时它们是条件相关的。这在图形模型文献中被称为“道德化”。def model_2(): a = numpyro.sample("a", dist.Normal(0, 1)) b = numpyro.sample("b", dist.LogNormal(0, 1)) c = numpyro.sample("c", dist.Normal(a, b)) numpyro.sample("d", dist.Normal(c, 1), obs=0.) assert get_dependencies(model_2) == { "prior_dependencies": { "a": {"a": set()}, "b": {"b": set()}, "c": {"a": set(), "b": set(), "c": set()}, "d": {"c": set(), "d": set()}, }, "posterior_dependencies": { "a": {"a": set(), "b": set(), "c": set()}, "b": {"b": set(), "c": set()}, "c": {"c": set(), "d": set()}, }, }
在存在平板的情况下,依赖关系可能更复杂。到目前为止,所有的字典值都是空的平板集合,但在下面的后验中,我们看到
c
通过平板p
依赖于自身。这意味着,在c
的元素中,例如c[0]
依赖于c[1]
(这就是我们明确允许变量依赖于自身的原因)def model_3(): with numpyro.plate("p", 5): a = numpyro.sample("a", dist.Normal(0, 1)) numpyro.sample("b", dist.Normal(a.sum(), 1), obs=0.0) assert get_dependencies(model_3) == { "prior_dependencies": { "a": {"a": set()}, "b": {"a": set(), "b": set()}, }, "posterior_dependencies": { "a": {"a": {"p"}, "b": set()}, }, }
- [1] S.Webb, A.Goliński, R.Zinkov, N.Siddharth, T.Rainforth, Y.W.Teh, F.Wood (2018)
“生成模型的忠实反演,实现有效的摊销推理” https://dl.acm.org/doi/10.5555/3327144.3327229
get_model_relations
- get_model_relations(model, model_args=None, model_kwargs=None)[source]
从给定模型和可选数据中推断 RVs 和 plate 的关系。更多详情请参阅 https://github.com/pyro-ppl/numpyro/issues/949。
这将返回一个包含以下键的字典:
“sample_sample” 将每个下游采样站点映射到其所依赖的上游采样站点列表;
“sample_param” 将每个下游采样站点映射到其所依赖的上游参数站点列表;
“sample_dist” 将每个采样站点映射到该站点的分布名称;
“param_constraint” 将每个参数站点映射到该站点的约束名称;
“plate_sample” 将每个 plate 名称映射到该 plate 内的采样站点列表;以及
“observe” 是一个观测到的采样站点列表。
例如,对于模型
def model(data): m = numpyro.sample('m', dist.Normal(0, 1)) sd = numpyro.sample('sd', dist.LogNormal(m, 1)) with numpyro.plate('N', len(data)): numpyro.sample('obs', dist.Normal(m, sd), obs=data)
关系为
{'sample_sample': {'m': [], 'sd': ['m'], 'obs': ['m', 'sd']}, 'sample_dist': {'m': 'Normal', 'sd': 'LogNormal', 'obs': 'Normal'}, 'plate_sample': {'N': ['obs']}, 'observed': ['obs']}
- 参数:
model (callable) – 要检查的模型。
model_args – 可选的模型参数元组 (args)。
model_kwargs – 可选的模型关键字参数字典 (kwargs)。
- 返回类型:
可视化工具
render_model
Trace 检查
- format_shapes(trace: dict, *, compute_log_prob: bool = False, title: str = 'Trace Shapes:', last_site: str | None = None)[source]
给定函数的 trace,返回一个字符串,显示 trace 中所有站点的形状表格。
使用
trace
handler(或用于枚举的 funsortrace
handler)生成 trace。- 参数:
用法
def model(*args, **kwargs): ... with numpyro.handlers.seed(rng_seed=1): trace = numpyro.handlers.trace(model).get_trace(*args, **kwargs) print(numpyro.util.format_shapes(trace))