重新参数化器
模块 numpyro.infer.reparam
包含了用于 numpyro.handlers.reparam
效果的重新参数化策略。这些策略对于改变条件差的参数空间的几何形状以使后验分布形状更好很有用。它们可以与各种推断算法一起使用,例如 Auto*Normal
guides 和 MCMC。
位置-尺度去中心化
- class LocScaleReparam(centered=None, shape_params=())[source]
基类:
Reparam
通用的去中心化重新参数化器 [1],用于由
loc
和scale
(以及可能的额外shape_params
)参数化的隐变量。这种重新参数化仅适用于隐变量,不适用于似然。
参考文献
Automatic Reparameterisation of Probabilistic Programs, Maria I. Gorinova, Dave Moore, Matthew D. Hoffman (2019)
- 参数:
- __call__(name, fn, obs)[source]
- 参数:
name (str) – 采样站点的名称。
fn (Distribution) – 一个分布。
obs (numpy.ndarray) – 观测值或 None。
- 返回:
一对 (
new_fn
,value
)。
神经传输
- class NeuTraReparam(guide, params)[source]
基类:
Reparam
多个隐变量的神经传输重新参数化器 [1]。
这使用训练好的
AutoContinuous
guide 来改变模型的几何形状,通常用于 MCMC 等。示例用法# Step 1. Train a guide guide = AutoIAFNormal(model) svi = SVI(model, guide, ...) # ...train the guide... # Step 2. Use trained guide in NeuTra MCMC neutra = NeuTraReparam(guide) model = neutra.reparam(model) nuts = NUTS(model) # ...now use the model in HMC or NUTS...
这种重新参数化仅适用于隐变量,不适用于似然。请注意,所有站点必须共享同一个
NeuTraReparam
实例,并且模型必须具有静态结构。- [1] Hoffman, M. 等人 (2019)
“利用神经传输中和哈密顿蒙特卡洛中的不良几何形状” https://arxiv.org/abs/1903.03704
- 参数:
guide (AutoContinuous) – 一个 guide。
params – guide 的训练参数。
- __call__(name, fn, obs)[source]
- 参数:
name (str) – 采样站点的名称。
fn (Distribution) – 一个分布。
obs (numpy.ndarray) – 观测值或 None。
- 返回:
一对 (
new_fn
,value
)。
变换后的分布
- class TransformReparam[source]
基类:
Reparam
TransformedDistribution
隐变量的重新参数化器。这对于具有复杂、改变几何形状的变换的变换分布很有用,在这种情况下,后验在
base_dist
的空间中具有简单的形状。这种重新参数化仅适用于隐变量,不适用于似然。
- __call__(name, fn, obs)[source]
- 参数:
name (str) – 采样站点的名称。
fn (Distribution) – 一个分布。
obs (numpy.ndarray) – 观测值或 None。
- 返回:
一对 (
new_fn
,value
)。
投影正态分布
圆形分布
显式重新参数化
- class ExplicitReparam(transform)[source]
基类:
Reparam
将隐变量
x
显式重新参数化到几何形状更易处理的变换空间y = transform(x)
。这个重新参数化器类似于TransformReparam
,但允许将重新参数化与模型声明解耦。- 参数:
transform – 到重新参数化空间的双射变换。
示例
>>> from jax import random >>> from jax import numpy as jnp >>> import numpyro >>> from numpyro import handlers, distributions as dist >>> from numpyro.infer import MCMC, NUTS >>> from numpyro.infer.reparam import ExplicitReparam >>> >>> def model(): ... numpyro.sample("x", dist.Gamma(4, 4)) >>> >>> # Sample in unconstrained space using a soft-plus instead of exp transform. >>> reparam = ExplicitReparam(dist.transforms.SoftplusTransform().inv) >>> reparametrized = handlers.reparam(model, {"x": reparam}) >>> kernel = NUTS(model=reparametrized) >>> mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000, num_chains=1) >>> mcmc.run(random.PRNGKey(2)) sample: 100%|██████████| 2000/2000 [00:00<00:00, 2306.47it/s, 3 steps of size 9.65e-01. acc. prob=0.93]
- __call__(name, fn, obs)[source]
- 参数:
name (str) – 采样站点的名称。
fn (Distribution) – 一个分布。
obs (numpy.ndarray) – 观测值或 None。
- 返回:
一对 (
new_fn
,value
)。