基于 Funsor 的 NumPyro

请参阅 GitHub 仓库 获取有关 Funsor 的更多信息。

效果处理器

class enum(fn=None, first_available_dim=None)[source]

基类:BaseEnumMessenger

在标记为 infer={"enumerate": "parallel"} 的离散采样点上进行并行枚举。

参数:
  • fn (callable) – 包含 NumPyro 原语的 Python 可调用对象。

  • first_available_dim (int) – 从右边计数,第一个可用于并行枚举的张量维度。此维度及其左边的所有维度可能会被 Pyro 内部使用。这应该是一个负整数或 None。

process_message(msg)[source]

由子类实现。

class infer_config(fn: Callable | None = None, config_fn: Callable | None = None)[source]

基类:Messenger

给定一个包含 NumPyro 原语调用的可调用对象 fn 和一个接受跟踪点并返回字典的可调用对象 config_fn,将采样点处的 infer kwarg 的值更新为 config_fn(site)。

参数:
  • fn – 一个随机函数(包含 NumPyro 原语调用的可调用对象)

  • config_fn – 一个接受点并返回 infer 字典的可调用对象

process_message(msg: dict[str, Any]) None[source]

由子类实现。

markov(fn=None, history=1, keep=False)[source]

马尔可夫依赖声明。

这在统计上等同于一个内存管理区域。

参数:
  • fn (callable) – 包含 NumPyro 原语的 Python 可调用对象。

  • history (int) – 当前上下文可见的前一个上下文的数量。默认为 1。如果为零,则类似于 numpyro.primitives.plate

  • keep (bool) – 如果为 True,帧可重放。这在分支时很重要:如果 keep=True,同一级别的相邻分支可以相互依赖;如果 keep=False,相邻分支是独立的(以其共同祖先为条件)。

class plate(name, size, subsample_size=None, dim=None)[source]

基类:GlobalNamedMessenger

numpyro.primitives.plate 原语的另一种实现。请注意,只有此版本与枚举兼容。

还有一个上下文管理器 plate_to_enum_plate(),它将 numpyro.plate 语句转换为此版本。

参数:
  • name (str) – plate 的名称。

  • size (int) – plate 的大小。

  • subsample_size (int) – 可选参数,表示 mini-batch 的大小。这可用于通过推断算法应用缩放因子。例如,在使用 mini-batch 计算 ELBO 时。

  • dim (int) – 可选参数,指定张量中的哪个维度用作 plate dim。如果为 None(默认),则分配最右边可用的 dim。

process_message(msg)[source]

由子类实现。

postprocess_message(msg)[source]

由子类实现。

to_data(x, name_to_dim=None, dim_type=DimType.LOCAL)[source]

Funsor 中提取 Python 对象的原语。

参数:
  • x (Funsor) – 一个 Funsor 对象

  • name_to_dim (OrderedDict) – 一个可选的输入提示,将来自 x 的维度名称映射到返回值中的维度位置。

  • dim_type (int) – 可以是 0、1 或 2。此可选参数指示应将维度视为“local”、“global”或“visible”,这可用于与全局 DimStack 交互。

返回:

x 等效的非 Funsor 对象。

to_funsor(x, output=None, dim_to_name=None, dim_type=DimType.LOCAL)[source]

将 Python 对象转换为 Funsor 的原语。

参数:
  • x – 一个对象。

  • output (funsor.domains.Domain) – 一个可选的输出提示,用于唯一地将数据转换为 Funsor(例如,当 x 是字符串时)。

  • dim_to_name (OrderedDict) – 一个可选的映射,将负的 batch 维度映射到名称字符串。

  • dim_type (int) – 可以是 0、1 或 2。此可选参数指示应将维度视为“local”、“global”或“visible”,这可用于与全局 DimStack 交互。

返回:

x 等效的 Funsor 对象。

返回类型:

funsor.terms.Funsor

class trace(fn: Callable | None = None)[source]

基类:trace

此版本的 trace 处理器记录执行后进行打包所需的信息。

每个采样点都用一个“dim_to_name”字典进行标注,可以直接传递给 to_funsor()

postprocess_message(msg)[source]

由子类实现。

推断工具

config_enumerate(fn=None, default='parallel')[source]

为 NumPyro 模型中的所有相关站点配置枚举。

当配置离散变量的详尽枚举时,这将配置所有满足 .has_enumerate_support == True 的采样点。

这既可以用作函数

model = config_enumerate(model)

或用作装饰器

@config_enumerate
def model(*args, **kwargs):
    ...

注意

目前,仅支持 default='parallel'

参数:
  • fn (callable) – 包含 NumPyro 原语的 Python 可调用对象。

  • default (str) – 要使用的枚举策略,可以是 “sequential”、“parallel” 或 None。默认为 “parallel”。

infer_discrete(fn=None, first_available_dim=None, temperature=1, rng_key=None)[source]

一个处理器,用于从后验中采样标记为 site["infer"]["enumerate"] = "parallel" 的离散点,以观测值为条件。

示例

@infer_discrete(first_available_dim=-1, temperature=0)
@config_enumerate
def viterbi_decoder(data, hidden_dim=10):
    transition = 0.3 / hidden_dim + 0.7 * jnp.eye(hidden_dim)
    means = jnp.arange(float(hidden_dim))
    states = [0]
    for t in markov(range(len(data))):
        states.append(numpyro.sample("states_{}".format(t),
                                     dist.Categorical(transition[states[-1]])))
        numpyro.sample("obs_{}".format(t),
                       dist.Normal(means[states[-1]], 1.),
                       obs=data[t])
    return states  # returns maximum likelihood states
参数:
  • fn – 一个随机函数(包含 NumPyro 原语调用的可调用对象)

  • first_available_dim (int) – 从右边计数,第一个可用于并行枚举的张量维度。此维度及其左边的所有维度可能会被 Pyro 内部使用。这应该是一个负整数。

  • temperature (int) – 可以是 1(通过前向过滤后向采样进行采样)或 0(通过类似 Viterbi 的 MAP 推断进行优化)。默认为 1(采样)。

  • rng_key (jax.random.PRNGKey) – 随机数生成器密钥,用于 temperature=1first_available_dim is None 的情况。

log_density(model, model_args, model_kwargs, params)[source]

类似于 numpyro.infer.util.log_density(),但适用于具有离散隐变量的模型。在内部,这使用 funsor 来边缘化离散隐点并评估联合对数概率。

参数:
  • model

    包含 NumPyro 原语的 Python 可调用对象。通常,模型已通过使用 enum 处理器进行枚举。

    def model(*args, **kwargs):
        ...
    
    log_joint = log_density(enum(config_enumerate(model)), args, kwargs, params)
    

  • model_args (tuple) – 提供给模型的参数 (args)。

  • model_kwargs (dict) – 提供给模型的关键字参数 (kwargs)。

  • params (dict) – 按点名称键控的当前参数值字典。

返回:

联合密度的对数以及相应的模型跟踪。

plate_to_enum_plate()[source]

一个上下文管理器,用于将 numpyro.plate 语句替换为基于 funsor 的 plate

这在对包含 numpyro.plate 语句的常规 NumPyro 程序进行推断时非常有用。例如,要获取其离散隐点被枚举的 model 的跟踪,我们可以使用

enum_model = numpyro.contrib.funsor.enum(model)
with plate_to_enum_plate():
    model_trace = numpyro.contrib.funsor.trace(enum_model).get_trace(
        *model_args, **model_kwargs)