Pyro 原语
param
- param(name: str, init_value: Array | ndarray | bool_ | number | bool | int | float | complex | Callable | None = None, **kwargs) Array | ndarray | bool_ | number | bool | int | float | complex | None [source]
将给定的站点(site)标注为可优化参数,以便与
jax.example_libraries.optimizers
一起使用。关于如何在推断算法中使用 param 语句的示例,请参考SVI
。- 参数:
name (str) – 站点的名称。
init_value (jnp.ndarray or callable) – 用户指定的初始值,或者一个惰性可调用对象,它接受 JAX 随机 PRNGKey 并返回一个数组。请注意,在 NumPyro 中没有全局参数存储,因此使用此值初始化优化器的责任在于用户实现的推断算法。
constraint (numpyro.distributions.constraints.Constraint) – NumPyro 约束,默认为
constraints.real
。event_dim (int) – (可选)与批量处理无关的最右侧维度数量。该维度左侧的维度将被视为批量维度;如果 param 语句位于子采样 plate 内,则参数的相应批量维度将相应地被子采样。如果未指定,所有维度将被视为事件维度,并且不会执行子采样。
- 返回:
参数的值。除非被包裹在像
substitute
这样的 handler 中,否则这将简单地返回初始值。
sample
- sample(name: str, fn: DistributionLike, obs: Array | ndarray | bool_ | number | bool | int | float | complex | None = None, rng_key: Array | ndarray | bool_ | number | bool | int | float | complex | None = None, sample_shape: tuple[int, ...] = (), infer: dict | None = None, obs_mask: Array | ndarray | bool_ | number | bool | int | float | complex | None = None) Array | ndarray | bool_ | number | bool | int | float | complex [source]
从随机函数 fn 返回一个随机样本。当被包裹在像
substitute
这样的效果处理器 (effect handler) 中时,这可能会产生额外的副作用。注意
根据设计,sample 原语旨在 NumPyro 模型内部使用。然后使用
seed
handler 向 fn 注入随机状态。在这些情况下,rng_key 关键字将不起作用。- 参数:
name (str) – 样本站点 (sample site) 的名称。
fn – 返回样本的随机函数。
obs (jnp.ndarray) – 观测值
rng_key (jax.random.PRNGKey) – fn 的可选随机密钥。
sample_shape – 要抽取样本的形状 (shape)。
infer (dict) – 包含推断算法附加信息的字典(可选)。例如,如果 fn 是离散分布,设置 infer={‘enumerate’: ‘parallel’} 可以告诉 MCMC 对这个离散潜变量站点进行边际化。
obs_mask (jnp.ndarray) – 可选的布尔数组掩码,其形状可与
fn.batch_shape
广播。如果提供,mask=True 的事件将以obs
为条件,其余事件将通过采样进行插补。这引入了一个名为name + "_unobserved"
的潜变量样本站点,SVI 中的 guides 应使用它。请注意,此参数不适用于 MCMC。
- 返回:
从随机函数 fn 采样。
plate
- class plate(name: str, size: int, subsample_size: int | None = None, dim: int | None = None)[source]
用于标注条件独立变量的构造。在 plate 上下文管理器中,sample 站点将自动广播到 plate 的大小。此外,如果指定了 subsample_size,某些推断算法可能会应用一个比例因子。
注意
这可以用于对数据进行 mini-batch 子采样
with plate("data", len(data), subsample_size=100) as ind: batch = data[ind] assert len(batch) == 100
plate_stack
subsample
- subsample(data: Array | ndarray | bool_ | number | bool | int | float | complex, event_dim: int) Array | ndarray | bool_ | number | bool | int | float | complex [source]
实验性子采样语句 (EXPERIMENTAL Subsampling statement),根据包含的
plate
对数据进行子采样。当通过传递
subsample_size
关键字参数,由plate
自动执行子采样时,这通常在model()
的参数上调用。例如,以下是等效的# Version 1. using indexing def model(data): with numpyro.plate("data", len(data), subsample_size=10, dim=-data.dim()) as ind: data = data[ind] # ... # Version 2. using numpyro.subsample() def model(data): with numpyro.plate("data", len(data), subsample_size=10, dim=-data.dim()): data = numpyro.subsample(data, event_dim=0) # ...
- 参数:
data (jnp.ndarray) – 批量数据的张量。
event_dim (int) – 数据张量的事件维度。其左侧的维度被视为批量维度。
- 返回:
data
的子采样版本- 返回类型:
ndarray
deterministic
- deterministic(name: str, value: Array | ndarray | bool_ | number | bool | int | float | complex) Array | ndarray | bool_ | number | bool | int | float | complex [source]
用于指定模型中的确定性站点 (deterministic sites)。请注意,大多数效果处理器不会操作确定性站点(除了
trace()
),因此确定性站点应无副作用。使用确定性节点的用例是在模型执行轨迹中记录任何值。- 参数:
name (str) – 确定性站点的名称。
value (jnp.ndarray) – 要在轨迹中记录的确定性值。
prng_key
factor
get_mask
- get_mask() Array | ndarray | bool_ | number | bool | int | float | complex | None [source]
记录包含的
handlers.mask
handler 的效果。这在预测期间非常有用,可以避免昂贵的numpyro.factor()
计算,尤其是在不需要计算对数密度时,例如:def model(): # ... if numpyro.get_mask() is not False: log_density = my_expensive_computation() numpyro.factor("foo", log_density) # ...
- 返回:
掩码。
- 返回类型:
None, bool, or jnp.ndarray
module
flax_module
- flax_module(name, nn_module, *args, input_shape=None, apply_rng=None, mutable=None, **kwargs)[source]
在模型内部声明一个
flax
风格的神经网络,以便通过param()
语句注册其参数用于优化。给定一个 flax
nn_module
,在 flax 中使用给定的一组参数评估模块,我们使用:nn_module.apply(params, x)
。在 NumPyro 模型中,模式将是net = flax_module("net", nn_module) y = net(x)
或者带有 dropout 层
net = flax_module("net", nn_module, apply_rng=["dropout"]) rng_key = numpyro.prng_key() y = net(x, rngs={"dropout": rng_key})
- 参数:
name (str) – 要注册模块的名称。
nn_module (flax.linen.Module) – 一个拥有 .init 和 .apply 方法的 flax 模块
args – 初始化 flax 神经网络的可选参数,作为 input_shape 的替代方案
input_shape (tuple) – 神经网络接受的输入形状 (shape)。
apply_rng (list) – 一个列表,指示
nn_module
需要哪些额外的 rng _种类_。例如,当nn_module
包含 dropout 层时,我们需要设置apply_rng=["dropout"]
。默认为 None,表示不需要额外的 rng 密钥。有关 Flax 如何处理 dropout 等随机层的更多信息,请参阅 Flax Linen Intro。mutable (list) – 一个列表,指示
nn_module
的可变状态 (mutable states)。例如,如果你的模块有 BatchNorm 层,我们需要定义mutable=["batch_stats"]
。有关更多信息,请参阅上述 Flax Linen Intro 教程。一个绑定了参数的可调用对象,它接受一个数组作为输入并返回神经网络转换后的输出数组。
- 返回:
haiku_module
haiku_module(name, nn_module, *args, input_shape=None, apply_rng=False, **kwargs)[source]
-
在模型内部声明一个
haiku
风格的神经网络,以便通过param()
语句注册其参数用于优化。 给定一个 haiku
nn_module
,在 haiku 中使用给定的一组参数评估模块,我们使用:nn_module.apply(params, None, x)
。在 NumPyro 模型中,模式将是对于一个haiku
nn_module
,在haiku中,要用给定参数集评估模块,我们使用:nn_module.apply(params, None, x)
。在NumPyro模型中,模式将是net = haiku_module("net", nn_module) y = net(x) # or y = net(rng_key, x)
或者带有 dropout 层
net = haiku_module("net", nn_module, apply_rng=True) rng_key = numpyro.prng_key() y = net(rng_key, x)
- 参数:
name (str) – 要注册模块的名称。
nn_module (haiku.Transformed or haiku.TransformedWithState) – 一个拥有 .init 和 .apply 方法的 haiku 模块
args – 初始化 flax 神经网络的可选参数,作为 input_shape 的替代方案
input_shape (tuple) – 神经网络接受的输入形状 (shape)。
apply_rng (bool) – 一个标志,指示返回的可调用对象是否需要一个 rng 参数(例如,当
nn_module
包含 dropout 层时)。默认为 False,表示不需要 rng 参数。如果为 True,则返回的可调用对象的签名nn = haiku_module(..., apply_rng=True)
将是nn(rng_key, x)
(而不是nn(x)
)。一个绑定了参数的可调用对象,它接受一个数组作为输入并返回神经网络转换后的输出数组。
- 返回:
haiku_module
nnx_module
- nnx_module(name, nn_module)[source]
在模型内部声明一个
nnx
风格的神经网络,以便通过param()
语句注册其参数用于优化。给定一个 flax NNX
nn_module
,要评估模块,我们直接调用它。在 NumPyro 模型中,模式将是# Eager initialization outside the model module = nn_module(...) # Inside the model net = nnx_module("net", module) y = net(x)
- 参数:
name (str) – 要注册模块的名称。
nn_module (flax.nnx.Module) – 一个预初始化的 flax nnx 模块实例。
- 返回:
一个接受数组作为输入并返回神经网络转换后输出数组的可调用对象。
random_flax_module
- random_flax_module(name, nn_module, prior, *args, input_shape=None, apply_rng=None, mutable=None, **kwargs)[source]
一个原语,用于为 Flax 模块 nn_module 的参数设置先验。
注意
Flax 模块的参数存储在嵌套字典中。例如,定义如下的模块 B
class A(flax.linen.Module): @flax.linen.compact def __call__(self, x): return nn.Dense(1, use_bias=False, name='dense')(x) class B(flax.linen.Module): @flax.linen.compact def __call__(self, x): return A(name='inner')(x)
其参数为 {‘inner’: {‘dense’: {‘kernel’: param_value}}}。在参数 prior 中,要指定 kernel 参数,我们使用点连接路径:prior={“inner.dense.kernel”: param_prior}。
- 参数:
name (str) – NumPyro 模块的名称
flax.linen.Module – 要在 NumPyro 中注册的模块
prior (dict, Distribution or callable) –
一个 NumPyro 分布,或者一个 Python 字典,其键是参数名称,值是相应的分布。例如
net = random_flax_module("net", flax.linen.Dense(features=1), prior={"bias": dist.Cauchy(), "kernel": dist.Normal()}, input_shape=(4,))
或者,我们可以使用可调用对象。例如,以下是等效的
prior=(lambda name, shape: dist.Cauchy() if name == "bias" else dist.Normal()) prior={"bias": dist.Cauchy(), "kernel": dist.Normal()}
args – 初始化 flax 神经网络的可选参数,作为 input_shape 的替代方案
input_shape (tuple) – 神经网络接受的输入形状 (shape)。
apply_rng (list) –
一个列表,指示
nn_module
需要哪些额外的 rng _种类_。例如,当nn_module
包含 dropout 层时,我们需要设置apply_rng=["dropout"]
。默认为 None,表示不需要额外的 rng 密钥。有关 Flax 如何处理 dropout 等随机层的更多信息,请参阅 Flax Linen Intro。mutable (list) – 一个列表,指示
nn_module
的可变状态 (mutable states)。例如,如果你的模块有 BatchNorm 层,我们需要定义mutable=["batch_stats"]
。有关更多信息,请参阅上述 Flax Linen Intro 教程。一个绑定了参数的可调用对象,它接受一个数组作为输入并返回神经网络转换后的输出数组。
- 返回:
一个采样的模块
示例
# NB: this example is ported from https://github.com/ctallec/pyvarinf/blob/master/main_regression.ipynb >>> import numpy as np; np.random.seed(0) >>> import tqdm >>> from flax import linen as nn >>> from jax import jit, random >>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.contrib.module import random_flax_module >>> from numpyro.infer import Predictive, SVI, TraceMeanField_ELBO, autoguide, init_to_feasible ... >>> class Net(nn.Module): ... n_units: int ... ... @nn.compact ... def __call__(self, x): ... x = nn.Dense(self.n_units)(x[..., None]) ... x = nn.relu(x) ... x = nn.Dense(self.n_units)(x) ... x = nn.relu(x) ... mean = nn.Dense(1)(x) ... rho = nn.Dense(1)(x) ... return mean.squeeze(), rho.squeeze() ... >>> def generate_data(n_samples): ... x = np.random.normal(size=n_samples) ... y = np.cos(x * 3) + np.random.normal(size=n_samples) * np.abs(x) / 2 ... return x, y ... >>> def model(x, y=None, batch_size=None): ... module = Net(n_units=32) ... net = random_flax_module("nn", module, dist.Normal(0, 0.1), input_shape=()) ... with numpyro.plate("batch", x.shape[0], subsample_size=batch_size): ... batch_x = numpyro.subsample(x, event_dim=0) ... batch_y = numpyro.subsample(y, event_dim=0) if y is not None else None ... mean, rho = net(batch_x) ... sigma = nn.softplus(rho) ... numpyro.sample("obs", dist.Normal(mean, sigma), obs=batch_y) ... >>> n_train_data = 5000 >>> x_train, y_train = generate_data(n_train_data) >>> guide = autoguide.AutoNormal(model, init_loc_fn=init_to_feasible) >>> svi = SVI(model, guide, numpyro.optim.Adam(5e-3), TraceMeanField_ELBO()) >>> n_iterations = 3000 >>> svi_result = svi.run(random.PRNGKey(0), n_iterations, x_train, y_train, batch_size=256) >>> params, losses = svi_result.params, svi_result.losses >>> n_test_data = 100 >>> x_test, y_test = generate_data(n_test_data) >>> predictive = Predictive(model, guide=guide, params=params, num_samples=1000) >>> y_pred = predictive(random.PRNGKey(1), x_test[:100])["obs"].copy() >>> assert losses[-1] < 3000 >>> assert np.sqrt(np.mean(np.square(y_test - y_pred))) < 1
random_haiku_module
- random_haiku_module(name, nn_module, prior, *args, input_shape=None, apply_rng=False, **kwargs)[source]
一个原语,用于为 Haiku 模块 nn_module 的参数设置先验。
- 参数:
name (str) – NumPyro 模块的名称
nn_module (haiku.Transformed or haiku.TransformedWithState) – 要在 NumPyro 中注册的模块
prior (dict, Distribution or callable) –
一个 NumPyro 分布,或者一个 Python 字典,其键是参数名称,值是相应的分布。例如
net = random_haiku_module("net", haiku.transform(lambda x: hk.Linear(1)(x)), prior={"linear.b": dist.Cauchy(), "linear.w": dist.Normal()}, input_shape=(4,))
或者,我们可以使用可调用对象。例如,以下是等效的
prior=(lambda name, shape: dist.Cauchy() if name.startswith("b") else dist.Normal()) prior={"bias": dist.Cauchy(), "kernel": dist.Normal()}
args – 初始化 flax 神经网络的可选参数,作为 input_shape 的替代方案
input_shape (tuple) – 神经网络接受的输入形状 (shape)。
apply_rng (bool) – 一个标志,指示返回的可调用对象是否需要一个 rng 参数(例如,当
nn_module
包含 dropout 层时)。默认为 False,表示不需要 rng 参数。如果为 True,则返回的可调用对象的签名nn = haiku_module(..., apply_rng=True)
将是nn(rng_key, x)
(而不是nn(x)
)。一个绑定了参数的可调用对象,它接受一个数组作为输入并返回神经网络转换后的输出数组。
- 返回:
一个采样的模块
random_nnx_module
- random_nnx_module(name, nn_module, prior)[source]
一个原语,用于创建一个随机的
nnx
风格神经网络,可用于 MCMC 采样器。神经网络的参数将从prior
中采样。- 参数:
name (str) – 要注册模块的名称。
nn_module (flax.nnx.Module) – 一个预初始化的 flax nnx 模块实例。
prior –
一个分布、一个分布字典或一个可调用对象。如果它是一个分布,所有参数将从同一分布中采样。如果它是一个字典,它将参数名称映射到分布。如果它是一个可调用对象,它接受参数名称和参数形状作为输入并返回一个分布。例如
class Linear(nnx.Module): def __init__(self, din, dout, *, rngs): self.w = nnx.Param(jax.random.uniform(rngs.params(), (din, dout))) self.b = nnx.Param(jnp.zeros((dout,))) def __call__(self, x): return x @ self.w + self.b # Eager initialization linear = Linear(din=4, dout=1, rngs=nnx.Rngs(params=random.PRNGKey(0))) net = random_nnx_module("net", linear, prior={"w": dist.Normal(), "b": dist.Cauchy()})
或者,我们可以使用可调用对象。例如,以下是等效的
prior=(lambda name, shape: dist.Cauchy() if name.endswith("b") else dist.Normal()) prior={"w": dist.Normal(), "b": dist.Cauchy()}
- 返回:
一个接受数组作为输入并返回神经网络转换后输出数组的可调用对象。
scan
- scan(f: Callable, init, xs, length: int | None = None, reverse: bool = False, history: int = 1)[source]
这个原语在 xs 的前导数组轴上扫描一个函数,同时携带状态。有关更多信息,请参阅
jax.lax.scan()
。用法:
>>> import numpy as np >>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.contrib.control_flow import scan >>> >>> def gaussian_hmm(y=None, T=10): ... def transition(x_prev, y_curr): ... x_curr = numpyro.sample('x', dist.Normal(x_prev, 1)) ... y_curr = numpyro.sample('y', dist.Normal(x_curr, 1), obs=y_curr) ... return x_curr, (x_curr, y_curr) ... ... x0 = numpyro.sample('x_0', dist.Normal(0, 1)) ... _, (x, y) = scan(transition, x0, y, length=T) ... return (x, y) >>> >>> # here we do some quick tests >>> with numpyro.handlers.seed(rng_seed=0): ... x, y = gaussian_hmm(np.arange(10.)) >>> assert x.shape == (10,) and y.shape == (10,) >>> assert np.all(y == np.arange(10)) >>> >>> with numpyro.handlers.seed(rng_seed=0): # generative ... x, y = gaussian_hmm() >>> assert x.shape == (10,) and y.shape == (10,)
警告
这是一个实验性实用函数,允许用户将 JAX 控制流与 NumPyro 的效果处理器一起使用。目前,支持 scan 主体 f 内的 sample 和 deterministic 站点。如果你发现任何效果处理器或分布不受支持,请提交 issue。
注意
在 plate 上下文内对齐 scan 维度是模糊的。因此不支持以下模式
with numpyro.plate('N', 10): last, ys = scan(f, init, xs)
所有 plate 语句都应放在 f 内部。例如,相应的可工作代码是
def g(*args, **kwargs): with numpyro.plate('N', 10): return f(*arg, **kwargs) last, ys = scan(g, init, xs)
注意
目前不支持嵌套的 scan。
注意
我们可以在 f 中对离散潜变量进行 scan。联合密度使用时间维度上的并行 scan(参考 [1])进行评估,这降低了并行复杂度至 O(log(length))。
带有离散潜变量的 scan 的
trace
将包含以下站点初始化点(init sites):这些点属于 f 的前 history 个跟踪(traces)。第 i 个跟踪的点名称将以前缀 ‘_PREV_’ * (2 * history - 1 - i) 开头。
扫描点(scanned sites):这些点收集 f 的剩余扫描循环的值。一个额外的时间维度 _time_foo 将被添加到这些点,其中 foo 是出现在 f 中的第一个点的名称。
并非所有转换函数 f 都受支持。Pyro枚举教程 [2] 中的所有限制在此仍然适用。此外,scan 外部不应有任何点依赖于 scan 的第一个输出(最后一个 carry 值)。
参考资料
贝叶斯平滑器的时序并行化, Simo Sarkka, Angel F. Garcia-Fernandez (https://arxiv.org/abs/1905.13002)
离散潜在变量的推断 (https://pyro.org.cn/examples/enumeration.html#Dependencies-among-plates)
- 参数:
- 返回:
scan 的输出,引用自
jax.lax.scan()
文档:“类型为 (c, [b]) 的对,其中第一个元素表示最终的循环 carrying 值,第二个元素表示 f 的第二个输出在沿输入的主轴扫描时堆叠的结果”。
cond
- cond(pred: bool, true_fun: Callable, false_fun: Callable, operand: Any) Any [source]
此原语根据条件应用
true_fun
或false_fun
。有关更多信息,请参阅jax.lax.cond()
。用法:
>>> import numpyro >>> import numpyro.distributions as dist >>> from jax import random >>> from numpyro.contrib.control_flow import cond >>> from numpyro.infer import SVI, Trace_ELBO >>> >>> def model(): ... def true_fun(_): ... return numpyro.sample("x", dist.Normal(20.0)) ... ... def false_fun(_): ... return numpyro.sample("x", dist.Normal(0.0)) ... ... cluster = numpyro.sample("cluster", dist.Normal()) ... return cond(cluster > 0, true_fun, false_fun, None) >>> >>> def guide(): ... m1 = numpyro.param("m1", 10.0) ... s1 = numpyro.param("s1", 0.1, constraint=dist.constraints.positive) ... m2 = numpyro.param("m2", 10.0) ... s2 = numpyro.param("s2", 0.1, constraint=dist.constraints.positive) ... ... def true_fun(_): ... return numpyro.sample("x", dist.Normal(m1, s1)) ... ... def false_fun(_): ... return numpyro.sample("x", dist.Normal(m2, s2)) ... ... cluster = numpyro.sample("cluster", dist.Normal()) ... return cond(cluster > 0, true_fun, false_fun, None) >>> >>> svi = SVI(model, guide, numpyro.optim.Adam(1e-2), Trace_ELBO(num_particles=100)) >>> svi_result = svi.run(random.PRNGKey(0), num_steps=2500)
警告
这是一个实验性实用函数,允许用户将 JAX 控制流与 NumPyro 的效果处理器一起使用。目前,支持 true_fun 和 false_fun 中的 sample 和 deterministic 点。如果您发现任何效果处理器或分布不受支持,请提交 issue。
警告
当前,
cond
原语不支持枚举,也不能在numpyro.plate
上下文中使用。注意
所有
sample
点必须属于同一分布类。例如,以下是不受支持的cond( True, lambda _: numpyro.sample("x", dist.Normal()), lambda _: numpyro.sample("x", dist.Laplace()), None, )
- 参数:
pred (bool) – 布尔标量类型,指示应用哪个分支函数。
true_fun (callable) – 当
pred
为真时应用的函数。false_fun (callable) – 当
pred
为假时应用的函数。operand – 根据
pred
应用到任一分支的操作数输入。这可以是任何 JAX PyTree(例如数组的列表/字典)。
- 返回:
应用的分支函数的输出。