自动导引生成
我们简要概述了 NumPyro 中可用的自动生成导引。
AutoNormal 和 AutoDiagonalNormal 是我们基本的平均场导引。如果隐空间是非欧几里得的(例如由于某个采样点上的正性约束),则会自动在内部使用适当的双射变换,将无约束空间(定义正态变分分布的空间)映射到相应的约束空间(请注意,所有自动导引都是如此)。在尝试使变分推断适用于您正在开发的模型时,这些导引是一个很好的起点。
AutoMultivariateNormal 和 AutoLowRankMultivariateNormal 也构建正态变分分布,但提供了更大的灵活性,因为它们可以捕获后验中的相关性。请注意,这些导引在高维设置中可能难以拟合。
AutoBNAFNormal 和 AutoIAFNormal 提供了由归一化流参数化的灵活变分分布。
AutoDAIS 是一种强大的变分推断算法,利用了 HMC。它是处理高度相关后验的一个很好的选择,但根据模型的性质,计算成本可能会很高。
AutoSurrogateLikelihoodDAIS 是一种强大的变分推断算法,利用了 HMC 并支持数据子采样。
AutoSemiDAIS 为局部隐变量构建了一个类似于 AutoDAIS 的后验近似,但通过利用全局隐变量的参数化导引,支持在 ELBO 训练期间进行数据子采样。
AutoLaplaceApproximation 可用于计算拉普拉斯近似。
AutoGuideList 可用于组合多个自动导引。
AutoGuide
- class AutoGuide(model, *, prefix='auto', init_loc_fn=<function init_to_uniform>, create_plates=None)[source]
基类:
ABC
自动导引的基类。
派生类必须实现
__call__()
方法。- 参数:
- abstract sample_posterior(rng_key, params, *args, sample_shape=(), **kwargs)[source]
从模型中隐变量站点的近似后验生成样本。
- 参数:
rng_key (jax.random.PRNGKey) – 用于抽取样本的随机密钥。
params (dict) – 模型和自动导引的当前参数。参数可以通过
SVI
的get_params()
方法获取。args – 提供给模型/导引的参数。
sample_shape (tuple) – 每个隐变量站点的样本形状,默认为 ()。
kwargs – 提供给模型/导引的关键字参数。
- 返回值:
包含此导引抽取样本的字典。
- 返回类型:
- median(params)[source]
返回每个隐变量的后验中位数。
- 参数:
params (dict) – 包含参数值的字典。参数可以通过
SVI
的get_params()
方法获取。- 返回值:
将采样站点名称映射到中位数值的字典。
- 返回类型:
AutoGuideList
- class AutoGuideList(model, *, prefix='auto', init_loc_fn=<function init_to_uniform>, create_plates=None)[source]
基类:
AutoGuide
用于组合多个自动导引的容器类。
示例用法
rng_key_init = random.PRNGKey(0) guide = AutoGuideList(my_model) guide.append( AutoNormal( numpyro.handlers.block(model, hide=["coefs"]) ) ) guide.append( AutoDelta( numpyro.handlers.block(model, expose=["coefs"]) ) ) svi = SVI(model, guide, optim, Trace_ELBO()) svi_state = svi.init(rng_key_init, data, labels) params = svi.get_params(svi_state)
- 参数:
model (callable) – 一个 NumPyro 模型
- append(part)[source]
为模型的一部分添加一个自动或自定义导引。该导引应通过阻塞模型来限制作用于采样站点的子集。任何两个部分都不应作用于同一个采样站点。
- 参数:
part (AutoGuide) – 要添加的部分导引
- sample_posterior(rng_key, params, *args, sample_shape=(), **kwargs)[source]
从模型中隐变量站点的近似后验生成样本。
- 参数:
rng_key (jax.random.PRNGKey) – 用于抽取样本的随机密钥。
params (dict) – 模型和自动导引的当前参数。参数可以通过
SVI
的get_params()
方法获取。args – 提供给模型/导引的参数。
sample_shape (tuple) – 每个隐变量站点的样本形状,默认为 ()。
kwargs – 提供给模型/导引的关键字参数。
- 返回值:
包含此导引抽取样本的字典。
- 返回类型:
- median(params)[source]
返回每个隐变量的后验中位数。
- 参数:
params (dict) – 包含参数值的字典。参数可以通过
SVI
的get_params()
方法获取。- 返回值:
将采样站点名称映射到中位数值的字典。
- 返回类型:
AutoContinuous
- class AutoContinuous(model, *, prefix='auto', init_loc_fn=<function init_to_uniform>, create_plates=None)[source]
基类:
AutoGuide
连续值自动微分变分推断 [1] 实现的基类。
每个派生类都实现自己的
_get_posterior()
方法。假定模型结构和隐变量维度是固定的,并且所有隐变量都是连续的。
参考
自动微分变分推断, Alp Kucukelbir, Dustin Tran, Rajesh Ranganath, Andrew Gelman, David M. Blei
- 参数:
- get_base_dist()[source]
当将后验重参数化为
TransformedDistribution
时,返回其基础分布。这不应依赖于模型的 *args, **kwargs。
- get_transform(params)[source]
返回导引学习到的变换,用于从无约束(近似)后验生成样本。
- 参数:
params (dict) – 模型和自动导引的当前参数。参数可以通过
SVI
的get_params()
方法获取。- 返回值:
后验分布的变换
- 返回类型:
- get_posterior(params)[source]
返回后验分布。
- 参数:
params (dict) – 模型和自动导引的当前参数。参数可以通过
SVI
的get_params()
方法获取。
- sample_posterior(rng_key, params, *args, sample_shape=(), **kwargs)[source]
从模型中隐变量站点的近似后验生成样本。
- 参数:
rng_key (jax.random.PRNGKey) – 用于抽取样本的随机密钥。
params (dict) – 模型和自动导引的当前参数。参数可以通过
SVI
的get_params()
方法获取。args – 提供给模型/导引的参数。
sample_shape (tuple) – 每个隐变量站点的样本形状,默认为 ()。
kwargs – 提供给模型/导引的关键字参数。
- 返回值:
包含此导引抽取样本的字典。
- 返回类型:
AutoBNAFNormal
- class AutoBNAFNormal(model, *, prefix='auto', init_loc_fn=<function init_to_uniform>, num_flows=1, hidden_factors=[8, 8])[source]
基类:
AutoContinuous
此
AutoContinuous
实现使用通过BlockNeuralAutoregressiveTransform
变换的对角正态分布来构建覆盖整个隐空间的导引。此导引不依赖于模型的*args, **kwargs
。用法
guide = AutoBNAFNormal(model, num_flows=1, hidden_factors=[50, 50], ...) svi = SVI(model, guide, ...)
参考文献
Block Neural Autoregressive Flow, Nicola De Cao, Ivan Titov, Wilker Aziz
- 参数:
- get_base_dist()[source]
当将后验重参数化为
TransformedDistribution
时,返回其基础分布。这不应依赖于模型的 *args, **kwargs。
AutoDiagonalNormal
- class AutoDiagonalNormal(model, *, prefix='auto', init_loc_fn=<function init_to_uniform>, init_scale=0.1)[source]
基类:
AutoContinuous
此
AutoContinuous
实现使用具有对角协方差矩阵的正态分布来构建覆盖整个隐空间的导引。此导引不依赖于模型的*args, **kwargs
。用法
guide = AutoDiagonalNormal(model, ...) svi = SVI(model, guide, ...)
- scale_constraint = SoftplusPositive(lower_bound=0.0)
- get_base_dist()[source]
当将后验重参数化为
TransformedDistribution
时,返回其基础分布。这不应依赖于模型的 *args, **kwargs。
- get_transform(params)[source]
返回导引学习到的变换,用于从无约束(近似)后验生成样本。
- 参数:
params (dict) – 模型和自动导引的当前参数。参数可以通过
SVI
的get_params()
方法获取。- 返回值:
后验分布的变换
- 返回类型:
- median(params)[source]
返回每个隐变量的后验中位数。
- 参数:
params (dict) – 包含参数值的字典。参数可以通过
SVI
的get_params()
方法获取。- 返回值:
将采样站点名称映射到中位数值的字典。
- 返回类型:
AutoMultivariateNormal
- class AutoMultivariateNormal(model, *, prefix='auto', init_loc_fn=<function init_to_uniform>, init_scale=0.1)[source]
基类:
AutoContinuous
此
AutoContinuous
实现使用多元正态分布来构建覆盖整个隐空间的导引。此导引不依赖于模型的*args, **kwargs
。用法
guide = AutoMultivariateNormal(model, ...) svi = SVI(model, guide, ...)
- scale_tril_constraint = ScaledUnitLowerCholesky()
- get_base_dist()[source]
当将后验重参数化为
TransformedDistribution
时,返回其基础分布。这不应依赖于模型的 *args, **kwargs。
- get_transform(params)[source]
返回导引学习到的变换,用于从无约束(近似)后验生成样本。
- 参数:
params (dict) – 模型和自动导引的当前参数。参数可以通过
SVI
的get_params()
方法获取。- 返回值:
后验分布的变换
- 返回类型:
- median(params)[source]
返回每个隐变量的后验中位数。
- 参数:
params (dict) – 包含参数值的字典。参数可以通过
SVI
的get_params()
方法获取。- 返回值:
将采样站点名称映射到中位数值的字典。
- 返回类型:
AutoIAFNormal
- class AutoIAFNormal(model, *, prefix='auto', init_loc_fn=<function init_to_uniform>, num_flows=3, hidden_dims=None, skip_connections=False, nonlinearity=(<function elementwise.<locals>.<lambda>>, <function elementwise.<locals>.<lambda>>))[source]
基类:
AutoContinuous
此
AutoContinuous
实现使用通过InverseAutoregressiveTransform
变换的对角正态分布来构建覆盖整个隐空间的导引。此导引不依赖于模型的*args, **kwargs
。用法
guide = AutoIAFNormal(model, hidden_dims=[20], skip_connections=True, ...) svi = SVI(model, guide, ...)
- 参数:
model (callable) – 生成模型。
prefix (str) – 将被前置到所有 param 内部站点的字符串前缀。
init_loc_fn (callable) – 按站点的初始化函数。
num_flows (int) – 要使用的流数量,默认为 3。
hidden_dims (list) – 每层隐藏单元的维度。默认为
[latent_dim, latent_dim]
。skip_connections (bool) – 是否从每个流的输入到输出添加跳跃连接。默认为 False。
nonlinearity (callable) – 前馈网络中使用的非线性函数。默认为
jax.example_libraries.stax.Elu()
。
- get_base_dist()[source]
当将后验重参数化为
TransformedDistribution
时,返回其基础分布。这不应依赖于模型的 *args, **kwargs。
AutoLaplaceApproximation
- class AutoLaplaceApproximation(model, *, prefix='auto', init_loc_fn=<function init_to_uniform>, create_plates=None, hessian_fn=None)[source]
基类:
AutoContinuous
拉普拉斯近似(二次近似)通过无约束空间中的多元正态分布近似后验 \(\log p(z | x)\)。在底层,它使用 Delta 分布构建覆盖整个(无约束)隐空间的 MAP(即点估计)导引。其协方差由 \(-\log p(x, z)\) 在 z 的 MAP 点处的 Hessian 矩阵的逆给出。
用法
guide = AutoLaplaceApproximation(model, ...) svi = SVI(model, guide, ...)
- 参数:
hessian_fn (callable) – 实验性功能:一个函数,接受函数 f 和向量 x,并返回 f 在 x 处的 Hessian 矩阵。默认情况下,我们使用
lambda f, x: jax.hessian(f)(x)
。其他替代方案可以是lambda f, x: jax.jacobian(jax.jacobian(f))(x)
或lambda f, x: jax.hessian(f)(x) + 1e-3 * jnp.eye(x.shape[0])
。后一个示例在 f 在 x 处的 Hessian 矩阵非正定时很有帮助。请注意,输出的 Hessian 矩阵是拉普拉斯近似的精度矩阵。
- get_base_dist()[source]
当将后验重参数化为
TransformedDistribution
时,返回其基础分布。这不应依赖于模型的 *args, **kwargs。
- get_transform(params)[source]
返回导引学习到的变换,用于从无约束(近似)后验生成样本。
- 参数:
params (dict) – 模型和自动导引的当前参数。参数可以通过
SVI
的get_params()
方法获取。- 返回值:
后验分布的变换
- 返回类型:
- sample_posterior(rng_key, params, *args, sample_shape=(), **kwargs)[source]
从模型中隐变量站点的近似后验生成样本。
- 参数:
rng_key (jax.random.PRNGKey) – 用于抽取样本的随机密钥。
params (dict) – 模型和自动导引的当前参数。参数可以通过
SVI
的get_params()
方法获取。args – 提供给模型/导引的参数。
sample_shape (tuple) – 每个隐变量站点的样本形状,默认为 ()。
kwargs – 提供给模型/导引的关键字参数。
- 返回值:
包含此导引抽取样本的字典。
- 返回类型:
- median(params)[source]
返回每个隐变量的后验中位数。
- 参数:
params (dict) – 包含参数值的字典。参数可以通过
SVI
的get_params()
方法获取。- 返回值:
将采样站点名称映射到中位数值的字典。
- 返回类型:
AutoLowRankMultivariateNormal
- class AutoLowRankMultivariateNormal(model, *, prefix='auto', init_loc_fn=<function init_to_uniform>, init_scale=0.1, rank=None)[source]
基类:
AutoContinuous
此
AutoContinuous
实现使用低秩多元正态分布来构建覆盖整个隐空间的导引。此导引不依赖于模型的*args, **kwargs
。用法
guide = AutoLowRankMultivariateNormal(model, rank=2, ...) svi = SVI(model, guide, ...)
- scale_constraint = SoftplusPositive(lower_bound=0.0)
- get_base_dist()[source]
当将后验重参数化为
TransformedDistribution
时,返回其基础分布。这不应依赖于模型的 *args, **kwargs。
- get_transform(params)[source]
返回导引学习到的变换,用于从无约束(近似)后验生成样本。
- 参数:
params (dict) – 模型和自动导引的当前参数。参数可以通过
SVI
的get_params()
方法获取。- 返回值:
后验分布的变换
- 返回类型:
- median(params)[source]
返回每个隐变量的后验中位数。
- 参数:
params (dict) – 包含参数值的字典。参数可以通过
SVI
的get_params()
方法获取。- 返回值:
将采样站点名称映射到中位数值的字典。
- 返回类型:
AutoNormal
- class AutoNormal(model, *, prefix='auto', init_loc_fn=<function init_to_uniform>, init_scale=0.1, create_plates=None)[source]
基类:
AutoGuide
此
AutoGuide
实现使用正态分布来构建覆盖整个隐空间的导引。此导引不依赖于模型的*args, **kwargs
。这应该等同于
AutoDiagonalNormal
,但具有更方便的站点名称,并且更好地支持平均场 ELBO。用法
guide = AutoNormal(model) svi = SVI(model, guide, ...)
- 参数:
model (callable) – 一个 NumPyro 模型。
prefix (str) – 将被前置到所有 param 内部站点的字符串前缀。
init_loc_fn (callable) – 按站点的初始化函数。有关可用函数,请参见初始化策略部分。
init_scale (float) – 每个(无约束变换的)隐变量标准差的初始比例。
create_plates (callable) – 一个可选函数,输入与
model()
相同的*args,**kwargs
并返回一个numpyro.plate
或 plates 的可迭代对象。未返回的 plates 将照常自动创建。这对于数据子采样很有用。
- scale_constraint = SoftplusPositive(lower_bound=0.0)
- sample_posterior(rng_key, params, *args, sample_shape=(), **kwargs)[source]
从模型中隐变量站点的近似后验生成样本。
- 参数:
rng_key (jax.random.PRNGKey) – 用于抽取样本的随机密钥。
params (dict) – 模型和自动导引的当前参数。参数可以通过
SVI
的get_params()
方法获取。args – 提供给模型/导引的参数。
sample_shape (tuple) – 每个隐变量站点的样本形状,默认为 ()。
kwargs – 提供给模型/导引的关键字参数。
- 返回值:
包含此导引抽取样本的字典。
- 返回类型:
- median(params)[source]
返回每个隐变量的后验中位数。
- 参数:
params (dict) – 包含参数值的字典。参数可以通过
SVI
的get_params()
方法获取。- 返回值:
将采样站点名称映射到中位数值的字典。
- 返回类型:
AutoDelta
- class AutoDelta(model, *, prefix='auto', init_loc_fn=<function init_to_median>, create_plates=None)[source]
基类:
AutoGuide
此
AutoGuide
实现使用 Delta 分布构建覆盖整个隐空间的 MAP 导引。此导引不依赖于模型的*args, **kwargs
。注意
此类在约束空间中进行 MAP 推断。
用法
guide = AutoDelta(model) svi = SVI(model, guide, ...)
- 参数:
- sample_posterior(rng_key, params, *args, sample_shape=(), **kwargs)[source]
从模型中隐变量站点的近似后验生成样本。
- 参数:
rng_key (jax.random.PRNGKey) – 用于抽取样本的随机密钥。
params (dict) – 模型和自动导引的当前参数。参数可以通过
SVI
的get_params()
方法获取。args – 提供给模型/导引的参数。
sample_shape (tuple) – 每个隐变量站点的样本形状,默认为 ()。
kwargs – 提供给模型/导引的关键字参数。
- 返回值:
包含此导引抽取样本的字典。
- 返回类型:
AutoDAIS
- class AutoDAIS(model, *, K=4, base_dist='diagonal', eta_init=0.01, eta_max=0.1, gamma_init=0.9, prefix='auto', init_loc_fn=<function init_to_uniform>, init_scale=0.1)[source]
基类:
AutoContinuous
此
AutoDAIS
实现使用可微分退火重要性采样 (DAIS) [1, 2] 构建覆盖整个隐空间的导引。变分分布(即导引)的样本通过(未校正的)哈密顿蒙特卡洛和退火重要性采样的组合生成。在 [1] 中,同一算法被称为“未校正哈密顿退火”。请注意,AutoDAIS 不能与数据子采样结合使用。
参考
MCMC Variational Inference via Uncorrected Hamiltonian Annealing, Tomas Geffner, Justin Domke
Differentiable Annealed Importance Sampling and the Perils of Gradient Noise, Guodong Zhang, Kyle Hsu, Jianing Li, Chelsea Finn, Roger Grosse
用法
guide = AutoDAIS(model) svi = SVI(model, guide, ...)
- 参数:
model (callable) – 一个 NumPyro 模型。
prefix (str) – 将前置到所有 param 内部站点的字符串前缀。
K (int) – 控制使用的 HMC 步数的正整数。默认为 4。
base_dist (str) – 控制基础正态变分分布是由“对角线”协方差矩阵参数化还是由下三角“Cholesky”因子参数化的满秩协方差矩阵。默认为“对角线”。
eta_init (float) – HMC 中使用的步长的初始值。默认为 0.01。
eta_max (float) – HMC 中可学习步长的最大值。默认为 0.1。
gamma_init (float) – HMC 中部分动量刷新期间使用的可学习阻尼因子的初始值。默认为 0.9。
init_loc_fn (callable) – 按站点的初始化函数。有关可用函数,请参见初始化策略部分。
init_scale (float) – 每个(无约束变换的)隐变量基础变分分布标准差的初始比例。默认为 0.1。
- sample_posterior(rng_key, params, *args, sample_shape=(), **kwargs)[source]
从模型中隐变量站点的近似后验生成样本。
- 参数:
rng_key (jax.random.PRNGKey) – 用于抽取样本的随机密钥。
params (dict) – 模型和自动导引的当前参数。参数可以通过
SVI
的get_params()
方法获取。args – 提供给模型/导引的参数。
sample_shape (tuple) – 每个隐变量站点的样本形状,默认为 ()。
kwargs – 提供给模型/导引的关键字参数。
- 返回值:
包含此导引抽取样本的字典。
- 返回类型:
AutoSemiDAIS
- class AutoSemiDAIS(model, local_model, global_guide=None, local_guide=None, *, prefix='auto', K=4, eta_init=0.01, eta_max=0.1, gamma_init=0.9, init_scale=0.1, subsample_plate=None, use_global_dais_params=False)[source]
基类:
AutoGuide
此
AutoSemiDAIS
实现 [1] 将全局潜在变量的参数化变分分布与可微分退火重要性采样 (DAIS) [2, 3] 相结合,用于推断局部潜在变量。与AutoDAIS
不同,此引导可与数据子采样结合使用。请注意,得到的 ELBO 可被理解为参考文献 [4] 中描述的“局部增强界”的特定实现。参考文献
Variational Inference with Locally Enhanced Bounds for Hierarchical Models,Tomas Geffner,Justin Domke
MCMC Variational Inference via Uncorrected Hamiltonian Annealing, Tomas Geffner, Justin Domke
Differentiable Annealed Importance Sampling and the Perils of Gradient Noise, Guodong Zhang, Kyle Hsu, Jianing Li, Chelsea Finn, Roger Grosse
Surrogate Likelihoods for Variational Annealed Importance Sampling,Martin Jankowiak,Du Phan
用法
def global_model(): return numpyro.sample("theta", dist.Normal(0, 1)) def local_model(theta): with numpyro.plate("data", 8, subsample_size=2): tau = numpyro.sample("tau", dist.Gamma(5.0, 5.0)) numpyro.sample("obs", dist.Normal(0.0, tau), obs=jnp.ones(2)) model = lambda: local_model(global_model()) global_guide = AutoNormal(global_model) guide = AutoSemiDAIS(model, local_model, global_guide, K=4) svi = SVI(model, guide, ...) # sample posterior for particular data subset {3, 7} with handlers.substitute(data={"data": jnp.array([3, 7])}): samples = guide.sample_posterior(random.PRNGKey(1), params)
- 参数:
model (callable) – 包含全局和局部潜在变量的 NumPyro 模型。
local_model (callable) – model 中仅包含局部潜在变量的部分。local_model 的签名应为仅包含全局潜在变量的全局模型的返回类型。
global_guide (callable) – 全局潜在变量的引导,例如 autoguide。返回类型应为潜在样本站点名称和相应样本的字典。如果模型中没有全局变量,可以将其设置为 None。
local_guide (callable) – 可选的引导,用于指定局部潜在变量的 DAIS 基础分布。
prefix (str) – 将作为前缀添加到所有内部站点的字符串。
K (int) – 控制使用的 HMC 步数的正整数。默认为 4。
eta_init (float) – HMC 中使用的步长的初始值。默认为 0.01。
eta_max (float) – HMC 中可学习步长的最大值。默认为 0.1。
gamma_init (float) – HMC 中部分动量刷新期间使用的可学习阻尼因子的初始值。默认为 0.9。
init_scale (float) – 每个(无约束变换后的)局部潜在变量的变分分布标准差的初始尺度。默认为 0.1。
subsample_plate (str) – 子采样 plate 站点的可选名称。当模型具有未指定 subsample_size 的子采样 plate 或模型具有 subsample_size 等于 plate 大小的子采样 plate 时,这是必需的。
use_global_dais_params (bool) – 控制 DAIS 动态(HMC 步长、HMC 质量矩阵等)的参数应该是全局的(即,子采样 plate 中所有数据点共用)还是局部的(即,子采样 plate 中的每个数据点都有独立的参数)。注意,我们不对基础分布使用全局参数。
- sample_posterior(rng_key, params, *args, sample_shape=(), **kwargs)[source]
从模型中隐变量站点的近似后验生成样本。
- 参数:
rng_key (jax.random.PRNGKey) – 用于抽取样本的随机密钥。
params (dict) – 模型和自动导引的当前参数。参数可以通过
SVI
的get_params()
方法获取。args – 提供给模型/导引的参数。
sample_shape (tuple) – 每个隐变量站点的样本形状,默认为 ()。
kwargs – 提供给模型/导引的关键字参数。
- 返回值:
包含此导引抽取样本的字典。
- 返回类型:
AutoSurrogateLikelihoodDAIS
- class AutoSurrogateLikelihoodDAIS(model, surrogate_model, *, K=4, eta_init=0.01, eta_max=0.1, gamma_init=0.9, prefix='auto', base_dist='diagonal', init_loc_fn=<function init_to_uniform>, init_scale=0.1)[source]
基类:
AutoDAIS
此
AutoSurrogateLikelihoodDAIS
实现提供了参考文献 [1] 中描述的支持 mini-batch 的变分分布族。它结合了用户提供的代理似然与可微分退火重要性采样 (DAIS) [2, 3]。它不适用于包含局部潜在变量的模型(参阅AutoSemiDAIS
),但与AutoDAIS
不同的是,它可以结合数据子采样使用。参考
Surrogate likelihoods for variational annealed importance sampling,Martin Jankowiak,Du Phan
MCMC Variational Inference via Uncorrected Hamiltonian Annealing, Tomas Geffner, Justin Domke
Differentiable Annealed Importance Sampling and the Perils of Gradient Noise, Guodong Zhang, Kyle Hsu, Jianing Li, Chelsea Finn, Roger Grosse
用法
# logistic regression model for data {X, Y} def model(X, Y): theta = numpyro.sample( "theta", dist.Normal(jnp.zeros(2), jnp.ones(2)).to_event(1) ) with numpyro.plate("N", 100, subsample_size=10): X_batch = numpyro.subsample(X, event_dim=1) Y_batch = numpyro.subsample(Y, event_dim=0) numpyro.sample("obs", dist.Bernoulli(logits=theta @ X_batch.T), obs=Y_batch) # surrogate model defined by prior and surrogate likelihood. # a convenient choice for specifying the latter is to compute the likelihood on # a randomly chosen data subset (here {X_surr, Y_surr} of size 20) and then use # handlers.scale to scale the log likelihood by a vector of learnable weights. def surrogate_model(X_surr, Y_surr): theta = numpyro.sample( "theta", dist.Normal(jnp.zeros(2), jnp.ones(2)).to_event(1) ) omegas = numpyro.param( "omegas", 5.0 * jnp.ones(20), constraint=dist.constraints.positive ) with numpyro.plate("N", 20), numpyro.handlers.scale(scale=omegas): numpyro.sample("obs", dist.Bernoulli(logits=theta @ X_surr.T), obs=Y_surr) guide = AutoSurrogateLikelihoodDAIS(model, surrogate_model) svi = SVI(model, guide, ...)
- 参数:
model (callable) – 一个 NumPyro 模型。
surrogate_model (callable) – 用作代理模型以引导定义变分分布的 HMC 动力学的 NumPyro 模型。特别是,surrogate_model 应该包含与 model 相同的先验,但应包含一个易于评估的似然参数化 ansatz。后者的一个简单 ansatz 涉及计算固定数据子集的似然,并通过一个可学习的正权重向量对结果的对数似然进行缩放。请参阅上面的使用示例。
prefix (str) – 将前置到所有 param 内部站点的字符串前缀。
K (int) – 控制使用的 HMC 步数的正整数。默认为 4。
base_dist (str) – 控制基础正态变分分布是由“对角线”协方差矩阵参数化还是由下三角“Cholesky”因子参数化的满秩协方差矩阵。默认为“对角线”。
eta_init (float) – HMC 中使用的步长的初始值。默认为 0.01。
eta_max (float) – HMC 中可学习步长的最大值。默认为 0.1。
gamma_init (float) – HMC 中部分动量刷新期间使用的可学习阻尼因子的初始值。默认为 0.9。
init_loc_fn (callable) – 按站点的初始化函数。有关可用函数,请参见初始化策略部分。
init_scale (float) – 每个(无约束变换的)隐变量基础变分分布标准差的初始比例。默认为 0.1。