运行时工具

enable_validation

enable_validation(is_validate=True)[source]

在 NumPyro 中启用或禁用验证检查。验证检查提供有用的警告和错误,例如 NaN 检查、验证分布参数和支持值等,这对调试很有用。

注意

此工具在 JAX 的 JIT 编译或向量化变换 jax.vmap() 下不起作用。

参数:

is_validate (bool) – 是否启用验证检查。

validation_enabled

validation_enabled(is_validate=True)[source]

一个上下文管理器,在临时启用/禁用验证检查时很有用。

参数:

is_validate (bool) – 是否启用验证检查。

enable_x64

enable_x64(use_x64: bool = True) None[source]

将默认数组类型更改为使用 NumPy 中的 64 位精度。

参数:

use_x64 (bool) – 当为 True 时,JAX 数组默认使用 64 位;否则使用 32 位。

set_platform

set_platform(platform: str | None = None) None[source]

将平台更改为 CPU、GPU 或 TPU。此工具仅在程序开始时生效。

参数:

platform (str) – 可以是 ‘cpu’、‘gpu’ 或 ‘tpu’。

set_host_device_count

set_host_device_count(n: int) None[source]

默认情况下,XLA 将所有 CPU 核心视为一个设备。此工具告知 XLA 有 n 个主机 (CPU) 设备可供使用。因此,这使得 JAX 中的并行映射 jax.pmap() 可以在 CPU 平台上工作。

注意

此工具仅在程序开始时生效。在底层,它会设置环境变量 XLA_FLAGS=–xla_force_host_platform_device_count=[num_devices],其中 [num_device] 是所需的 CPU 设备数量 n

警告

我们对在 XLA 中使用 xla_force_host_platform_device_count 标志的副作用的理解不完整。如果您在使用此工具时发现一些奇怪的现象,请通过我们的问题页面或论坛告知我们。更多信息可在此 JAX 问题中找到。

参数:

n (int) – 要使用的 CPU 设备数量。

推断工具

Predictive

class Predictive(model: Callable, posterior_samples: dict | None = None, *, guide: Callable | None = None, params: dict | None = None, num_samples: int | None = None, return_sites: Sequence[str] | None = None, infer_discrete: bool = False, parallel: bool = False, batch_ndims: int | None = None, exclude_deterministic: bool = True)[source]

基类: object

此类用于构建预测分布。预测分布是通过在来自 posterior_samples 的隐变量样本上运行模型获得的。

警告

Predictive 类的接口是实验性的,将来可能会更改。

请注意,为了按预期返回预测分布,模型中的观测变量(限制似然项)必须设置为 None(参见示例)。

参数:
  • model – 包含 Pyro 原语的 Python 可调用对象。

  • posterior_samples (dict) – 后验样本字典。

  • guide (callable) – 可选的指南,用于获取 posterior_samples 中不存在的站点的后验样本。

  • params (dict) – 模型/指南中 param 站点值的字典。

  • num_samples (int) – 样本数量

  • return_sites (list) – 要返回的站点;默认情况下,仅返回 posterior_samples 中不存在的样本站点。

  • infer_discrete (bool) – 是否从后验中采样离散站点,条件是观测值和 posterior_samples 中的其他隐变量值。在底层,这些站点将被标记为 site["infer"]["enumerate"] = "parallel"。请参阅 Pyro 枚举教程,了解 infer_discrete 如何工作。请注意,这需要安装 funsor

  • parallel (bool) – 是否使用 JAX 向量化映射 jax.vmap() 进行并行预测。默认为 False。

  • batch_ndims

    后验样本或参数中的批次维度数。如果为 None,则在设置了指南(即不为 None)时默认为 0,否则为 1。批处理后验样本的用法:

    • 设置 batch_ndims=0 以获得单个样本的预测

    • 设置 batch_ndims=1 以获得形状为 (num_samples x …)posterior_samples 的预测(与 guide=None 时 `batch_ndims=None` 相同)

    • 设置 batch_ndims=2 以获得形状为 (num_chains x N x …)posterior_samples 的预测。请注意,如果 num_samples 参数不为 None,其值应等于 num_chains x N

    批处理参数的用法:

    • 设置 batch_ndims=0 从指南和参数中获取 1 个样本(与带指南的 batch_ndims=None 相同)

    • 设置 batch_ndims=1 从指南和参数的一维批次中获取形状为 (num_samples x batch_size x …) 的预测

  • exclude_deterministic – 指示是否忽略后验样本中的确定性站点。

返回:

预测分布的样本字典。

示例

给定一个模型

def model(X, y=None):
    ...
    return numpyro.sample("obs", likelihood, obs=y)

您可以从先验预测中采样

predictive = Predictive(model, num_samples=1000)
y_pred = predictive(rng_key, X)["obs"]

请注意,上面没有将 y 的值传递给 predictive,导致 y 被设置为 None。在使用 Predictive 时,将观测变量设置为 None 是使方法按预期工作的必要条件。

如果您还有后验样本,则可以从后验预测中采样

predictive = Predictive(model, posterior_samples=posterior_samples)
y_pred = predictive(rng_key, X)["obs"]

请参阅 SVIMCMCKernel 的文档字符串,以查看上下文中的示例代码。

log_density

log_density(model, model_args: tuple, model_kwargs: dict, params: dict)[source]

(实验性接口)计算给定隐变量值 params 时模型的联合对数密度。

参数:
  • model – 包含 NumPyro 原语的 Python 可调用对象。

  • model_args – 提供给模型的参数。

  • model_kwargs – 提供给模型的关键字参数。

  • params – 按站点名称键控的当前参数值字典。

返回:

联合对数密度和相应的模型迹。

compute_log_probs

compute_log_probs(model, model_args: tuple, model_kwargs: dict, params: dict, sum_log_prob: bool = True)[source]

(实验性接口)计算给定隐变量值 params 时模型的每个站点的对数密度。

参数:
  • model – 包含 NumPyro 原语的 Python 可调用对象。

  • model_args – 提供给模型的参数。

  • model_kwargs – 提供给模型的关键字参数。

  • params – 按站点名称键控的当前参数值字典。

  • sum_log_prob – 在批次维度上对数概率求和。

返回:

将站点名称映射到对数密度和相应模型迹的字典。

get_transforms

get_transforms(model, model_args, model_kwargs, params)[source]

(实验性接口)给定 NumPyro 模型,通过 biject_to() 检索(逆)变换。此函数支持 'param' 站点。注意:参数值仅用于检索模型迹。

参数:
  • model – 包含 NumPyro 原语的可调用对象。

  • model_args (tuple) – 提供给模型的参数。

  • model_kwargs (dict) – 提供给模型的关键字参数。

  • params (dict) – 按站点名称键控的值字典。

返回:

按站点名称键控的变换 dict

transform_fn

transform_fn(transforms, params, invert=False)[source]

(实验性接口)将 transforms 字典中的变换应用于 params 字典中的值,并返回按相同名称键控的变换后的值的可调用对象。

参数:
  • transforms – 按名称键控的变换字典。transformsparams 中的名称应对应。

  • params – 按名称键控的数组字典。

  • invert – 是否应用变换的逆变换。

返回:

变换后的参数 dict

constrain_fn

constrain_fn(model, model_args, model_kwargs, params, return_deterministic=False)[source]

(实验性接口)给定无约束参数 params,获取 model 中每个隐变量站点的值。transforms 用于将这些无约束参数转换为 model 中相应先验的基本值。如果先验是变换分布,则相应的基本值位于基本分布的支持域中。否则,基本值位于分布的支持域中。

参数:
  • model – 包含 NumPyro 原语的可调用对象。

  • model_args (tuple) – 提供给模型的参数。

  • model_kwargs (dict) – 提供给模型的关键字参数。

  • params (dict) – 按站点名称键控的无约束值字典。

  • return_deterministic (bool) – 是否返回模型中 deterministic 站点的值。默认为 False

返回:

变换后的参数 dict

unconstrain_fn

unconstrain_fn(model, model_args, model_kwargs, params)[source]

(实验性接口)给定 NumPyro 模型和参数字典,此函数应用正确的变换将参数值从约束空间转换为无约束空间。

参数:
  • model – 包含 NumPyro 原语的可调用对象。

  • model_args (tuple) – 提供给模型的参数。

  • model_kwargs (dict) – 提供给模型的关键字参数。

  • params (dict) – 按站点名称键控的约束值字典。

返回:

按站点名称键控的变换 dict

potential_energy

potential_energy(model, model_args, model_kwargs, params, enum=False)[source]

(实验性接口)计算给定无约束参数的模型的势能。在底层,我们将把这些无约束参数转换为属于 model 中相应先验的支持域的值。

参数:
  • model – 包含 NumPyro 原语的可调用对象。

  • model_args (tuple) – 提供给模型的参数。

  • model_kwargs (dict) – 提供给模型的关键字参数。

  • params (dict) – model 的无约束参数。

  • enum (bool) – 是否枚举离散隐变量站点。

返回:

给定无约束参数的势能。

log_likelihood

log_likelihood(model, posterior_samples, *args, parallel=False, batch_ndims=1, **kwargs)[source]

(实验性接口)返回给定所有隐变量样本的模型中观测节点的对数似然。

参数:
  • model – 包含 Pyro 原语的 Python 可调用对象。

  • posterior_samples (dict) – 后验样本字典。

  • args – 模型参数。

  • batch_ndims

    后验样本中的批次维度数。一些用法:

    • 设置 batch_ndims=0 以获取单个样本的对数似然

    • 设置 batch_ndims=1 以获取形状为 (num_samples x …)posterior_samples 的对数似然

    • 设置 batch_ndims=2 以获取形状为 (num_chains x num_samples x …)posterior_samples 的对数似然

  • kwargs – 模型关键字参数。

返回:

观测站点的对数似然字典。

find_valid_initial_params

find_valid_initial_params(rng_key, model, *, init_strategy=<function init_to_uniform>, enum=False, model_args=(), model_kwargs=None, prototype_params=None, forward_mode_differentiation=False, validate_grad=True)[source]

(实验性接口)给定包含 Pyro 原语的模型,返回所有参数的初始有效无约束值。此函数还返回相应的势能、梯度以及一个 is_valid 标志,指示初始参数是否有效。如果对数密度的值和梯度具有有限值,则参数值被视为有效。

参数:
  • rng_key (jax.random.PRNGKey) – 从先验采样的随机数生成器种子。返回的 init_params 将具有批处理形状 rng_key.shape[:-1]

  • model – 包含 Pyro 原语的 Python 可调用对象。

  • init_strategy (callable) – 每个站点的初始化函数。

  • enum (bool) – 是否枚举离散隐变量站点。

  • model_args (tuple) – 提供给模型的参数。

  • model_kwargs (dict) – 提供给模型的关键字参数。

  • prototype_params (dict) – 可选的原型参数,用于定义初始参数的形状。

  • forward_mode_differentiation (bool) – 是否使用前向模式微分或反向模式微分。默认为 False。

  • validate_grad (bool) – 是否验证初始参数的梯度。默认为 True。

返回:

init_params_infois_valid 的元组,其中 init_params_info 是包含初始参数、其势能及其梯度的元组。

初始化策略

init_to_feasible

init_to_feasible(site=None)[source]

初始化为任意可行点,忽略分布参数。

init_to_mean

init_to_mean(site=None)[source]

初始化为先验均值。对于没有实现 .mean 属性的先验,我们采用 init_to_median() 策略。

init_to_median

init_to_median(site=None, num_samples=15)[source]

初始化为先验中位数。对于没有实现 .sample 方法的先验,我们采用 init_to_uniform() 策略。

参数:

num_samples (int) – 用于计算中位数的先验点数。

init_to_sample

init_to_sample(site=None)[source]

初始化为先验样本。对于没有实现 .sample 方法的先验,我们采用 init_to_uniform() 策略。

init_to_uniform

init_to_uniform(site=None, radius=2)[source]

在无约束域的区域 (-radius, radius) 中初始化为随机点。

参数:

radius (float) – 指定在无约束域中绘制初始点的范围。

init_to_value

init_to_value(site=None, values={})[source]

初始化为 values 中指定的值。对于未出现在 values 中的站点,我们采用 init_to_uniform() 策略。

参数:

values (dict) – 按站点名称键控的初始值字典。

张量索引

vindex(tensor, args)[source]

带有广播语义的向量化高级索引。

另请参阅便利包装器 Vindex

这对于编写与批处理和枚举兼容的索引代码非常有用,特别是对于使用离散随机变量选择混合成分。

例如,假设 x 是一个参数,其 len(x.shape) == 3,并且我们希望将表达式 x[i, :, j] 从整数 i,j 推广到带有批次维度和枚举维度(但没有事件维度)的张量 i,j。然后我们可以使用 Vindex 编写推广版本

xij = Vindex(x)[i, :, j]

batch_shape = broadcast_shape(i.shape, j.shape)
event_shape = (x.size(1),)
assert xij.shape == batch_shape + event_shape

为了处理 x 也可能包含批次维度的情况(例如,如果 x 在平板上下文中使用向量化粒子进行采样),vindex() 使用特殊约定,Ellipsis 表示批次维度(因此 ... 只能出现在左侧,不能出现在中间或右侧)。假设 x 的事件维度为 3。那么我们可以写成

old_batch_shape = x.shape[:-3]
old_event_shape = x.shape[-3:]

xij = Vindex(x)[..., i, :, j]   # The ... denotes unknown batch shape.

new_batch_shape = broadcast_shape(old_batch_shape, i.shape, j.shape)
new_event_shape = (x.size(1),)
assert xij.shape = new_batch_shape + new_event_shape

请注意,这种对 Ellipsis 的特殊处理与 NEP [1] 不同。

形式上,此函数假设

  1. 每个参数要么是 Ellipsisslice(None)、整数,要么是批处理整数张量(即事件形状为空)。此函数不支持非平凡切片或布尔张量掩码。Ellipsis 只能作为 args[0] 出现在左侧。

  2. 如果 args[0] 不是 Ellipsis,则 tensor 未进行批处理,并且其事件维度等于 len(args)

  3. 如果 args[0] Ellipsis,则 tensor 已进行批处理,并且其事件维度等于 len(args[1:])tensor 事件维度左侧的维度被视为批次维度,并将与张量参数的维度进行广播。

请注意,如果所有参数都不是具有 len(shape) > 0 的张量,则此函数的行为与标准索引相同

if not any(isinstance(a, jnp.ndarray) and len(a.shape) > 0 for a in args):
    assert Vindex(x)[args] == x[args]

参考文献

[1] https://numpy.com.cn/neps/nep-0021-advanced-indexing.html

vindex 作为向量化索引的辅助函数引入。此实现类似于提议的符号 x.vindex[],只是对 Ellipsis 的处理略有不同。

参数:
  • tensor (jnp.ndarray) – 要索引的张量。

  • args (tuple) – 一个索引,作为 __getitem__ 的参数。

返回:

tensor[args] 的非标准解释。

返回类型:

jnp.ndarray

class Vindex(tensor)[source]

基类: object

周围 vindex() 的便利包装器。

以下是等效的

Vindex(x)[..., i, j, :]
vindex(x, (Ellipsis, i, j, slice(None)))
参数:

tensor (jnp.ndarray) – 要索引的张量。

返回:

一个带有特殊 __getitem__() 方法的对象。

模型检查

get_dependencies

get_dependencies(model: Callable, model_args: tuple | None = None, model_kwargs: dict | None = None) dict[str, object][source]

推断条件模型的依赖结构。

这会返回一个嵌套字典,结构如下

{
    "prior_dependencies": {
        "variable1": {"variable1": set()},
        "variable2": {"variable1": set(), "variable2": set()},
        ...
    },
    "posterior_dependencies": {
        "variable1": {"variable1": {"plate1"}, "variable2": set()},
        ...
    },
}

其中

  • prior_dependencies 是一个字典,将下游隐变量和观测变量映射到另一个字典,后者将它们依赖的上游隐变量映射到引起完全依赖的平板集合。也就是说,包含的平板引入二次方的依赖关系,如完全二分图,而排除的平板仅引入线性的依赖关系,如独立并行的边集。先验依赖关系遵循原始模型顺序。

  • posterior_dependencies 是一个类似的字典,但将隐变量映射到它们在后验中依赖的隐变量或观测站点。后验依赖关系与模型顺序相反。

依赖关系会省略 numpyro.deterministic 站点和 numpyro.sample(..., Delta(...)) 站点。

示例

这里是一个没有平板的简单示例。我们看到每个节点都依赖于自身,并且只有隐变量出现在后验中

def model_1():
    a = numpyro.sample("a", dist.Normal(0, 1))
    numpyro.sample("b", dist.Normal(a, 1), obs=0.0)

assert get_dependencies(model_1) == {
    "prior_dependencies": {
        "a": {"a": set()},
        "b": {"a": set(), "b": set()},
    },
    "posterior_dependencies": {
        "a": {"a": set(), "b": set()},
    },
}

这里有一个示例,其中两个变量 ab 在先验中条件独立,但在后验中由于所谓的“碰撞”变量 c 而条件依赖。ab 都影响 c`,因此在给定 c 的值时它们是条件相关的。这在图形模型文献中被称为“道德化”。

def model_2():
    a = numpyro.sample("a", dist.Normal(0, 1))
    b = numpyro.sample("b", dist.LogNormal(0, 1))
    c = numpyro.sample("c", dist.Normal(a, b))
    numpyro.sample("d", dist.Normal(c, 1), obs=0.)

assert get_dependencies(model_2) == {
    "prior_dependencies": {
        "a": {"a": set()},
        "b": {"b": set()},
        "c": {"a": set(), "b": set(), "c": set()},
        "d": {"c": set(), "d": set()},
    },
    "posterior_dependencies": {
        "a": {"a": set(), "b": set(), "c": set()},
        "b": {"b": set(), "c": set()},
        "c": {"c": set(), "d": set()},
    },
}

在存在平板的情况下,依赖关系可能更复杂。到目前为止,所有的字典值都是空的平板集合,但在下面的后验中,我们看到 c 通过平板 p 依赖于自身。这意味着,在 c 的元素中,例如 c[0] 依赖于 c[1](这就是我们明确允许变量依赖于自身的原因)

def model_3():
    with numpyro.plate("p", 5):
        a = numpyro.sample("a", dist.Normal(0, 1))
    numpyro.sample("b", dist.Normal(a.sum(), 1), obs=0.0)

assert get_dependencies(model_3) == {
    "prior_dependencies": {
        "a": {"a": set()},
        "b": {"a": set(), "b": set()},
    },
    "posterior_dependencies": {
        "a": {"a": {"p"}, "b": set()},
    },
}
[1] S.Webb, A.Goliński, R.Zinkov, N.Siddharth, T.Rainforth, Y.W.Teh, F.Wood (2018)

“生成模型的忠实反演,实现有效的摊销推理” https://dl.acm.org/doi/10.5555/3327144.3327229

参数:
  • model (callable) – 一个模型。

  • model_args (tuple) – 可选的模型参数元组 (args)。

  • model_kwargs (dict) – 可选的模型关键字参数字典 (kwargs)。

返回:

一个元数据字典(见上文)。

返回类型:

dict

get_model_relations

get_model_relations(model, model_args=None, model_kwargs=None)[source]

从给定模型和可选数据中推断 RVs 和 plate 的关系。更多详情请参阅 https://github.com/pyro-ppl/numpyro/issues/949

这将返回一个包含以下键的字典:

  • “sample_sample” 将每个下游采样站点映射到其所依赖的上游采样站点列表;

  • “sample_param” 将每个下游采样站点映射到其所依赖的上游参数站点列表;

  • “sample_dist” 将每个采样站点映射到该站点的分布名称;

  • “param_constraint” 将每个参数站点映射到该站点的约束名称;

  • “plate_sample” 将每个 plate 名称映射到该 plate 内的采样站点列表;以及

  • “observe” 是一个观测到的采样站点列表。

例如,对于模型

def model(data):
    m = numpyro.sample('m', dist.Normal(0, 1))
    sd = numpyro.sample('sd', dist.LogNormal(m, 1))
    with numpyro.plate('N', len(data)):
        numpyro.sample('obs', dist.Normal(m, sd), obs=data)

关系为

{'sample_sample': {'m': [], 'sd': ['m'], 'obs': ['m', 'sd']},
 'sample_dist': {'m': 'Normal', 'sd': 'LogNormal', 'obs': 'Normal'},
 'plate_sample': {'N': ['obs']},
 'observed': ['obs']}
参数:
  • model (callable) – 要检查的模型。

  • model_args – 可选的模型参数元组 (args)。

  • model_kwargs – 可选的模型关键字参数字典 (kwargs)。

返回类型:

dict

可视化工具

render_model

render_model(model, model_args=None, model_kwargs=None, filename=None, render_distributions=False, render_params=False)[source]

封装渲染模型所需的所有函数。

警告

此工具不支持 scan() 原语。如果要渲染时间序列模型,可以尝试使用 Python for 循环重写代码。

参数:
  • model – 要渲染的模型。

  • model_args – 传递给模型的位序参数。

  • model_kwargs – 传递给模型的关键字参数。

  • filename (str) – 用于保存渲染模型的文件的路径。

  • render_distributions (bool) – 是否在图中包含 RV 分布标注。

  • render_params (bool) – 是否在图中显示参数 (params)。

Trace 检查

format_shapes(trace: dict, *, compute_log_prob: bool = False, title: str = 'Trace Shapes:', last_site: str | None = None)[source]

给定函数的 trace,返回一个字符串,显示 trace 中所有站点的形状表格。

使用 trace handler(或用于枚举的 funsor trace handler)生成 trace。

参数:
  • trace (dict) – 要格式化的模型 trace。

  • compute_log_prob – 计算对数概率并在表中显示其形状。接受 True / False,或一个函数,该函数在给定包含站点级元数据的字典时,返回是否应计算对数概率并将其包含在表中。

  • title (str) – 形状表格的标题。

  • last_site (str | None) – 模型中站点的名称。如果提供,后续站点将不会在表中显示。

用法

def model(*args, **kwargs):
    ...

with numpyro.handlers.seed(rng_seed=1):
    trace = numpyro.handlers.trace(model).get_trace(*args, **kwargs)
print(numpyro.util.format_shapes(trace))