效果处理器

NumPyro 提供了一小组效果处理器,它们模仿了 Pyro 的 poutine 模块。有关效果处理器的更一般教程,建议读者阅读 Poutine: Pyro 中的效果处理器编程指南。这些简单的效果处理器可以组合使用或添加新的处理器,以便实现自定义推断工具和算法。

当一个处理器,例如 handlers.seed,应用于 NumPyro 中的模型时(例如 seeded_model = handlers.seed(model, rng_seed=0)),它会创建一个具有状态属性的可调用对象。这些属性可能会干扰 JAX 原语,例如 jax.jitjax.vmapjax.grad。为了确保与 JAX 原语正确组合,处理器应在模型使用的函数或上下文中局部应用,而不是全局应用。例如

# Good: can be used in a jitted function
def seeded_model(data):
    return handlers.seed(model, rng_seed=0)(data)

# Bad: might create tracer-leaks when used in a jitted function
seeded_model = handlers.seed(model, rng_seed=0)

示例

作为一个示例,我们使用 seedtracesubstitute 处理器定义下面的 log_likelihood 函数。我们首先创建一个逻辑回归模型,并使用 MCMC() 从回归参数的后验分布中进行采样。log_likelihood 函数使用效果处理器,通过将采样点替换为后验分布中的值来运行模型,并计算单个数据点的对数密度。log_predictive_density 函数计算联合后验分布中每次抽样的对数似然,并汇总所有数据点的结果,但它通过使用 JAX 的自动向量化转换 vmap 来实现,这样我们就无需遍历所有数据点。

>>> import jax.numpy as jnp
>>> from jax import random, vmap
>>> from jax.scipy.special import logsumexp
>>> import numpyro
>>> import numpyro.distributions as dist
>>> from numpyro import handlers
>>> from numpyro.infer import MCMC, NUTS

>>> N, D = 3000, 3
>>> def logistic_regression(data, labels):
...     coefs = numpyro.sample('coefs', dist.Normal(jnp.zeros(D), jnp.ones(D)))
...     intercept = numpyro.sample('intercept', dist.Normal(0., 10.))
...     logits = jnp.sum(coefs * data + intercept, axis=-1)
...     return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels)

>>> data = random.normal(random.PRNGKey(0), (N, D))
>>> true_coefs = jnp.arange(1., D + 1.)
>>> logits = jnp.sum(true_coefs * data, axis=-1)
>>> labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))

>>> num_warmup, num_samples = 1000, 1000
>>> mcmc = MCMC(NUTS(model=logistic_regression), num_warmup=num_warmup, num_samples=num_samples)
>>> mcmc.run(random.PRNGKey(2), data, labels)  
sample: 100%|██████████| 1000/1000 [00:00<00:00, 1252.39it/s, 1 steps of size 5.83e-01. acc. prob=0.85]
>>> mcmc.print_summary()  


                   mean         sd       5.5%      94.5%      n_eff       Rhat
    coefs[0]       0.96       0.07       0.85       1.07     455.35       1.01
    coefs[1]       2.05       0.09       1.91       2.20     332.00       1.01
    coefs[2]       3.18       0.13       2.96       3.37     320.27       1.00
   intercept      -0.03       0.02      -0.06       0.00     402.53       1.00

>>> def log_likelihood(rng_key, params, model, *args, **kwargs):
...     model = handlers.substitute(handlers.seed(model, rng_key), params)
...     model_trace = handlers.trace(model).get_trace(*args, **kwargs)
...     obs_node = model_trace['obs']
...     return obs_node['fn'].log_prob(obs_node['value'])

>>> def log_predictive_density(rng_key, params, model, *args, **kwargs):
...     n = list(params.values())[0].shape[0]
...     log_lk_fn = vmap(lambda rng_key, params: log_likelihood(rng_key, params, model, *args, **kwargs))
...     log_lk_vals = log_lk_fn(random.split(rng_key, n), params)
...     return jnp.sum(logsumexp(log_lk_vals, 0) - jnp.log(n))

>>> print(log_predictive_density(random.PRNGKey(2), mcmc.get_samples(),
...       logistic_regression, data, labels))  
-874.89813

block

class block(fn: 可调用对象 | None = None, hide_fn: 可调用对象 | None = None, hide: 列表[str] | None = None, expose_types: 列表[str] | None = None, expose: 列表[str] | None = None)[source]

基类: Messenger

给定一个可调用对象 fn,返回另一个可调用对象,该对象选择性地隐藏堆栈上其他效果处理器的原语点。在没有参数的情况下,所有原语点都被阻塞。hide_fn 优先于 hide,而 hide 的优先级高于 expose_types,最后是 expose。只考虑具有优先级的参数。

参数:
  • fn (可调用对象) – 包含 NumPyro 原语的 Python 可调用对象。

  • hide_fn (可调用对象) – 一个函数,给定一个包含点级别元数据的字典时,返回是否应该阻塞它。

  • hide (列表) – 要隐藏的点名称列表。

  • expose_types (列表) – 要暴露的点类型列表,例如 [‘param’]

  • expose (列表) – 要暴露的点名称列表。

返回值:

包含 NumPyro 原语的 Python 可调用对象。

示例

>>> from jax import random
>>> import numpyro
>>> from numpyro.handlers import block, seed, trace
>>> import numpyro.distributions as dist

>>> def model():
...     a = numpyro.sample('a', dist.Normal(0., 1.))
...     return numpyro.sample('b', dist.Normal(a, 1.))

>>> model = seed(model, random.PRNGKey(0))
>>> block_all = block(model)
>>> block_a = block(model, lambda site: site['name'] == 'a')
>>> trace_block_all = trace(block_all).get_trace()
>>> assert not {'a', 'b'}.intersection(trace_block_all.keys())
>>> trace_block_a =  trace(block_a).get_trace()
>>> assert 'a' not in trace_block_a
>>> assert 'b' in trace_block_a
process_message(msg: 字典[str, Any]) None[source]

由子类实现。

collapse

class collapse(*args, **kwargs)[source]

基类: trace

实验性 通过延迟采样并尝试使用共轭关系来折叠上下文中的所有点。如果不知道共轭关系,这将失败。使用采样点结果的代码必须编写为接受 Funsors 而不是 Tensors。这需要安装 funsor

process_message(msg: 字典[str, Any]) None[source]

由子类实现。

condition

class condition(fn: 可调用对象 | None = None, data: 字典[str, Array | ndarray | bool_ | number | bool | int | float | complex] | None = None, condition_fn: 可调用对象 | None = None)[source]

基类: Messenger

将未观测的采样点条件设置为来自 datacondition_fn 的值。类似于 substitute,但只影响 sample 点并将 is_observed 属性更改为 True

参数:
  • fn – 包含 NumPyro 原语的 Python 可调用对象。

  • data (字典) – 以点名称为键的 numpy.ndarray 值字典。

  • condition_fn – 一个可调用对象,它接受一个点字典并返回一个 numpy 数组或 None(在这种情况下,处理器没有副作用)。

示例

>>> from jax import random
>>> import numpyro
>>> from numpyro.handlers import condition, seed, substitute, trace
>>> import numpyro.distributions as dist

>>> def model():
...     numpyro.sample('a', dist.Normal(0., 1.))

>>> model = seed(model, random.PRNGKey(0))
>>> exec_trace = trace(condition(model, {'a': -1})).get_trace()
>>> assert exec_trace['a']['value'] == -1
>>> assert exec_trace['a']['is_observed']
process_message(msg)[source]

由子类实现。

do

class do(fn: 可调用对象 | None = None, data: 字典[str, Array | ndarray | bool_ | number | bool | int | float | complex ] | None = None)[source]

基类: Messenger

给定一个包含一些采样语句的随机函数以及一个以名称为键的值字典,将这些点的返回值设置为就像它们被硬编码到这些值一样,并引入同名的新采样点,这些点的值不会传播。

可以自由地与 condition() 组合,以表示潜在结果的反事实分布。有关更多详细信息和理论,请参阅单世界干预图 [1]。

这相当于将 z = numpyro.sample(“z”, …) 替换为 z = 1.,并引入一个新的采样点 numpyro.sample(“z”, …),其值不在其他地方使用。

参考文献

  1. 单世界干预图:入门指南, Thomas Richardson, James Robins

参数:
  • fn – 一个随机函数(包含 Pyro 原语调用的可调用对象)

  • data – 一个将采样点名称映射到干预值的 dict

示例

>>> import jax.numpy as jnp
>>> import numpyro
>>> from numpyro.handlers import do, trace, seed
>>> import numpyro.distributions as dist
>>> def model(x):
...     s = numpyro.sample("s", dist.LogNormal())
...     z = numpyro.sample("z", dist.Normal(x, s))
...     return z ** 2
>>> intervened_model = handlers.do(model, data={"z": 1.})
>>> with trace() as exec_trace:
...     z_square = seed(intervened_model, 0)(1)
>>> assert exec_trace['z']['value'] != 1.
>>> assert not exec_trace['z']['is_observed']
>>> assert not exec_trace['z'].get('stop', None)
>>> assert z_square == 1
process_message(msg: 字典[str, Any]) None[source]

由子类实现。

infer_config

class infer_config(fn: 可调用对象 | None = None, config_fn: 可调用对象 | None = None)[source]

基类: Messenger

给定一个包含 NumPyro 原语调用的可调用对象 fn 以及一个接受跟踪点并返回字典的可调用对象 config_fn,将采样点的推断(infer)关键字参数值更新为 config_fn(site)。

参数:
  • fn – 一个随机函数(包含 NumPyro 原语调用的可调用对象)

  • config_fn – 一个接受点并返回推断字典的可调用对象

process_message(msg: 字典[str, Any]) None[source]

由子类实现。

lift

class lift(fn: 可调用对象 | None = None, prior: DistributionLike | 字典[str, DistributionLike] | None = None)[source]

基类: Messenger

给定一个包含 param 调用和先验分布的随机函数,创建一个新的随机函数,其中所有 param 调用都被从先验分布中采样所取代。先验分布应该是一个分布对象或一个从名称到分布对象的字典。

考虑下面的 NumPyro 程序

>>> import numpyro
>>> import numpyro.distributions as dist
>>> from numpyro.handlers import lift
>>>
>>> def model(x):
...     s = numpyro.param("s", 0.5)
...     z = numpyro.sample("z", dist.Normal(x, s))
...     return z ** 2
>>> lifted_model = lift(model, prior={"s": dist.Exponential(0.3)})

lift 使 param 语句的行为类似于使用 prior 中的分布的 sample 语句。在此示例中,点 s 的行为就好像被替换为 s = numpyro.sample("s", dist.Exponential(0.3))

参数:
  • fn – 其参数将被提升为随机值的函数

  • prior – 先验函数,可以是 Distribution 对象或 Distribution 字典

process_message(msg: 字典[str, Any]) None[source]

由子类实现。

mask

class mask(fn: 可调用对象 | None = None, mask: Array | ndarray | bool_ | number | bool | int | float | complex | None = True)[source]

基类: Messenger

此消息处理器逐元素屏蔽部分采样语句。

参数:

mask – 一个布尔值或布尔值数组,用于逐元素屏蔽采样点的对数概率(True 包含该点,False 排除该点)。

process_message(msg: 字典[str, Any]) None[source]

由子类实现。

reparam

class reparam(fn: 可调用对象 | None = None, config: 字典 | 可调用对象 | None = None)[source]

基类: Messenger

将每个受影响的采样点重新参数化为一个或多个辅助采样点,然后进行确定性转换 [1]。

要指定重新参数化器,请将 config 字典或可调用对象传递给构造函数。有关可用的重新参数化器,请参阅 numpyro.infer.reparam 模块。

注意 某些重新参数化器可以检查它们影响的函数的 *args,**kwargs 输入;这些重新参数化器需要将 handlers.reparam 用作装饰器而不是上下文管理器。

[1] Maria I. Gorinova, Dave Moore, Matthew D. Hoffman (2019)

“Automatic Reparameterisation of Probabilistic Programs” https://arxiv.org/pdf/1906.03028.pdf

参数:

config (字典 or 可调用对象) – 配置,可以是将点名称映射到 Reparam 的字典,也可以是将点映射到 Reparam 或 None 的函数。

process_message(msg: 字典[str, Any]) None[source]

由子类实现。

replay

class replay(fn: 可调用对象 | None = None, trace: OrderedDict[str, 字典[str, Any]] | None = None)[source]

基类: Messenger

给定一个可调用对象 fn 和一个执行跟踪 trace,返回一个可调用对象,该对象将 fn 中的 sample 调用替换为 trace 中相应点名称的值。

参数:
  • fn – 包含 NumPyro 原语的 Python 可调用对象。

  • trace – 一个包含执行元数据的 OrderedDict。

示例

>>> from jax import random
>>> import numpyro
>>> import numpyro.distributions as dist
>>> from numpyro.handlers import replay, seed, trace

>>> def model():
...     numpyro.sample('a', dist.Normal(0., 1.))

>>> exec_trace = trace(seed(model, random.PRNGKey(0))).get_trace()
>>> print(exec_trace['a']['value'])  
-0.20584235
>>> replayed_trace = trace(replay(model, exec_trace)).get_trace()
>>> print(exec_trace['a']['value'])  
-0.20584235
>>> assert replayed_trace['a']['value'] == exec_trace['a']['value']
process_message(msg: 字典[str, Any]) None[source]

由子类实现。

scale

class scale(fn: 可调用对象 | None = None, scale: Array | ndarray | bool_ | number | bool | int | float | complex | None = 1.0)[source]

基类: Messenger

此消息处理器重新缩放对数概率得分。

这通常用于数据子采样或数据分层采样(例如在欺诈检测中,负样本数量远远超过正样本)。

参数:

scale (float or numpy.ndarray) – 一个正比例因子,可以广播到对数概率的形状。

process_message(msg: 字典[str, Any]) None[source]

由子类实现。

scope

class scope(fn: 可调用对象 | None = None, prefix: str = '', divider: str = '/', *, hide_types: 列表[str] | None = None)[source]

基类: Messenger

此处理器将一个前缀后跟一个分隔符添加到采样点的名称前。

示例

>>> import numpyro
>>> import numpyro.distributions as dist
>>> from numpyro.handlers import scope, seed, trace

>>> def model():
...     with scope(prefix="a"):
...         with scope(prefix="b", divider="."):
...             return numpyro.sample("x", dist.Bernoulli(0.5))
...
>>> assert "a/b.x" in trace(seed(model, 0)).get_trace()
参数:
  • fn – 包含 NumPyro 原语的 Python 可调用对象。

  • prefix (str) – 要添加到采样名称前的字符串

  • divider (str) – 用于连接前缀和采样名称的字符串;默认为 ‘/’

  • hide_types (列表) – 一个可选的点类型列表,跳过重命名。

process_message(msg: 字典[str, Any]) None[source]

由子类实现。

seed

class seed(fn: 可调用对象 | None = None, rng_seed: Array | None = None, hide_types: 列表[str] | None = None)[source]

基类: Messenger

JAX 使用函数式伪随机数生成器,需要将种子 PRNGKey() 传递给每个随机函数。seed 处理器允许我们使用 PRNGKey() 初始化随机函数。函数内部的每次 sample() 原语调用都会分割这个初始种子,以便我们在后续每次调用中使用新的种子,而无需显式地将 PRNGKey 传递给每次 sample 调用。

参数:
  • fn – 包含 NumPyro 原语的 Python 可调用对象。

  • rng_seed (int, jnp.ndarray scalar, or jax.random.PRNGKey) – 一个随机数生成器种子。

  • hide_types (列表) – 一个可选的点类型列表,跳过设定种子,例如 [‘plate’]。

注意

与 Pyro 不同,numpyro.sample 原语必须包装在 seed 处理器中才能使用,因为它没有全局随机状态。因此,用户需要将 seed 用作上下文管理器以从分布中生成样本,或者用作其模型可调用对象的装饰器(参见下方)。

注意

seed 处理器有一个可变属性 rng_key,它在每次 sample 调用后都会改变。因此,此类的实例(例如 seed(model, rng_seed=0))在进行 JIT 编译时可能会产生 tracer-leaks。一个解决方案是在函数中关闭该实例,例如 seeded_model = lambda *args: seed(model, rng_seed=0)(*args)。这个 seeded_model 可以进行 JIT 编译。

示例

>>> from jax import random
>>> import numpyro
>>> import numpyro.handlers
>>> import numpyro.distributions as dist

>>> # as context manager
>>> with handlers.seed(rng_seed=1):
...     x = numpyro.sample('x', dist.Normal(0., 1.))

>>> def model():
...     return numpyro.sample('y', dist.Normal(0., 1.))

>>> # as function decorator (/modifier)
>>> y = handlers.seed(model, rng_seed=1)()
>>> assert x == y
stateful = False
process_message(msg: 字典[str, Any]) None[source]

由子类实现。

substitute

class substitute(fn: 可调用对象 | None = None, data: 字典[str, Array] | None = None, substitute_fn: 可调用对象 | None = None)[source]

基类: Messenger

给定一个可调用对象 fn 和一个以点名称为键的字典 data(或者,一个可调用对象 substitute_fn),返回一个可调用对象,该对象将 fn 中的所有原语调用替换为 data 中与点名称匹配的键的值。如果点名称不在 data 中,则没有副作用。

如果提供了 substitute_fn,则该点的值将替换为调用 substitute_fn 为给定点返回的值。

注意

此处理器主要用于内部算法。对于基于观测数据对生成模型进行条件设置,请使用 condition 处理器。

参数:
  • fn – 包含 NumPyro 原语的 Python 可调用对象。

  • data (字典) – 以点名称为键的 numpy.ndarray 值字典。

  • substitute_fn – 一个可调用对象,它接受一个点字典并返回一个 numpy 数组或 None(在这种情况下,处理器没有副作用)。

示例

>>> from jax import random
>>> import numpyro
>>> from numpyro.handlers import seed, substitute, trace
>>> import numpyro.distributions as dist

>>> def model():
...     numpyro.sample('a', dist.Normal(0., 1.))

>>> model = seed(model, random.PRNGKey(0))
>>> exec_trace = trace(substitute(model, {'a': -1})).get_trace()
>>> assert exec_trace['a']['value'] == -1
process_message(msg: 字典[str, Any]) None[source]

由子类实现。

trace

class trace(fn: 可调用对象 | None = None)[source]

基类: Messenger

返回一个处理器,该处理器记录 fn 内原语调用的输入和输出。

示例

>>> from jax import random
>>> import numpyro
>>> import numpyro.distributions as dist
>>> from numpyro.handlers import seed, trace
>>> import pprint as pp

>>> def model():
...     numpyro.sample('a', dist.Normal(0., 1.))

>>> exec_trace = trace(seed(model, random.PRNGKey(0))).get_trace()
>>> pp.pprint(exec_trace)  
OrderedDict([('a',
              {'args': (),
               'fn': <numpyro.distributions.continuous.Normal object at 0x7f9e689b1eb8>,
               'is_observed': False,
               'kwargs': {'rng_key': Array([0, 0], dtype=uint32)},
               'name': 'a',
               'type': 'sample',
               'value': Array(-0.20584235, dtype=float32)})])
postprocess_message(msg: 字典[str, Any]) None[source]

由子类实现。

get_trace(*args, **kwargs) OrderedDict[str, 字典[str, Any]][source]

运行包装的可调用对象并返回记录的跟踪。

参数:
  • *args – 可调用对象的参数。

  • **kwargs – 可调用对象的关键字参数。

返回值:

包含执行跟踪的 OrderedDict