Pyro 原语

param

param(name: str, init_value: Array | ndarray | bool_ | number | bool | int | float | complex | Callable | None = None, **kwargs) Array | ndarray | bool_ | number | bool | int | float | complex | None[source]

将给定的站点(site)标注为可优化参数,以便与 jax.example_libraries.optimizers 一起使用。关于如何在推断算法中使用 param 语句的示例,请参考 SVI

参数:
  • name (str) – 站点的名称。

  • init_value (jnp.ndarray or callable) – 用户指定的初始值,或者一个惰性可调用对象,它接受 JAX 随机 PRNGKey 并返回一个数组。请注意,在 NumPyro 中没有全局参数存储,因此使用此值初始化优化器的责任在于用户实现的推断算法。

  • constraint (numpyro.distributions.constraints.Constraint) – NumPyro 约束,默认为 constraints.real

  • event_dim (int) – (可选)与批量处理无关的最右侧维度数量。该维度左侧的维度将被视为批量维度;如果 param 语句位于子采样 plate 内,则参数的相应批量维度将相应地被子采样。如果未指定,所有维度将被视为事件维度,并且不会执行子采样。

返回:

参数的值。除非被包裹在像 substitute 这样的 handler 中,否则这将简单地返回初始值。

sample

sample(name: str, fn: DistributionLike, obs: Array | ndarray | bool_ | number | bool | int | float | complex | None = None, rng_key: Array | ndarray | bool_ | number | bool | int | float | complex | None = None, sample_shape: tuple[int, ...] = (), infer: dict | None = None, obs_mask: Array | ndarray | bool_ | number | bool | int | float | complex | None = None) Array | ndarray | bool_ | number | bool | int | float | complex[source]

从随机函数 fn 返回一个随机样本。当被包裹在像 substitute 这样的效果处理器 (effect handler) 中时,这可能会产生额外的副作用。

注意

根据设计,sample 原语旨在 NumPyro 模型内部使用。然后使用 seed handler 向 fn 注入随机状态。在这些情况下,rng_key 关键字将不起作用。

参数:
  • name (str) – 样本站点 (sample site) 的名称。

  • fn – 返回样本的随机函数。

  • obs (jnp.ndarray) – 观测值

  • rng_key (jax.random.PRNGKey) – fn 的可选随机密钥。

  • sample_shape – 要抽取样本的形状 (shape)。

  • infer (dict) – 包含推断算法附加信息的字典(可选)。例如,如果 fn 是离散分布,设置 infer={‘enumerate’: ‘parallel’} 可以告诉 MCMC 对这个离散潜变量站点进行边际化。

  • obs_mask (jnp.ndarray) – 可选的布尔数组掩码,其形状可与 fn.batch_shape 广播。如果提供,mask=True 的事件将以 obs 为条件,其余事件将通过采样进行插补。这引入了一个名为 name + "_unobserved" 的潜变量样本站点,SVI 中的 guides 应使用它。请注意,此参数不适用于 MCMC。

返回:

从随机函数 fn 采样。

plate

class plate(name: str, size: int, subsample_size: int | None = None, dim: int | None = None)[source]

用于标注条件独立变量的构造。在 plate 上下文管理器中,sample 站点将自动广播到 plate 的大小。此外,如果指定了 subsample_size,某些推断算法可能会应用一个比例因子。

注意

这可以用于对数据进行 mini-batch 子采样

with plate("data", len(data), subsample_size=100) as ind:
    batch = data[ind]
    assert len(batch) == 100
参数:
  • name (str) – plate 的名称。

  • size (int) – plate 的大小。

  • subsample_size (int) – 可选参数,表示 mini-batch 的大小。这可以用于推断算法应用比例因子,例如当使用 mini-batch 计算 ELBO 时。

  • dim (int) – 可选参数,指定张量中的哪个维度用作 plate 维度。如果为 None(默认),则分配最右侧的可用维度。

plate_stack

plate_stack(prefix: str, sizes: list[int], rightmost_dim: int = -1) Generator[None, None, None][source]

创建一个连续堆叠的 plate,带维度

rightmost_dim - len(sizes), ..., rightmost_dim
参数:
  • prefix (str) – plate 的名称前缀。

  • sizes (iterable) – plate 大小的可迭代对象。

  • rightmost_dim (int) – 最右侧维度,从右边开始计数。

subsample

subsample(data: Array | ndarray | bool_ | number | bool | int | float | complex, event_dim: int) Array | ndarray | bool_ | number | bool | int | float | complex[source]

实验性子采样语句 (EXPERIMENTAL Subsampling statement),根据包含的 plate 对数据进行子采样。

当通过传递 subsample_size 关键字参数,由 plate 自动执行子采样时,这通常在 model() 的参数上调用。例如,以下是等效的

# Version 1. using indexing
def model(data):
    with numpyro.plate("data", len(data), subsample_size=10, dim=-data.dim()) as ind:
        data = data[ind]
        # ...

# Version 2. using numpyro.subsample()
def model(data):
    with numpyro.plate("data", len(data), subsample_size=10, dim=-data.dim()):
        data = numpyro.subsample(data, event_dim=0)
        # ...
参数:
  • data (jnp.ndarray) – 批量数据的张量。

  • event_dim (int) – 数据张量的事件维度。其左侧的维度被视为批量维度。

返回:

data 的子采样版本

返回类型:

ndarray

deterministic

deterministic(name: str, value: Array | ndarray | bool_ | number | bool | int | float | complex) Array | ndarray | bool_ | number | bool | int | float | complex[source]

用于指定模型中的确定性站点 (deterministic sites)。请注意,大多数效果处理器不会操作确定性站点(除了 trace()),因此确定性站点应无副作用。使用确定性节点的用例是在模型执行轨迹中记录任何值。

参数:
  • name (str) – 确定性站点的名称。

  • value (jnp.ndarray) – 要在轨迹中记录的确定性值。

prng_key

prng_key() Array | None[source]

一个语句,用于在 seed handler 下抽取伪随机数生成器密钥 PRNGKey()

返回:

一个形状为 (2,) 且 dtype 为 unit32 的 PRNG 密钥。

factor

factor(name: str, log_factor: Array | ndarray | bool_ | number | bool | int | float | complex) None[source]

Factor 语句,用于向概率模型添加任意对数概率因子 (log probability factor)。

参数:
  • name (str) – 简单样本 (trivial sample) 的名称。

  • log_factor (jnp.ndarray) – 一个可能经过批量处理的对数概率因子。

get_mask

get_mask() Array | ndarray | bool_ | number | bool | int | float | complex | None[source]

记录包含的 handlers.mask handler 的效果。这在预测期间非常有用,可以避免昂贵的 numpyro.factor() 计算,尤其是在不需要计算对数密度时,例如:

def model():
    # ...
    if numpyro.get_mask() is not False:
        log_density = my_expensive_computation()
        numpyro.factor("foo", log_density)
    # ...
返回:

掩码。

返回类型:

None, bool, or jnp.ndarray

module

module(name: str, nn: tuple, input_shape: tuple | None = None) Callable[source]

在模型内部声明一个 stax 风格的神经网络,以便通过 param() 语句注册其参数用于优化。

参数:
  • name (str) – 要注册模块的名称。

  • nn (tuple) – 通过 stax 构造函数获得的一个包含 (init_fn, apply_fn) 的元组。

  • input_shape (tuple) – 神经网络接受的输入形状 (shape)。

返回:

一个绑定了参数的 apply_fn,它接受一个数组作为输入并返回神经网络转换后的输出数组。

flax_module

flax_module(name, nn_module, *args, input_shape=None, apply_rng=None, mutable=None, **kwargs)[source]

在模型内部声明一个 flax 风格的神经网络,以便通过 param() 语句注册其参数用于优化。

给定一个 flax nn_module,在 flax 中使用给定的一组参数评估模块,我们使用:nn_module.apply(params, x)。在 NumPyro 模型中,模式将是

net = flax_module("net", nn_module)
y = net(x)

或者带有 dropout 层

net = flax_module("net", nn_module, apply_rng=["dropout"])
rng_key = numpyro.prng_key()
y = net(x, rngs={"dropout": rng_key})
参数:
  • name (str) – 要注册模块的名称。

  • nn_module (flax.linen.Module) – 一个拥有 .init 和 .apply 方法的 flax 模块

  • args – 初始化 flax 神经网络的可选参数,作为 input_shape 的替代方案

  • input_shape (tuple) – 神经网络接受的输入形状 (shape)。

  • apply_rng (list) – 一个列表,指示 nn_module 需要哪些额外的 rng _种类_。例如,当 nn_module 包含 dropout 层时,我们需要设置 apply_rng=["dropout"]。默认为 None,表示不需要额外的 rng 密钥。有关 Flax 如何处理 dropout 等随机层的更多信息,请参阅 Flax Linen Intro

  • mutable (list) – 一个列表,指示 nn_module 的可变状态 (mutable states)。例如,如果你的模块有 BatchNorm 层,我们需要定义 mutable=["batch_stats"]。有关更多信息,请参阅上述 Flax Linen Intro 教程。

  • 一个绑定了参数的可调用对象,它接受一个数组作为输入并返回神经网络转换后的输出数组。

返回:

haiku_module

haiku_module(name, nn_module, *args, input_shape=None, apply_rng=False, **kwargs)[source]

在模型内部声明一个 haiku 风格的神经网络,以便通过 param() 语句注册其参数用于优化。

给定一个 haiku nn_module,在 haiku 中使用给定的一组参数评估模块,我们使用:nn_module.apply(params, None, x)。在 NumPyro 模型中,模式将是

对于一个haiku nn_module,在haiku中,要用给定参数集评估模块,我们使用: nn_module.apply(params, None, x)。在NumPyro模型中,模式将是

net = haiku_module("net", nn_module)
y = net(x)  # or y = net(rng_key, x)

或者带有 dropout 层

net = haiku_module("net", nn_module, apply_rng=True)
rng_key = numpyro.prng_key()
y = net(rng_key, x)
参数:
  • name (str) – 要注册模块的名称。

  • nn_module (haiku.Transformed or haiku.TransformedWithState) – 一个拥有 .init 和 .apply 方法的 haiku 模块

  • args – 初始化 flax 神经网络的可选参数,作为 input_shape 的替代方案

  • input_shape (tuple) – 神经网络接受的输入形状 (shape)。

  • apply_rng (bool) – 一个标志,指示返回的可调用对象是否需要一个 rng 参数(例如,当 nn_module 包含 dropout 层时)。默认为 False,表示不需要 rng 参数。如果为 True,则返回的可调用对象的签名 nn = haiku_module(..., apply_rng=True) 将是 nn(rng_key, x)(而不是 nn(x))。

  • 一个绑定了参数的可调用对象,它接受一个数组作为输入并返回神经网络转换后的输出数组。

返回:

haiku_module

nnx_module

nnx_module(name, nn_module)[source]

在模型内部声明一个 nnx 风格的神经网络,以便通过 param() 语句注册其参数用于优化。

给定一个 flax NNX nn_module,要评估模块,我们直接调用它。在 NumPyro 模型中,模式将是

# Eager initialization outside the model
module = nn_module(...)

# Inside the model
net = nnx_module("net", module)
y = net(x)
参数:
  • name (str) – 要注册模块的名称。

  • nn_module (flax.nnx.Module) – 一个预初始化的 flax nnx 模块实例。

返回:

一个接受数组作为输入并返回神经网络转换后输出数组的可调用对象。

random_flax_module

random_flax_module(name, nn_module, prior, *args, input_shape=None, apply_rng=None, mutable=None, **kwargs)[source]

一个原语,用于为 Flax 模块 nn_module 的参数设置先验。

注意

Flax 模块的参数存储在嵌套字典中。例如,定义如下的模块 B

class A(flax.linen.Module):
    @flax.linen.compact
    def __call__(self, x):
        return nn.Dense(1, use_bias=False, name='dense')(x)

class B(flax.linen.Module):
    @flax.linen.compact
    def __call__(self, x):
        return A(name='inner')(x)

其参数为 {‘inner’: {‘dense’: {‘kernel’: param_value}}}。在参数 prior 中,要指定 kernel 参数,我们使用点连接路径:prior={“inner.dense.kernel”: param_prior}

参数:
  • name (str) – NumPyro 模块的名称

  • flax.linen.Module – 要在 NumPyro 中注册的模块

  • prior (dict, Distribution or callable) –

    一个 NumPyro 分布,或者一个 Python 字典,其键是参数名称,值是相应的分布。例如

    net = random_flax_module("net",
                             flax.linen.Dense(features=1),
                             prior={"bias": dist.Cauchy(), "kernel": dist.Normal()},
                             input_shape=(4,))
    

    或者,我们可以使用可调用对象。例如,以下是等效的

    prior=(lambda name, shape: dist.Cauchy() if name == "bias" else dist.Normal())
    prior={"bias": dist.Cauchy(), "kernel": dist.Normal()}
    

  • args – 初始化 flax 神经网络的可选参数,作为 input_shape 的替代方案

  • input_shape (tuple) – 神经网络接受的输入形状 (shape)。

  • apply_rng (list) –

    一个列表,指示 nn_module 需要哪些额外的 rng _种类_。例如,当 nn_module 包含 dropout 层时,我们需要设置 apply_rng=["dropout"]。默认为 None,表示不需要额外的 rng 密钥。有关 Flax 如何处理 dropout 等随机层的更多信息,请参阅 Flax Linen Intro

  • mutable (list) – 一个列表,指示 nn_module 的可变状态 (mutable states)。例如,如果你的模块有 BatchNorm 层,我们需要定义 mutable=["batch_stats"]。有关更多信息,请参阅上述 Flax Linen Intro 教程。

  • 一个绑定了参数的可调用对象,它接受一个数组作为输入并返回神经网络转换后的输出数组。

返回:

一个采样的模块

示例

# NB: this example is ported from https://github.com/ctallec/pyvarinf/blob/master/main_regression.ipynb
>>> import numpy as np; np.random.seed(0)
>>> import tqdm
>>> from flax import linen as nn
>>> from jax import jit, random
>>> import numpyro
>>> import numpyro.distributions as dist
>>> from numpyro.contrib.module import random_flax_module
>>> from numpyro.infer import Predictive, SVI, TraceMeanField_ELBO, autoguide, init_to_feasible
...
>>> class Net(nn.Module):
...     n_units: int
...
...     @nn.compact
...     def __call__(self, x):
...         x = nn.Dense(self.n_units)(x[..., None])
...         x = nn.relu(x)
...         x = nn.Dense(self.n_units)(x)
...         x = nn.relu(x)
...         mean = nn.Dense(1)(x)
...         rho = nn.Dense(1)(x)
...         return mean.squeeze(), rho.squeeze()
...
>>> def generate_data(n_samples):
...     x = np.random.normal(size=n_samples)
...     y = np.cos(x * 3) + np.random.normal(size=n_samples) * np.abs(x) / 2
...     return x, y
...
>>> def model(x, y=None, batch_size=None):
...     module = Net(n_units=32)
...     net = random_flax_module("nn", module, dist.Normal(0, 0.1), input_shape=())
...     with numpyro.plate("batch", x.shape[0], subsample_size=batch_size):
...         batch_x = numpyro.subsample(x, event_dim=0)
...         batch_y = numpyro.subsample(y, event_dim=0) if y is not None else None
...         mean, rho = net(batch_x)
...         sigma = nn.softplus(rho)
...         numpyro.sample("obs", dist.Normal(mean, sigma), obs=batch_y)
...
>>> n_train_data = 5000
>>> x_train, y_train = generate_data(n_train_data)
>>> guide = autoguide.AutoNormal(model, init_loc_fn=init_to_feasible)
>>> svi = SVI(model, guide, numpyro.optim.Adam(5e-3), TraceMeanField_ELBO())
>>> n_iterations = 3000
>>> svi_result = svi.run(random.PRNGKey(0), n_iterations, x_train, y_train, batch_size=256)
>>> params, losses = svi_result.params, svi_result.losses
>>> n_test_data = 100
>>> x_test, y_test = generate_data(n_test_data)
>>> predictive = Predictive(model, guide=guide, params=params, num_samples=1000)
>>> y_pred = predictive(random.PRNGKey(1), x_test[:100])["obs"].copy()
>>> assert losses[-1] < 3000
>>> assert np.sqrt(np.mean(np.square(y_test - y_pred))) < 1

random_haiku_module

random_haiku_module(name, nn_module, prior, *args, input_shape=None, apply_rng=False, **kwargs)[source]

一个原语,用于为 Haiku 模块 nn_module 的参数设置先验。

参数:
  • name (str) – NumPyro 模块的名称

  • nn_module (haiku.Transformed or haiku.TransformedWithState) – 要在 NumPyro 中注册的模块

  • prior (dict, Distribution or callable) –

    一个 NumPyro 分布,或者一个 Python 字典,其键是参数名称,值是相应的分布。例如

    net = random_haiku_module("net",
                              haiku.transform(lambda x: hk.Linear(1)(x)),
                              prior={"linear.b": dist.Cauchy(), "linear.w": dist.Normal()},
                              input_shape=(4,))
    

    或者,我们可以使用可调用对象。例如,以下是等效的

    prior=(lambda name, shape: dist.Cauchy() if name.startswith("b") else dist.Normal())
    prior={"bias": dist.Cauchy(), "kernel": dist.Normal()}
    

  • args – 初始化 flax 神经网络的可选参数,作为 input_shape 的替代方案

  • input_shape (tuple) – 神经网络接受的输入形状 (shape)。

  • apply_rng (bool) – 一个标志,指示返回的可调用对象是否需要一个 rng 参数(例如,当 nn_module 包含 dropout 层时)。默认为 False,表示不需要 rng 参数。如果为 True,则返回的可调用对象的签名 nn = haiku_module(..., apply_rng=True) 将是 nn(rng_key, x)(而不是 nn(x))。

  • 一个绑定了参数的可调用对象,它接受一个数组作为输入并返回神经网络转换后的输出数组。

返回:

一个采样的模块

random_nnx_module

random_nnx_module(name, nn_module, prior)[source]

一个原语,用于创建一个随机的 nnx 风格神经网络,可用于 MCMC 采样器。神经网络的参数将从 prior 中采样。

参数:
  • name (str) – 要注册模块的名称。

  • nn_module (flax.nnx.Module) – 一个预初始化的 flax nnx 模块实例。

  • prior

    一个分布、一个分布字典或一个可调用对象。如果它是一个分布,所有参数将从同一分布中采样。如果它是一个字典,它将参数名称映射到分布。如果它是一个可调用对象,它接受参数名称和参数形状作为输入并返回一个分布。例如

    class Linear(nnx.Module):
        def __init__(self, din, dout, *, rngs):
            self.w = nnx.Param(jax.random.uniform(rngs.params(), (din, dout)))
            self.b = nnx.Param(jnp.zeros((dout,)))
    
        def __call__(self, x):
            return x @ self.w + self.b
    
    # Eager initialization
    linear = Linear(din=4, dout=1, rngs=nnx.Rngs(params=random.PRNGKey(0)))
    net = random_nnx_module("net", linear, prior={"w": dist.Normal(), "b": dist.Cauchy()})
    

    或者,我们可以使用可调用对象。例如,以下是等效的

    prior=(lambda name, shape: dist.Cauchy() if name.endswith("b") else dist.Normal())
    prior={"w": dist.Normal(), "b": dist.Cauchy()}
    

返回:

一个接受数组作为输入并返回神经网络转换后输出数组的可调用对象。

scan

scan(f: Callable, init, xs, length: int | None = None, reverse: bool = False, history: int = 1)[source]

这个原语在 xs 的前导数组轴上扫描一个函数,同时携带状态。有关更多信息,请参阅 jax.lax.scan()

用法:

>>> import numpy as np
>>> import numpyro
>>> import numpyro.distributions as dist
>>> from numpyro.contrib.control_flow import scan
>>>
>>> def gaussian_hmm(y=None, T=10):
...     def transition(x_prev, y_curr):
...         x_curr = numpyro.sample('x', dist.Normal(x_prev, 1))
...         y_curr = numpyro.sample('y', dist.Normal(x_curr, 1), obs=y_curr)
...         return x_curr, (x_curr, y_curr)
...
...     x0 = numpyro.sample('x_0', dist.Normal(0, 1))
...     _, (x, y) = scan(transition, x0, y, length=T)
...     return (x, y)
>>>
>>> # here we do some quick tests
>>> with numpyro.handlers.seed(rng_seed=0):
...     x, y = gaussian_hmm(np.arange(10.))
>>> assert x.shape == (10,) and y.shape == (10,)
>>> assert np.all(y == np.arange(10))
>>>
>>> with numpyro.handlers.seed(rng_seed=0):  # generative
...     x, y = gaussian_hmm()
>>> assert x.shape == (10,) and y.shape == (10,)

警告

这是一个实验性实用函数,允许用户将 JAX 控制流与 NumPyro 的效果处理器一起使用。目前,支持 scan 主体 f 内的 sampledeterministic 站点。如果你发现任何效果处理器或分布不受支持,请提交 issue。

注意

plate 上下文内对齐 scan 维度是模糊的。因此不支持以下模式

with numpyro.plate('N', 10):
    last, ys = scan(f, init, xs)

所有 plate 语句都应放在 f 内部。例如,相应的可工作代码是

def g(*args, **kwargs):
    with numpyro.plate('N', 10):
        return f(*arg, **kwargs)

last, ys = scan(g, init, xs)

注意

目前不支持嵌套的 scan。

注意

我们可以在 f 中对离散潜变量进行 scan。联合密度使用时间维度上的并行 scan(参考 [1])进行评估,这降低了并行复杂度至 O(log(length))

带有离散潜变量的 scantrace 将包含以下站点

  • 初始化点(init sites):这些点属于 f 的前 history 个跟踪(traces)。第 i 个跟踪的点名称将以前缀 ‘_PREV_’ * (2 * history - 1 - i) 开头。

  • 扫描点(scanned sites):这些点收集 f 的剩余扫描循环的值。一个额外的时间维度 _time_foo 将被添加到这些点,其中 foo 是出现在 f 中的第一个点的名称。

并非所有转换函数 f 都受支持。Pyro枚举教程 [2] 中的所有限制在此仍然适用。此外,scan 外部不应有任何点依赖于 scan 的第一个输出(最后一个 carry 值)。

参考资料

  1. 贝叶斯平滑器的时序并行化, Simo Sarkka, Angel F. Garcia-Fernandez (https://arxiv.org/abs/1905.13002)

  2. 离散潜在变量的推断 (https://pyro.org.cn/examples/enumeration.html#Dependencies-among-plates)

参数:
  • f (callable) – 要扫描的函数。

  • init – 初始 carrying 状态。

  • xs – 我们沿着主轴扫描的值。这可以是任何 JAX pytree(例如数组的列表/字典)。

  • length (int | None) – 可选值,指定 xs 的长度,但当 xs 是空 pytree(例如 None)时可以使用。

  • reverse (bool) – 可选布尔值,指定是正向(默认)还是反向运行扫描迭代。

  • history (int) – 当前上下文可见的先前上下文的数量。默认为 1。如果为零,则类似于 numpyro.plate

返回:

scan 的输出,引用自 jax.lax.scan() 文档:“类型为 (c, [b]) 的对,其中第一个元素表示最终的循环 carrying 值,第二个元素表示 f 的第二个输出在沿输入的主轴扫描时堆叠的结果”。

cond

cond(pred: bool, true_fun: Callable, false_fun: Callable, operand: Any) Any[source]

此原语根据条件应用 true_funfalse_fun。有关更多信息,请参阅 jax.lax.cond()

用法:

>>> import numpyro
>>> import numpyro.distributions as dist
>>> from jax import random
>>> from numpyro.contrib.control_flow import cond
>>> from numpyro.infer import SVI, Trace_ELBO
>>>
>>> def model():
...     def true_fun(_):
...         return numpyro.sample("x", dist.Normal(20.0))
...
...     def false_fun(_):
...         return numpyro.sample("x", dist.Normal(0.0))
...
...     cluster = numpyro.sample("cluster", dist.Normal())
...     return cond(cluster > 0, true_fun, false_fun, None)
>>>
>>> def guide():
...     m1 = numpyro.param("m1", 10.0)
...     s1 = numpyro.param("s1", 0.1, constraint=dist.constraints.positive)
...     m2 = numpyro.param("m2", 10.0)
...     s2 = numpyro.param("s2", 0.1, constraint=dist.constraints.positive)
...
...     def true_fun(_):
...         return numpyro.sample("x", dist.Normal(m1, s1))
...
...     def false_fun(_):
...         return numpyro.sample("x", dist.Normal(m2, s2))
...
...     cluster = numpyro.sample("cluster", dist.Normal())
...     return cond(cluster > 0, true_fun, false_fun, None)
>>>
>>> svi = SVI(model, guide, numpyro.optim.Adam(1e-2), Trace_ELBO(num_particles=100))
>>> svi_result = svi.run(random.PRNGKey(0), num_steps=2500)

警告

这是一个实验性实用函数,允许用户将 JAX 控制流与 NumPyro 的效果处理器一起使用。目前,支持 true_funfalse_fun 中的 sampledeterministic 点。如果您发现任何效果处理器或分布不受支持,请提交 issue。

警告

当前,cond 原语不支持枚举,也不能在 numpyro.plate 上下文中使用。

注意

所有 sample 点必须属于同一分布类。例如,以下是不受支持的

cond(
    True,
    lambda _: numpyro.sample("x", dist.Normal()),
    lambda _: numpyro.sample("x", dist.Laplace()),
    None,
)
参数:
  • pred (bool) – 布尔标量类型,指示应用哪个分支函数。

  • true_fun (callable) – 当 pred 为真时应用的函数。

  • false_fun (callable) – 当 pred 为假时应用的函数。

  • operand – 根据 pred 应用到任一分支的操作数输入。这可以是任何 JAX PyTree(例如数组的列表/字典)。

返回:

应用的分支函数的输出。