基于 Funsor 的 NumPyro
请参阅 GitHub 仓库 获取有关 Funsor 的更多信息。
效果处理器
- class enum(fn=None, first_available_dim=None)[source]
基类:
BaseEnumMessenger
在标记为
infer={"enumerate": "parallel"}
的离散采样点上进行并行枚举。- 参数:
fn (callable) – 包含 NumPyro 原语的 Python 可调用对象。
first_available_dim (int) – 从右边计数,第一个可用于并行枚举的张量维度。此维度及其左边的所有维度可能会被 Pyro 内部使用。这应该是一个负整数或 None。
- class infer_config(fn: Callable | None = None, config_fn: Callable | None = None)[source]
基类:
Messenger
给定一个包含 NumPyro 原语调用的可调用对象 fn 和一个接受跟踪点并返回字典的可调用对象 config_fn,将采样点处的 infer kwarg 的值更新为 config_fn(site)。
- 参数:
fn – 一个随机函数(包含 NumPyro 原语调用的可调用对象)
config_fn – 一个接受点并返回 infer 字典的可调用对象
- markov(fn=None, history=1, keep=False)[source]
马尔可夫依赖声明。
这在统计上等同于一个内存管理区域。
- 参数:
fn (callable) – 包含 NumPyro 原语的 Python 可调用对象。
history (int) – 当前上下文可见的前一个上下文的数量。默认为 1。如果为零,则类似于
numpyro.primitives.plate
。keep (bool) – 如果为 True,帧可重放。这在分支时很重要:如果
keep=True
,同一级别的相邻分支可以相互依赖;如果keep=False
,相邻分支是独立的(以其共同祖先为条件)。
- class plate(name, size, subsample_size=None, dim=None)[source]
基类:
GlobalNamedMessenger
numpyro.primitives.plate
原语的另一种实现。请注意,只有此版本与枚举兼容。还有一个上下文管理器
plate_to_enum_plate()
,它将 numpyro.plate 语句转换为此版本。- 参数:
- to_data(x, name_to_dim=None, dim_type=DimType.LOCAL)[source]
从
Funsor
中提取 Python 对象的原语。- 参数:
x (Funsor) – 一个 Funsor 对象
name_to_dim (OrderedDict) – 一个可选的输入提示,将来自 x 的维度名称映射到返回值中的维度位置。
dim_type (int) – 可以是 0、1 或 2。此可选参数指示应将维度视为“local”、“global”或“visible”,这可用于与全局
DimStack
交互。
- 返回:
与 x 等效的非 Funsor 对象。
- to_funsor(x, output=None, dim_to_name=None, dim_type=DimType.LOCAL)[source]
将 Python 对象转换为
Funsor
的原语。- 参数:
x – 一个对象。
output (funsor.domains.Domain) – 一个可选的输出提示,用于唯一地将数据转换为 Funsor(例如,当 x 是字符串时)。
dim_to_name (OrderedDict) – 一个可选的映射,将负的 batch 维度映射到名称字符串。
dim_type (int) – 可以是 0、1 或 2。此可选参数指示应将维度视为“local”、“global”或“visible”,这可用于与全局
DimStack
交互。
- 返回:
与 x 等效的 Funsor 对象。
- 返回类型:
funsor.terms.Funsor
推断工具
- config_enumerate(fn=None, default='parallel')[source]
为 NumPyro 模型中的所有相关站点配置枚举。
当配置离散变量的详尽枚举时,这将配置所有满足
.has_enumerate_support == True
的采样点。这既可以用作函数
model = config_enumerate(model)
或用作装饰器
@config_enumerate def model(*args, **kwargs): ...
注意
目前,仅支持
default='parallel'
。- 参数:
fn (callable) – 包含 NumPyro 原语的 Python 可调用对象。
default (str) – 要使用的枚举策略,可以是 “sequential”、“parallel” 或 None。默认为 “parallel”。
- infer_discrete(fn=None, first_available_dim=None, temperature=1, rng_key=None)[source]
一个处理器,用于从后验中采样标记为
site["infer"]["enumerate"] = "parallel"
的离散点,以观测值为条件。示例
@infer_discrete(first_available_dim=-1, temperature=0) @config_enumerate def viterbi_decoder(data, hidden_dim=10): transition = 0.3 / hidden_dim + 0.7 * jnp.eye(hidden_dim) means = jnp.arange(float(hidden_dim)) states = [0] for t in markov(range(len(data))): states.append(numpyro.sample("states_{}".format(t), dist.Categorical(transition[states[-1]]))) numpyro.sample("obs_{}".format(t), dist.Normal(means[states[-1]], 1.), obs=data[t]) return states # returns maximum likelihood states
- 参数:
- log_density(model, model_args, model_kwargs, params)[source]
类似于
numpyro.infer.util.log_density()
,但适用于具有离散隐变量的模型。在内部,这使用funsor
来边缘化离散隐点并评估联合对数概率。- 参数:
- 返回:
联合密度的对数以及相应的模型跟踪。
- plate_to_enum_plate()[source]
一个上下文管理器,用于将 numpyro.plate 语句替换为基于 funsor 的
plate
。这在对包含 numpyro.plate 语句的常规 NumPyro 程序进行推断时非常有用。例如,要获取其离散隐点被枚举的 model 的跟踪,我们可以使用
enum_model = numpyro.contrib.funsor.enum(model) with plate_to_enum_plate(): model_trace = numpyro.contrib.funsor.trace(enum_model).get_trace( *model_args, **model_kwargs)