效果处理器
NumPyro 提供了一小组效果处理器,它们模仿了 Pyro 的 poutine 模块。有关效果处理器的更一般教程,建议读者阅读 Poutine: Pyro 中的效果处理器编程指南。这些简单的效果处理器可以组合使用或添加新的处理器,以便实现自定义推断工具和算法。
当一个处理器,例如 handlers.seed,应用于 NumPyro 中的模型时(例如 seeded_model = handlers.seed(model, rng_seed=0)),它会创建一个具有状态属性的可调用对象。这些属性可能会干扰 JAX 原语,例如 jax.jit、jax.vmap 和 jax.grad。为了确保与 JAX 原语正确组合,处理器应在模型使用的函数或上下文中局部应用,而不是全局应用。例如
# Good: can be used in a jitted function
def seeded_model(data):
return handlers.seed(model, rng_seed=0)(data)
# Bad: might create tracer-leaks when used in a jitted function
seeded_model = handlers.seed(model, rng_seed=0)
示例
作为一个示例,我们使用 seed
、trace
和 substitute
处理器定义下面的 log_likelihood 函数。我们首先创建一个逻辑回归模型,并使用 MCMC()
从回归参数的后验分布中进行采样。log_likelihood 函数使用效果处理器,通过将采样点替换为后验分布中的值来运行模型,并计算单个数据点的对数密度。log_predictive_density 函数计算联合后验分布中每次抽样的对数似然,并汇总所有数据点的结果,但它通过使用 JAX 的自动向量化转换 vmap 来实现,这样我们就无需遍历所有数据点。
>>> import jax.numpy as jnp
>>> from jax import random, vmap
>>> from jax.scipy.special import logsumexp
>>> import numpyro
>>> import numpyro.distributions as dist
>>> from numpyro import handlers
>>> from numpyro.infer import MCMC, NUTS
>>> N, D = 3000, 3
>>> def logistic_regression(data, labels):
... coefs = numpyro.sample('coefs', dist.Normal(jnp.zeros(D), jnp.ones(D)))
... intercept = numpyro.sample('intercept', dist.Normal(0., 10.))
... logits = jnp.sum(coefs * data + intercept, axis=-1)
... return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels)
>>> data = random.normal(random.PRNGKey(0), (N, D))
>>> true_coefs = jnp.arange(1., D + 1.)
>>> logits = jnp.sum(true_coefs * data, axis=-1)
>>> labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))
>>> num_warmup, num_samples = 1000, 1000
>>> mcmc = MCMC(NUTS(model=logistic_regression), num_warmup=num_warmup, num_samples=num_samples)
>>> mcmc.run(random.PRNGKey(2), data, labels)
sample: 100%|██████████| 1000/1000 [00:00<00:00, 1252.39it/s, 1 steps of size 5.83e-01. acc. prob=0.85]
>>> mcmc.print_summary()
mean sd 5.5% 94.5% n_eff Rhat
coefs[0] 0.96 0.07 0.85 1.07 455.35 1.01
coefs[1] 2.05 0.09 1.91 2.20 332.00 1.01
coefs[2] 3.18 0.13 2.96 3.37 320.27 1.00
intercept -0.03 0.02 -0.06 0.00 402.53 1.00
>>> def log_likelihood(rng_key, params, model, *args, **kwargs):
... model = handlers.substitute(handlers.seed(model, rng_key), params)
... model_trace = handlers.trace(model).get_trace(*args, **kwargs)
... obs_node = model_trace['obs']
... return obs_node['fn'].log_prob(obs_node['value'])
>>> def log_predictive_density(rng_key, params, model, *args, **kwargs):
... n = list(params.values())[0].shape[0]
... log_lk_fn = vmap(lambda rng_key, params: log_likelihood(rng_key, params, model, *args, **kwargs))
... log_lk_vals = log_lk_fn(random.split(rng_key, n), params)
... return jnp.sum(logsumexp(log_lk_vals, 0) - jnp.log(n))
>>> print(log_predictive_density(random.PRNGKey(2), mcmc.get_samples(),
... logistic_regression, data, labels))
-874.89813
block
- class block(fn: 可调用对象 | None = None, hide_fn: 可调用对象 | None = None, hide: 列表[str] | None = None, expose_types: 列表[str] | None = None, expose: 列表[str] | None = None)[source]
基类:
Messenger
给定一个可调用对象 fn,返回另一个可调用对象,该对象选择性地隐藏堆栈上其他效果处理器的原语点。在没有参数的情况下,所有原语点都被阻塞。hide_fn 优先于 hide,而 hide 的优先级高于 expose_types,最后是 expose。只考虑具有优先级的参数。
- 参数:
- 返回值:
包含 NumPyro 原语的 Python 可调用对象。
示例
>>> from jax import random >>> import numpyro >>> from numpyro.handlers import block, seed, trace >>> import numpyro.distributions as dist >>> def model(): ... a = numpyro.sample('a', dist.Normal(0., 1.)) ... return numpyro.sample('b', dist.Normal(a, 1.)) >>> model = seed(model, random.PRNGKey(0)) >>> block_all = block(model) >>> block_a = block(model, lambda site: site['name'] == 'a') >>> trace_block_all = trace(block_all).get_trace() >>> assert not {'a', 'b'}.intersection(trace_block_all.keys()) >>> trace_block_a = trace(block_a).get_trace() >>> assert 'a' not in trace_block_a >>> assert 'b' in trace_block_a
collapse
condition
- class condition(fn: 可调用对象 | None = None, data: 字典[str, Array | ndarray | bool_ | number | bool | int | float | complex] | None = None, condition_fn: 可调用对象 | None = None)[source]
基类:
Messenger
将未观测的采样点条件设置为来自 data 或 condition_fn 的值。类似于
substitute
,但只影响 sample 点并将 is_observed 属性更改为 True。- 参数:
fn – 包含 NumPyro 原语的 Python 可调用对象。
data (字典) – 以点名称为键的 numpy.ndarray 值字典。
condition_fn – 一个可调用对象,它接受一个点字典并返回一个 numpy 数组或 None(在这种情况下,处理器没有副作用)。
示例
>>> from jax import random >>> import numpyro >>> from numpyro.handlers import condition, seed, substitute, trace >>> import numpyro.distributions as dist >>> def model(): ... numpyro.sample('a', dist.Normal(0., 1.)) >>> model = seed(model, random.PRNGKey(0)) >>> exec_trace = trace(condition(model, {'a': -1})).get_trace() >>> assert exec_trace['a']['value'] == -1 >>> assert exec_trace['a']['is_observed']
do
- class do(fn: 可调用对象 | None = None, data: 字典[str, Array | ndarray | bool_ | number | bool | int | float | complex ] | None = None)[source]
基类:
Messenger
给定一个包含一些采样语句的随机函数以及一个以名称为键的值字典,将这些点的返回值设置为就像它们被硬编码到这些值一样,并引入同名的新采样点,这些点的值不会传播。
可以自由地与
condition()
组合,以表示潜在结果的反事实分布。有关更多详细信息和理论,请参阅单世界干预图 [1]。这相当于将 z = numpyro.sample(“z”, …) 替换为 z = 1.,并引入一个新的采样点 numpyro.sample(“z”, …),其值不在其他地方使用。
参考文献
单世界干预图:入门指南, Thomas Richardson, James Robins
- 参数:
fn – 一个随机函数(包含 Pyro 原语调用的可调用对象)
data – 一个将采样点名称映射到干预值的
dict
示例
>>> import jax.numpy as jnp >>> import numpyro >>> from numpyro.handlers import do, trace, seed >>> import numpyro.distributions as dist >>> def model(x): ... s = numpyro.sample("s", dist.LogNormal()) ... z = numpyro.sample("z", dist.Normal(x, s)) ... return z ** 2 >>> intervened_model = handlers.do(model, data={"z": 1.}) >>> with trace() as exec_trace: ... z_square = seed(intervened_model, 0)(1) >>> assert exec_trace['z']['value'] != 1. >>> assert not exec_trace['z']['is_observed'] >>> assert not exec_trace['z'].get('stop', None) >>> assert z_square == 1
infer_config
lift
- class lift(fn: 可调用对象 | None = None, prior: DistributionLike | 字典[str, DistributionLike] | None = None)[source]
基类:
Messenger
给定一个包含
param
调用和先验分布的随机函数,创建一个新的随机函数,其中所有 param 调用都被从先验分布中采样所取代。先验分布应该是一个分布对象或一个从名称到分布对象的字典。考虑下面的 NumPyro 程序
>>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.handlers import lift >>> >>> def model(x): ... s = numpyro.param("s", 0.5) ... z = numpyro.sample("z", dist.Normal(x, s)) ... return z ** 2 >>> lifted_model = lift(model, prior={"s": dist.Exponential(0.3)})
lift 使
param
语句的行为类似于使用prior
中的分布的sample
语句。在此示例中,点 s 的行为就好像被替换为s = numpyro.sample("s", dist.Exponential(0.3))
。- 参数:
fn – 其参数将被提升为随机值的函数
prior – 先验函数,可以是 Distribution 对象或 Distribution 字典
mask
reparam
- class reparam(fn: 可调用对象 | None = None, config: 字典 | 可调用对象 | None = None)[source]
基类:
Messenger
将每个受影响的采样点重新参数化为一个或多个辅助采样点,然后进行确定性转换 [1]。
要指定重新参数化器,请将
config
字典或可调用对象传递给构造函数。有关可用的重新参数化器,请参阅numpyro.infer.reparam
模块。注意 某些重新参数化器可以检查它们影响的函数的
*args,**kwargs
输入;这些重新参数化器需要将handlers.reparam
用作装饰器而不是上下文管理器。- [1] Maria I. Gorinova, Dave Moore, Matthew D. Hoffman (2019)
“Automatic Reparameterisation of Probabilistic Programs” https://arxiv.org/pdf/1906.03028.pdf
replay
- class replay(fn: 可调用对象 | None = None, trace: OrderedDict[str, 字典[str, Any]] | None = None)[source]
基类:
Messenger
给定一个可调用对象 fn 和一个执行跟踪 trace,返回一个可调用对象,该对象将 fn 中的 sample 调用替换为 trace 中相应点名称的值。
- 参数:
fn – 包含 NumPyro 原语的 Python 可调用对象。
trace – 一个包含执行元数据的 OrderedDict。
示例
>>> from jax import random >>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.handlers import replay, seed, trace >>> def model(): ... numpyro.sample('a', dist.Normal(0., 1.)) >>> exec_trace = trace(seed(model, random.PRNGKey(0))).get_trace() >>> print(exec_trace['a']['value']) -0.20584235 >>> replayed_trace = trace(replay(model, exec_trace)).get_trace() >>> print(exec_trace['a']['value']) -0.20584235 >>> assert replayed_trace['a']['value'] == exec_trace['a']['value']
scale
scope
- class scope(fn: 可调用对象 | None = None, prefix: str = '', divider: str = '/', *, hide_types: 列表[str] | None = None)[source]
基类:
Messenger
此处理器将一个前缀后跟一个分隔符添加到采样点的名称前。
示例
>>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.handlers import scope, seed, trace >>> def model(): ... with scope(prefix="a"): ... with scope(prefix="b", divider="."): ... return numpyro.sample("x", dist.Bernoulli(0.5)) ... >>> assert "a/b.x" in trace(seed(model, 0)).get_trace()
- 参数:
seed
- class seed(fn: 可调用对象 | None = None, rng_seed: Array | None = None, hide_types: 列表[str] | None = None)[source]
基类:
Messenger
JAX 使用函数式伪随机数生成器,需要将种子
PRNGKey()
传递给每个随机函数。seed 处理器允许我们使用PRNGKey()
初始化随机函数。函数内部的每次sample()
原语调用都会分割这个初始种子,以便我们在后续每次调用中使用新的种子,而无需显式地将 PRNGKey 传递给每次 sample 调用。- 参数:
fn – 包含 NumPyro 原语的 Python 可调用对象。
rng_seed (int, jnp.ndarray scalar, or jax.random.PRNGKey) – 一个随机数生成器种子。
hide_types (列表) – 一个可选的点类型列表,跳过设定种子,例如 [‘plate’]。
注意
与 Pyro 不同,numpyro.sample 原语必须包装在 seed 处理器中才能使用,因为它没有全局随机状态。因此,用户需要将 seed 用作上下文管理器以从分布中生成样本,或者用作其模型可调用对象的装饰器(参见下方)。
注意
seed 处理器有一个可变属性 rng_key,它在每次 sample 调用后都会改变。因此,此类的实例(例如 seed(model, rng_seed=0))在进行 JIT 编译时可能会产生 tracer-leaks。一个解决方案是在函数中关闭该实例,例如 seeded_model = lambda *args: seed(model, rng_seed=0)(*args)。这个 seeded_model 可以进行 JIT 编译。
示例
>>> from jax import random >>> import numpyro >>> import numpyro.handlers >>> import numpyro.distributions as dist >>> # as context manager >>> with handlers.seed(rng_seed=1): ... x = numpyro.sample('x', dist.Normal(0., 1.)) >>> def model(): ... return numpyro.sample('y', dist.Normal(0., 1.)) >>> # as function decorator (/modifier) >>> y = handlers.seed(model, rng_seed=1)() >>> assert x == y
- stateful = False
substitute
- class substitute(fn: 可调用对象 | None = None, data: 字典[str, Array] | None = None, substitute_fn: 可调用对象 | None = None)[source]
基类:
Messenger
给定一个可调用对象 fn 和一个以点名称为键的字典 data(或者,一个可调用对象 substitute_fn),返回一个可调用对象,该对象将 fn 中的所有原语调用替换为 data 中与点名称匹配的键的值。如果点名称不在 data 中,则没有副作用。
如果提供了 substitute_fn,则该点的值将替换为调用 substitute_fn 为给定点返回的值。
注意
此处理器主要用于内部算法。对于基于观测数据对生成模型进行条件设置,请使用
condition
处理器。- 参数:
fn – 包含 NumPyro 原语的 Python 可调用对象。
data (字典) – 以点名称为键的 numpy.ndarray 值字典。
substitute_fn – 一个可调用对象,它接受一个点字典并返回一个 numpy 数组或 None(在这种情况下,处理器没有副作用)。
示例
>>> from jax import random >>> import numpyro >>> from numpyro.handlers import seed, substitute, trace >>> import numpyro.distributions as dist >>> def model(): ... numpyro.sample('a', dist.Normal(0., 1.)) >>> model = seed(model, random.PRNGKey(0)) >>> exec_trace = trace(substitute(model, {'a': -1})).get_trace() >>> assert exec_trace['a']['value'] == -1
trace
- class trace(fn: 可调用对象 | None = None)[source]
基类:
Messenger
返回一个处理器,该处理器记录 fn 内原语调用的输入和输出。
示例
>>> from jax import random >>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.handlers import seed, trace >>> import pprint as pp >>> def model(): ... numpyro.sample('a', dist.Normal(0., 1.)) >>> exec_trace = trace(seed(model, random.PRNGKey(0))).get_trace() >>> pp.pprint(exec_trace) OrderedDict([('a', {'args': (), 'fn': <numpyro.distributions.continuous.Normal object at 0x7f9e689b1eb8>, 'is_observed': False, 'kwargs': {'rng_key': Array([0, 0], dtype=uint32)}, 'name': 'a', 'type': 'sample', 'value': Array(-0.20584235, dtype=float32)})])