重新参数化器

模块 numpyro.infer.reparam 包含了用于 numpyro.handlers.reparam 效果的重新参数化策略。这些策略对于改变条件差的参数空间的几何形状以使后验分布形状更好很有用。它们可以与各种推断算法一起使用,例如 Auto*Normal guides 和 MCMC。

class Reparam[source]

基类: ABC

重新参数化器的基类。

位置-尺度去中心化

class LocScaleReparam(centered=None, shape_params=())[source]

基类: Reparam

通用的去中心化重新参数化器 [1],用于由 locscale(以及可能的额外 shape_params)参数化的隐变量。

这种重新参数化仅适用于隐变量,不适用于似然。

参考文献

  1. Automatic Reparameterisation of Probabilistic Programs, Maria I. Gorinova, Dave Moore, Matthew D. Hoffman (2019)

参数:
  • centered (float) – 可选的中心化参数。如果为 None(默认),则学习一个在 [0,1] 范围内、初始化值为 0.5 的每站点、每元素的中心化参数。要采样该参数,考虑使用 lift 处理器以及类似 Uniform(0, 1) 的先验,将该参数转换为隐变量。如果为 0,则完全去中心化分布;如果为 1,则保持中心化分布不变。

  • shape_params (tuplelist) – 从中心化分布复制到去中心化分布时保持不变的额外参数名称列表。

__call__(name, fn, obs)[source]
参数:
  • name (str) – 采样站点的名称。

  • fn (Distribution) – 一个分布。

  • obs (numpy.ndarray) – 观测值或 None。

返回:

一对 (new_fn, value)。

神经传输

class NeuTraReparam(guide, params)[source]

基类: Reparam

多个隐变量的神经传输重新参数化器 [1]。

这使用训练好的 AutoContinuous guide 来改变模型的几何形状,通常用于 MCMC 等。示例用法

# Step 1. Train a guide
guide = AutoIAFNormal(model)
svi = SVI(model, guide, ...)
# ...train the guide...

# Step 2. Use trained guide in NeuTra MCMC
neutra = NeuTraReparam(guide)
model = neutra.reparam(model)
nuts = NUTS(model)
# ...now use the model in HMC or NUTS...

这种重新参数化仅适用于隐变量,不适用于似然。请注意,所有站点必须共享同一个 NeuTraReparam 实例,并且模型必须具有静态结构。

[1] Hoffman, M. 等人 (2019)

“利用神经传输中和哈密顿蒙特卡洛中的不良几何形状” https://arxiv.org/abs/1903.03704

参数:
  • guide (AutoContinuous) – 一个 guide。

  • params – guide 的训练参数。

reparam(fn=None)[source]
__call__(name, fn, obs)[source]
参数:
  • name (str) – 采样站点的名称。

  • fn (Distribution) – 一个分布。

  • obs (numpy.ndarray) – 观测值或 None。

返回:

一对 (new_fn, value)。

transform_sample(latent)[source]

给定来自变形后验(可能带有批量维度)的隐变量样本,返回一个 dict,其中包含模型中隐变量站点的样本。

参数:

latent – 来自变形后验的样本(可能带有批量)。

返回:

一个 dict,键是模型中的隐变量站点,值是对应的样本。

返回类型:

dict

变换后的分布

class TransformReparam[source]

基类: Reparam

TransformedDistribution 隐变量的重新参数化器。

这对于具有复杂、改变几何形状的变换的变换分布很有用,在这种情况下,后验在 base_dist 的空间中具有简单的形状。

这种重新参数化仅适用于隐变量,不适用于似然。

__call__(name, fn, obs)[source]
参数:
  • name (str) – 采样站点的名称。

  • fn (Distribution) – 一个分布。

  • obs (numpy.ndarray) – 观测值或 None。

返回:

一对 (new_fn, value)。

投影正态分布

class ProjectedNormalReparam[source]

基类: Reparam

ProjectedNormal 隐变量的重新参数化器。

这种重新参数化仅适用于隐变量,不适用于似然。

__call__(name, fn, obs)[source]
参数:
  • name (str) – 采样站点的名称。

  • fn (Distribution) – 一个分布。

  • obs (numpy.ndarray) – 观测值或 None。

返回:

一对 (new_fn, value)。

圆形分布

class CircularReparam[source]

基类: Reparam

VonMises 隐变量的重新参数化器。

__call__(name, fn, obs)[source]
参数:
  • name (str) – 采样站点的名称。

  • fn (Distribution) – 一个分布。

  • obs (numpy.ndarray) – 观测值或 None。

返回:

一对 (new_fn, value)。

显式重新参数化

class ExplicitReparam(transform)[source]

基类: Reparam

将隐变量 x 显式重新参数化到几何形状更易处理的变换空间 y = transform(x)。这个重新参数化器类似于 TransformReparam,但允许将重新参数化与模型声明解耦。

参数:

transform – 到重新参数化空间的双射变换。

示例

>>> from jax import random
>>> from jax import numpy as jnp
>>> import numpyro
>>> from numpyro import handlers, distributions as dist
>>> from numpyro.infer import MCMC, NUTS
>>> from numpyro.infer.reparam import ExplicitReparam
>>>
>>> def model():
...    numpyro.sample("x", dist.Gamma(4, 4))
>>>
>>> # Sample in unconstrained space using a soft-plus instead of exp transform.
>>> reparam = ExplicitReparam(dist.transforms.SoftplusTransform().inv)
>>> reparametrized = handlers.reparam(model, {"x": reparam})
>>> kernel = NUTS(model=reparametrized)
>>> mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000, num_chains=1)
>>> mcmc.run(random.PRNGKey(2))  
sample: 100%|██████████| 2000/2000 [00:00<00:00, 2306.47it/s, 3 steps of size 9.65e-01. acc. prob=0.93]
__call__(name, fn, obs)[source]
参数:
  • name (str) – 采样站点的名称。

  • fn (Distribution) – 一个分布。

  • obs (numpy.ndarray) – 观测值或 None。

返回:

一对 (new_fn, value)。