分布
基础分布
Distribution
- class Distribution(batch_shape=(), event_shape=(), *, validate_args=None)[source]
基类:
object
NumPyro 中概率分布的基类。设计主要借鉴了
torch.distributions
。- 参数:
batch_shape – 分布的批次形状。这指定了从分布中抽取的样本的独立(可能不完全相同)维度。对于一个分布实例,这是固定的,并从分布参数的形状推断出来。
event_shape – 分布的事件形状。这指定了从分布中抽取的样本的依赖维度。当我们使用 .log_prob 计算一批样本的对数概率密度时,这些维度会被折叠。
validate_args – 是否启用分布参数和 .log_prob 方法参数的验证。
例如
>>> import jax.numpy as jnp >>> import numpyro.distributions as dist >>> d = dist.Dirichlet(jnp.ones((2, 3, 4))) >>> d.batch_shape (2, 3) >>> d.event_shape (4,)
- arg_constraints = {}
- support = None
- has_enumerate_support = False
- reparametrized_params = []
- pytree_data_fields = ()
- pytree_aux_fields = ('_batch_shape', '_event_shape')
- validate_args(strict: bool = True) None [source]
验证分布的参数。
- 参数:
strict – 要求严格验证,如果在 JIT 编译的代码内调用该函数,则会引发错误。
- rsample(key, sample_shape=()) Array | ndarray | bool_ | number | bool | int | float | complex [source]
- shape(sample_shape=()) tuple[int, ...] [source]
此分布的样本的张量形状。
样本的形状为
d.shape(sample_shape) == sample_shape + d.batch_shape + d.event_shape
- sample(key: Array | ndarray | bool_ | number | bool | int | float | complex, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool_ | number | bool | int | float | complex [source]
返回形状为 sample_shape + batch_shape + event_shape 的分布样本。注意,当 sample_shape 非空时,返回样本的前导维度(大小为 sample_shape)将填充从分布实例中独立同分布抽取的样本。
- 参数:
key (jax.random.PRNGKey) – 用于分布的 rng_key。
sample_shape (tuple) – 分布的样本形状。
- 返回:
形状为 sample_shape + batch_shape + event_shape 的数组
- 返回类型:
- sample_with_intermediates(key, sample_shape=())[source]
与
sample
相同,但会返回所有中间计算结果(对 TransformedDistribution 有用)。- 参数:
key (jax.random.PRNGKey) – 用于分布的 rng_key。
sample_shape (tuple) – 分布的样本形状。
- 返回:
形状为 sample_shape + batch_shape + event_shape 的数组
- 返回类型:
- log_prob(value)[source]
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
- to_event(reinterpreted_batch_ndims=None)[source]
将最右侧的 reinterpreted_batch_ndims 个批次维度解释为相关的事件维度。
- 参数:
reinterpreted_batch_ndims – 要解释为事件维度的最右侧批次维度数量。
- 返回:
Independent 分布的实例。
- 返回类型:
- expand(batch_shape)[source]
返回一个批次维度扩展到 batch_shape 的新的
ExpandedDistribution
实例。- 参数:
batch_shape (tuple) – 要扩展到的批次形状。
- 返回:
ExpandedDistribution 的实例。
- 返回类型:
- expand_by(sample_shape)[source]
通过在分布的
batch_shape
左侧添加sample_shape
来扩展分布。要将self.batch_shape
的内部维度从 1 扩展到更大的值,请改用expand()
。- 参数:
sample_shape (tuple) – 从分布中抽取的 iid 批次大小。
- 返回:
此分布的扩展版本。
- 返回类型:
- mask(mask)[source]
通过一个可广播到分布的
Distribution.batch_shape
的布尔值或布尔值数组对分布进行掩码操作。- 参数:
mask (jnp.ndarray 或 bool) – 布尔值或布尔值数组(True 包含一个站点,False 排除一个站点)。
- 返回:
此分布的掩码副本。
- 返回类型:
示例
>>> from jax import random >>> import jax.numpy as jnp >>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.distributions import constraints >>> from numpyro.infer import SVI, Trace_ELBO >>> def model(data, m): ... f = numpyro.sample("latent_fairness", dist.Beta(1, 1)) ... with numpyro.plate("N", data.shape[0]): ... # only take into account the values selected by the mask ... masked_dist = dist.Bernoulli(f).mask(m) ... numpyro.sample("obs", masked_dist, obs=data) >>> def guide(data, m): ... alpha_q = numpyro.param("alpha_q", 5., constraint=constraints.positive) ... beta_q = numpyro.param("beta_q", 5., constraint=constraints.positive) ... numpyro.sample("latent_fairness", dist.Beta(alpha_q, beta_q)) >>> data = jnp.concatenate([jnp.ones(5), jnp.zeros(5)]) >>> # select values equal to one >>> masked_array = jnp.where(data == 1, True, False) >>> optimizer = numpyro.optim.Adam(step_size=0.05) >>> svi = SVI(model, guide, optimizer, loss=Trace_ELBO()) >>> svi_result = svi.run(random.PRNGKey(0), 300, data, masked_array) >>> params = svi_result.params >>> # inferred_mean is closer to 1 >>> inferred_mean = params["alpha_q"] / (params["alpha_q"] + params["beta_q"])
- classmethod infer_shapes(*args, **kwargs)[source]
根据
__init__()
参数的形状推断batch_shape
和event_shape
。注意
这假设分布形状仅依赖于张量输入的形状,而不依赖于输入中包含的数据。
- 参数:
*args – 位置参数,每个输入参数由表示每个张量输入大小的元组代替。
**kwargs – 关键字参数,将输入参数名称映射到表示每个张量输入大小的元组。
- 返回:
一对表示分布形状的元组
(batch_shape, event_shape)
,该分布将使用给定形状的输入参数创建。- 返回类型:
- cdf(value: Array | ndarray | bool_ | number | bool | int | float | complex) Array | ndarray | bool_ | number | bool | int | float | complex [source]
此分布的累积分布函数。
- 参数:
value – 此分布的样本。
- 返回:
在 value 处评估的累积分布函数的输出。
- icdf(q: Array | ndarray | bool_ | number | bool | int | float | complex) Array | ndarray | bool_ | number | bool | int | float | complex [source]
此分布的逆累积分布函数。
- 参数:
q – 分位数,应属于 [0, 1]。
- 返回:
其 cdf 值等于 q 的样本。
- property is_discrete
ExpandedDistribution
- class ExpandedDistribution(base_dist, batch_shape=())[source]
基类:
Distribution
- arg_constraints = {}
- pytree_data_fields = ('base_dist',)
- pytree_aux_fields = ('_expanded_sizes', '_interstitial_sizes')
- property has_enumerate_support
bool(x) -> bool
当参数 x 为 True 时返回 True,否则返回 False。内置的 True 和 False 是 bool 类的唯二实例。bool 类是 int 类的子类,不能再被继承。
- property has_rsample
- property support
- sample_with_intermediates(key, sample_shape=())[source]
与
sample
相同,但会返回所有中间计算结果(对 TransformedDistribution 有用)。- 参数:
key (jax.random.PRNGKey) – 用于分布的 rng_key。
sample_shape (tuple) – 分布的样本形状。
- 返回:
形状为 sample_shape + batch_shape + event_shape 的数组
- 返回类型:
- sample(key, sample_shape=())[source]
返回形状为 sample_shape + batch_shape + event_shape 的分布样本。注意,当 sample_shape 非空时,返回样本的前导维度(大小为 sample_shape)将填充从分布实例中独立同分布抽取的样本。
- 参数:
key (jax.random.PRNGKey) – 用于分布的 rng_key。
sample_shape (tuple) – 分布的样本形状。
- 返回:
形状为 sample_shape + batch_shape + event_shape 的数组
- 返回类型:
- log_prob(value, intermediates=None)[source]
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
- property mean
分布的均值。
- property variance
分布的方差。
FoldedDistribution
- class FoldedDistribution(base_dist, *, validate_args=None)[source]
-
等同于
TransformedDistribution(base_dist, AbsTransform())
,但额外支持log_prob()
。- 参数:
base_dist (Distribution) – 要反射的单变量分布。
- support = Positive(lower_bound=0.0)
- log_prob(*args, **kwargs)
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
ImproperUniform
- class ImproperUniform(support, batch_shape, event_shape, *, validate_args=None)[source]
基类:
Distribution
一个辅助分布,在 support 域上
log_prob()
为零。注意
此分布未实现 sample 方法。在 autoguide 和 mcmc 中,非正则站点(improper sites)的初始参数来自 init_to_uniform 或 init_to_value 策略。
用法
>>> from numpyro import sample >>> from numpyro.distributions import ImproperUniform, Normal, constraints >>> >>> def model(): ... # ordered vector with length 10 ... x = sample('x', ImproperUniform(constraints.ordered_vector, (), event_shape=(10,))) ... ... # real matrix with shape (3, 4) ... y = sample('y', ImproperUniform(constraints.real, (), event_shape=(3, 4))) ... ... # a shape-(6, 8) batch of length-5 vectors greater than 3 ... z = sample('z', ImproperUniform(constraints.greater_than(3), (6, 8), event_shape=(5,)))
如果你想对所有大于 a 的值设置非正则先验,其中 a 是另一个随机变量,你可以使用
>>> def model(): ... a = sample('a', Normal(0, 1)) ... x = sample('x', ImproperUniform(constraints.greater_than(a), (), event_shape=()))
或者如果你想对其进行重参数化
>>> from numpyro.distributions import TransformedDistribution, transforms >>> from numpyro.handlers import reparam >>> from numpyro.infer.reparam import TransformReparam >>> >>> def model(): ... a = sample('a', Normal(0, 1)) ... with reparam(config={'x': TransformReparam()}): ... x = sample('x', ... TransformedDistribution(ImproperUniform(constraints.positive, (), ()), ... transforms.AffineTransform(a, 1)))
- 参数:
support (Constraint) – 此分布的支持集。
batch_shape (tuple) – 此分布的批次形状。通常可以安全地设置为 batch_shape=()。
event_shape (tuple) – 此分布的事件形状。
- arg_constraints = {}
- pytree_data_fields = ('support',)
- support = Dependent()
- log_prob(*args, **kwargs)
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
Independent
- class Independent(base_dist, reinterpreted_batch_ndims, *, validate_args=None)[source]
基类:
Distribution
通过将批次-事件维度边界进一步向左移动,将分布的批次维度重新解释为事件维度。
从实践角度来看,这在改变
log_prob()
的结果时很有用。例如,单变量正态分布可以解释为具有对角协方差的多元正态分布>>> import numpyro.distributions as dist >>> normal = dist.Normal(jnp.zeros(3), jnp.ones(3)) >>> [normal.batch_shape, normal.event_shape] [(3,), ()] >>> diag_normal = dist.Independent(normal, 1) >>> [diag_normal.batch_shape, diag_normal.event_shape] [(), (3,)]
- 参数:
base_distribution (numpyro.distribution.Distribution) – 一个分布实例。
reinterpreted_batch_ndims (int) – 要重新解释为事件维度的批次维度数量。
- arg_constraints = {}
- pytree_data_fields = ('base_dist',)
- pytree_aux_fields = ('reinterpreted_batch_ndims',)
- property support
- property has_enumerate_support
bool(x) -> bool
当参数 x 为 True 时返回 True,否则返回 False。内置的 True 和 False 是 bool 类的唯二实例。bool 类是 int 类的子类,不能再被继承。
- property reparametrized_params
内置的可变序列。
如果没有给出参数,构造函数将创建一个新的空列表。如果指定了参数,则参数必须是可迭代的。
- property mean
分布的均值。
- property variance
分布的方差。
- property has_rsample
- sample(key, sample_shape=())[source]
返回形状为 sample_shape + batch_shape + event_shape 的分布样本。注意,当 sample_shape 非空时,返回样本的前导维度(大小为 sample_shape)将填充从分布实例中独立同分布抽取的样本。
- 参数:
key (jax.random.PRNGKey) – 用于分布的 rng_key。
sample_shape (tuple) – 分布的样本形状。
- 返回:
形状为 sample_shape + batch_shape + event_shape 的数组
- 返回类型:
- log_prob(value)[source]
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
- expand(batch_shape)[source]
返回一个批次维度扩展到 batch_shape 的新的
ExpandedDistribution
实例。- 参数:
batch_shape (tuple) – 要扩展到的批次形状。
- 返回:
ExpandedDistribution 的实例。
- 返回类型:
MaskedDistribution
- class MaskedDistribution(base_dist, mask)[source]
基类:
Distribution
通过一个可广播到分布的
Distribution.batch_shape
的布尔数组对分布进行掩码操作。在特殊情况下,如果mask is False
,则跳过log_prob()
的计算,并返回常量零值。- 参数:
mask (jnp.ndarray 或 bool) – 布尔值或布尔值数组。
- arg_constraints = {}
- pytree_data_fields = ('base_dist', '_mask')
- pytree_aux_fields = ('_mask',)
- property has_enumerate_support
bool(x) -> bool
当参数 x 为 True 时返回 True,否则返回 False。内置的 True 和 False 是 bool 类的唯二实例。bool 类是 int 类的子类,不能再被继承。
- property has_rsample
- property support
- sample(key, sample_shape=())[source]
返回形状为 sample_shape + batch_shape + event_shape 的分布样本。注意,当 sample_shape 非空时,返回样本的前导维度(大小为 sample_shape)将填充从分布实例中独立同分布抽取的样本。
- 参数:
key (jax.random.PRNGKey) – 用于分布的 rng_key。
sample_shape (tuple) – 分布的样本形状。
- 返回:
形状为 sample_shape + batch_shape + event_shape 的数组
- 返回类型:
- log_prob(value)[source]
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
- property mean
分布的均值。
- property variance
分布的方差。
TransformedDistribution
- class TransformedDistribution(base_distribution, transforms, *, validate_args=None)[source]
基类:
Distribution
返回通过对基础分布应用一系列变换获得的分布实例。例如,参见
LogNormal
和HalfNormal
。- 参数:
base_distribution – 应用变换的基础分布。
transforms – 单个变换或变换列表。
validate_args – 是否启用分布参数和 .log_prob 方法参数的验证。
- arg_constraints = {}
- pytree_data_fields = ('base_dist', 'transforms')
- property has_rsample
- property support
- sample(key, sample_shape=())[source]
返回形状为 sample_shape + batch_shape + event_shape 的分布样本。注意,当 sample_shape 非空时,返回样本的前导维度(大小为 sample_shape)将填充从分布实例中独立同分布抽取的样本。
- 参数:
key (jax.random.PRNGKey) – 用于分布的 rng_key。
sample_shape (tuple) – 分布的样本形状。
- 返回:
形状为 sample_shape + batch_shape + event_shape 的数组
- 返回类型:
- sample_with_intermediates(key, sample_shape=())[source]
与
sample
相同,但会返回所有中间计算结果(对 TransformedDistribution 有用)。- 参数:
key (jax.random.PRNGKey) – 用于分布的 rng_key。
sample_shape (tuple) – 分布的样本形状。
- 返回:
形状为 sample_shape + batch_shape + event_shape 的数组
- 返回类型:
- log_prob(*args, **kwargs)
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
- property mean
分布的均值。
- property variance
分布的方差。
Delta
- class Delta(v=0.0, log_density=0.0, event_dim=0, *, validate_args=None)[source]
基类:
Distribution
- arg_constraints = {'log_density': 实数(), 'v': 依赖()}
- reparametrized_params = ['v', 'log_density']
- property support
- sample(key, sample_shape=())[source]
返回形状为 sample_shape + batch_shape + event_shape 的分布样本。注意,当 sample_shape 非空时,返回样本的前导维度(大小为 sample_shape)将填充从分布实例中独立同分布抽取的样本。
- 参数:
key (jax.random.PRNGKey) – 用于分布的 rng_key。
sample_shape (tuple) – 分布的样本形状。
- 返回:
形状为 sample_shape + batch_shape + event_shape 的数组
- 返回类型:
- log_prob(*args, **kwargs)
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
- property mean
分布的均值。
- property variance
分布的方差。
Unit
- class Unit(log_factor, *, validate_args=None)[source]
基类:
Distribution
表示单位类型的非标准化琐碎分布。
单位类型只有一个无数据的值,即
value.size == 0
。这用于
numpyro.factor()
语句。- arg_constraints = {'log_factor': 实数()}
- support = 实数()
- sample(key, sample_shape=())[source]
返回形状为 sample_shape + batch_shape + event_shape 的分布样本。注意,当 sample_shape 非空时,返回样本的前导维度(大小为 sample_shape)将填充从分布实例中独立同分布抽取的样本。
- 参数:
key (jax.random.PRNGKey) – 用于分布的 rng_key。
sample_shape (tuple) – 分布的样本形状。
- 返回:
形状为 sample_shape + batch_shape + event_shape 的数组
- 返回类型:
连续分布
AsymmetricLaplace
- class AsymmetricLaplace(loc=0.0, scale=1.0, asymmetry=1.0, *, validate_args=None)[source]
基类:
Distribution
- arg_constraints = {'asymmetry': 正数(下界=0.0), 'loc': 实数(), 'scale': 正数(下界=0.0)}
- reparametrized_params = ['loc', 'scale', 'asymmetry']
- support = 实数()
- log_prob(value)[source]
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
- sample(key, sample_shape=())[source]
返回形状为 sample_shape + batch_shape + event_shape 的分布样本。注意,当 sample_shape 非空时,返回样本的前导维度(大小为 sample_shape)将填充从分布实例中独立同分布抽取的样本。
- 参数:
key (jax.random.PRNGKey) – 用于分布的 rng_key。
sample_shape (tuple) – 分布的样本形状。
- 返回:
形状为 sample_shape + batch_shape + event_shape 的数组
- 返回类型:
- property mean
分布的均值。
- property variance
分布的方差。
AsymmetricLaplaceQuantile
- class AsymmetricLaplaceQuantile(loc=0.0, scale=1.0, quantile=0.5, *, validate_args=None)[source]
基类:
Distribution
非对称拉普拉斯 (AsymmetricLaplace) 分布的一种替代参数化,常用于贝叶斯分位数回归。
与 AsymmetricLaplace 使用的 asymmetry 参数来定义分布左右两边的平衡不同,此类使用 quantile 参数,它描述了落在分布左侧的概率密度比例。
scale 参数的解释也与 AsymmetricLaplace 略有不同。当
loc=0
和scale=1
时,AsymmetricLaplace(0,1,1) 等同于 Laplace(0,1),而 AsymmetricLaplaceQuantile(0,1,0.5) 等同于 Laplace(0,2)。- arg_constraints = {'loc': 实数(), 'quantile': 开区间(下界=0.0, 上界=1.0), 'scale': 正数(下界=0.0)}
- reparametrized_params = ['loc', 'scale', 'quantile']
- support = 实数()
- pytree_data_fields = ('loc', 'scale', 'quantile', '_ald')
- log_prob(value)[source]
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
- sample(key, sample_shape=())[source]
返回形状为 sample_shape + batch_shape + event_shape 的分布样本。注意,当 sample_shape 非空时,返回样本的前导维度(大小为 sample_shape)将填充从分布实例中独立同分布抽取的样本。
- 参数:
key (jax.random.PRNGKey) – 用于分布的 rng_key。
sample_shape (tuple) – 分布的样本形状。
- 返回:
形状为 sample_shape + batch_shape + event_shape 的数组
- 返回类型:
- property mean
分布的均值。
- property variance
分布的方差。
Beta
- class Beta(concentration1, concentration0, *, validate_args=None)[source]
基类:
Distribution
- arg_constraints = {'concentration0': 正数(下界=0.0), 'concentration1': 正数(下界=0.0)}
- reparametrized_params = ['concentration1', 'concentration0']
- support = 单位区间(下界=0.0, 上界=1.0)
- pytree_data_fields = ('concentration0', 'concentration1', '_dirichlet')
- sample(key, sample_shape=())[source]
返回形状为 sample_shape + batch_shape + event_shape 的分布样本。注意,当 sample_shape 非空时,返回样本的前导维度(大小为 sample_shape)将填充从分布实例中独立同分布抽取的样本。
- 参数:
key (jax.random.PRNGKey) – 用于分布的 rng_key。
sample_shape (tuple) – 分布的样本形状。
- 返回:
形状为 sample_shape + batch_shape + event_shape 的数组
- 返回类型:
- log_prob(*args, **kwargs)
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
- property mean
分布的均值。
- property variance
分布的方差。
BetaProportion
- class BetaProportion(mean, concentration, *, validate_args=None)[source]
基类:
Beta
BetaProportion 分布是传统 Beta 分布的一种重新参数化,它使用变量均值和精度参数。
- 参考文献
- 用于建模比率和比例的 Beta 回归,Ferrari Silvia,以及
Francisco Cribari-Neto。Journal of Applied Statistics 31.7 (2004): 799-815。
- arg_constraints = {'concentration': 正数(下界=0.0), 'mean': 开区间(下界=0.0, 上界=1.0)}
- reparametrized_params = ['mean', 'concentration']
- support = 单位区间(下界=0.0, 上界=1.0)
- pytree_data_fields = ('concentration',)
CAR
- class CAR(loc, correlation, conditional_precision, adj_matrix, *, is_sparse=False, validate_args=None)[source]
基类:
Distribution
条件自回归 (CAR) 分布是多元正态分布的一个特例,其中精度矩阵根据站点的邻接矩阵构建。站点之间的自相关量由
correlation
控制。该分布是区域空间数据的流行先验。- 参数:
loc (float 或 ndarray) – 多元正态分布的均值
correlation (float) – 自回归参数。在大多数情况下,该值应介于 0(站点独立,退化为 iid 多元正态分布)和 1(站点间完美自相关)之间,但规范允许负相关。
conditional_precision (float) – 多元正态分布的正精度
adj_matrix (ndarray 或 scipy.sparse.csr_matrix) – 对称邻接矩阵,其中 1 表示站点间邻接,0 表示不邻接。
jax.numpy.ndarray
adj_matrix
受支持,但不推荐使用,建议使用numpy.ndarray
或scipy.sparse.spmatrix
。is_sparse (bool) – 在计算中是否使用 `adj_matrix` 的稀疏形式(如果 `adj_matrix` 是
scipy.sparse.spmatrix
,则必须为 True)
- arg_constraints = {'adj_matrix': 依赖(), 'conditional_precision': 正数(下界=0.0), 'correlation': 开区间(下界=-1, 上界=1), 'loc': 实数向量(实数(), 1)}
- support = 实数向量(实数(), 1)
- reparametrized_params = ['loc', 'correlation', 'conditional_precision', 'adj_matrix']
- pytree_aux_fields = ('is_sparse', 'adj_matrix')
- sample(key, sample_shape=())[source]
返回形状为 sample_shape + batch_shape + event_shape 的分布样本。注意,当 sample_shape 非空时,返回样本的前导维度(大小为 sample_shape)将填充从分布实例中独立同分布抽取的样本。
- 参数:
key (jax.random.PRNGKey) – 用于分布的 rng_key。
sample_shape (tuple) – 分布的样本形状。
- 返回:
形状为 sample_shape + batch_shape + event_shape 的数组
- 返回类型:
- log_prob(*args, **kwargs)
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
- property mean
分布的均值。
- static infer_shapes(loc, correlation, conditional_precision, adj_matrix)[source]
根据
__init__()
参数的形状推断batch_shape
和event_shape
。注意
这假设分布形状仅依赖于张量输入的形状,而不依赖于输入中包含的数据。
- 参数:
*args – 位置参数,每个输入参数由表示每个张量输入大小的元组代替。
**kwargs – 关键字参数,将输入参数名称映射到表示每个张量输入大小的元组。
- 返回:
一对表示分布形状的元组
(batch_shape, event_shape)
,该分布将使用给定形状的输入参数创建。- 返回类型:
Cauchy
- class Cauchy(loc=0.0, scale=1.0, *, validate_args=None)[source]
基类:
Distribution
- arg_constraints = {'loc': 实数(), 'scale': 正数(下界=0.0)}
- support = 实数()
- reparametrized_params = ['loc', 'scale']
- sample(key, sample_shape=())[source]
返回形状为 sample_shape + batch_shape + event_shape 的分布样本。注意,当 sample_shape 非空时,返回样本的前导维度(大小为 sample_shape)将填充从分布实例中独立同分布抽取的样本。
- 参数:
key (jax.random.PRNGKey) – 用于分布的 rng_key。
sample_shape (tuple) – 分布的样本形状。
- 返回:
形状为 sample_shape + batch_shape + event_shape 的数组
- 返回类型:
- log_prob(*args, **kwargs)
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
- property mean
分布的均值。
- property variance
分布的方差。
Chi2
CirculantNormal
- class CirculantNormal(loc: Array, covariance_row: Array = None, covariance_rfft: Array = None, *, validate_args=None)[source]
-
具有正定循环协方差矩阵 \(\mathbf{C}\) 的多元正态分布 [1],即具有周期性边界条件。样本 \(\mathbf{x}\in\mathbb{R}^n\) 的密度是标准多元正态密度
\[p\left(\mathbf{x}\mid\boldsymbol{\mu},\mathbf{C}\right) = \frac{\left(\mathrm{det}\,\mathbf{C}\right)^{-1/2}}{\left(2\pi\right)^{n / 2}} \exp\left(-\frac{1}{2}\left(\mathbf{x}-\boldsymbol{\mu}\right)^\intercal \mathbf{C}^{-1}\left(\mathbf{x}-\boldsymbol{\mu}\right)\right),\]其中 \(\mathrm{det}\) 表示行列式,\(^\intercal\) 表示转置。循环矩阵可以使用离散傅里叶变换高效地对角化 [1],这允许以 \(n \log n\) 的时间为 \(n\) 个观测值评估对数似然 [2]。
- 参数:
loc – 分布的均值 \(\boldsymbol{\mu}\)。
covariance_row – 循环协方差矩阵 \(\boldsymbol{C}\) 的第一行。由于周期性边界条件,协方差矩阵完全由其第一行确定(有关详细信息,请参见
jax.scipy.linalg.toeplitz()
)。covariance_rfft –
covariance_row
(循环协方差矩阵 \(\boldsymbol{C}\) 的第一行)的实快速傅里叶变换的实部。
参考文献
Wikipedia. (n.d.). Circulant matrix. Retrieved March 6, 2025, from https://en.wikipedia.org/wiki/Circulant_matrix
Wood, A. T. A., & Chan, G. (1994). Simulation of Stationary Gaussian Processes in \(\left[0, 1\right]^d\). Journal of Computational and Graphical Statistics, 3(4), 409–432. https://doi.org/10.1080/10618600.1994.10474655
- arg_constraints = {'covariance_rfft': 独立约束(正数(下界=0.0), 1), 'covariance_row': 正定循环向量(), 'loc': 实数向量(实数(), 1)}
- support = 实数向量(实数(), 1)
- static infer_shapes(loc: tuple = (), covariance_row: tuple | None = None, covariance_rfft: tuple | None = None)[source]
根据
__init__()
参数的形状推断batch_shape
和event_shape
。注意
这假设分布形状仅依赖于张量输入的形状,而不依赖于输入中包含的数据。
- 参数:
*args – 位置参数,每个输入参数由表示每个张量输入大小的元组代替。
**kwargs – 关键字参数,将输入参数名称映射到表示每个张量输入大小的元组。
- 返回:
一对表示分布形状的元组
(batch_shape, event_shape)
,该分布将使用给定形状的输入参数创建。- 返回类型:
Dirichlet
- class Dirichlet(concentration, *, validate_args=None)[source]
基类:
Distribution
- arg_constraints = {'concentration': 独立约束(正数(下界=0.0), 1)}
- reparametrized_params = ['concentration']
- support = 单纯形()
- sample(key, sample_shape=())[source]
返回形状为 sample_shape + batch_shape + event_shape 的分布样本。注意,当 sample_shape 非空时,返回样本的前导维度(大小为 sample_shape)将填充从分布实例中独立同分布抽取的样本。
- 参数:
key (jax.random.PRNGKey) – 用于分布的 rng_key。
sample_shape (tuple) – 分布的样本形状。
- 返回:
形状为 sample_shape + batch_shape + event_shape 的数组
- 返回类型:
- log_prob(*args, **kwargs)
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
- property mean
分布的均值。
- property variance
分布的方差。
EulerMaruyama
- class EulerMaruyama(t, sde_fn, init_dist, *, validate_args=None)[source]
基类:
Distribution
Euler–Maruyama 方法是求解随机微分方程 (SDE) 的近似数值方法。
- 参数:
t (ndarray) – 离散化时间
sde_fn (callable) – 返回 SDE 漂移和扩散系数的函数
init_dist (Distribution) – 初始值分布。
参考文献
[1] https://en.wikipedia.org/wiki/Euler-Maruyama_method
- arg_constraints = {'t': 有序向量()}
- pytree_data_fields = ('t', 'init_dist')
- pytree_aux_fields = ('sde_fn',)
- property support
- sample(key, sample_shape=())[source]
返回形状为 sample_shape + batch_shape + event_shape 的分布样本。注意,当 sample_shape 非空时,返回样本的前导维度(大小为 sample_shape)将填充从分布实例中独立同分布抽取的样本。
- 参数:
key (jax.random.PRNGKey) – 用于分布的 rng_key。
sample_shape (tuple) – 分布的样本形状。
- 返回:
形状为 sample_shape + batch_shape + event_shape 的数组
- 返回类型:
- log_prob(*args, **kwargs)
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
Exponential
- class Exponential(rate=1.0, *, validate_args=None)[source]
基类:
Distribution
- reparametrized_params = ['rate']
- arg_constraints = {'rate': 正数(下界=0.0)}
- support = 正数(下界=0.0)
- sample(key, sample_shape=())[source]
返回形状为 sample_shape + batch_shape + event_shape 的分布样本。注意,当 sample_shape 非空时,返回样本的前导维度(大小为 sample_shape)将填充从分布实例中独立同分布抽取的样本。
- 参数:
key (jax.random.PRNGKey) – 用于分布的 rng_key。
sample_shape (tuple) – 分布的样本形状。
- 返回:
形状为 sample_shape + batch_shape + event_shape 的数组
- 返回类型:
- log_prob(*args, **kwargs)
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
- property mean
分布的均值。
- property variance
分布的方差。
Gamma
- class Gamma(concentration, rate=1.0, *, validate_args=None)[source]
基类:
Distribution
- arg_constraints = {'concentration': 正数(下界=0.0), 'rate': 正数(下界=0.0)}
- support = 正数(下界=0.0)
- reparametrized_params = ['concentration', 'rate']
- sample(key, sample_shape=())[source]
返回形状为 sample_shape + batch_shape + event_shape 的分布样本。注意,当 sample_shape 非空时,返回样本的前导维度(大小为 sample_shape)将填充从分布实例中独立同分布抽取的样本。
- 参数:
key (jax.random.PRNGKey) – 用于分布的 rng_key。
sample_shape (tuple) – 分布的样本形状。
- 返回:
形状为 sample_shape + batch_shape + event_shape 的数组
- 返回类型:
- log_prob(*args, **kwargs)
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
- property mean
分布的均值。
- property variance
分布的方差。
GaussianCopula
- class GaussianCopula(marginal_dist, correlation_matrix=None, correlation_cholesky=None, *, validate_args=None)[source]
基类:
Distribution
将边缘分布 marginal_dist 的
batch_shape[:-1]
与模拟轴之间相关性的多元高斯 Copula 连接起来的分布。- 参数:
marginal_dist (Distribution) – 其最后一个 batch 轴需要连接的分布。
correlation_matrix (array_like) – 耦合多元正态分布的相关矩阵。
correlation_cholesky (array_like) – 耦合多元正态分布的相关 Cholesky 因子。
- arg_constraints = {'correlation_cholesky': 相关矩阵 Cholesky 分解(), 'correlation_matrix': 相关矩阵()}
- reparametrized_params = ['correlation_matrix', 'correlation_cholesky']
- pytree_data_fields = ('marginal_dist', 'base_dist')
- sample(key, sample_shape=())[source]
返回形状为 sample_shape + batch_shape + event_shape 的分布样本。注意,当 sample_shape 非空时,返回样本的前导维度(大小为 sample_shape)将填充从分布实例中独立同分布抽取的样本。
- 参数:
key (jax.random.PRNGKey) – 用于分布的 rng_key。
sample_shape (tuple) – 分布的样本形状。
- 返回:
形状为 sample_shape + batch_shape + event_shape 的数组
- 返回类型:
- log_prob(*args, **kwargs)
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
- property mean
分布的均值。
- property variance
分布的方差。
- property support
GaussianCopulaBeta
- class GaussianCopulaBeta(concentration1, concentration0, correlation_matrix=None, correlation_cholesky=None, *, validate_args=False)[source]
继承自:
GaussianCopula
- arg_constraints = {'concentration0': Positive(lower_bound=0.0), 'concentration1': Positive(lower_bound=0.0), 'correlation_cholesky': CorrCholesky(), 'correlation_matrix': CorrMatrix()}
- support = IndependentConstraint(UnitInterval(lower_bound=0.0, upper_bound=1.0), 1)
- pytree_data_fields = ('concentration1', 'concentration0')
高斯随机游走
- class GaussianRandomWalk(scale=1.0, num_steps=1, *, validate_args=None)[source]
基类:
Distribution
- arg_constraints = {'scale': Positive(lower_bound=0.0)}
- support = RealVector(Real(), 1)
- reparametrized_params = ['scale']
- pytree_aux_fields = ('num_steps',)
- sample(key, sample_shape=())[source]
返回形状为 sample_shape + batch_shape + event_shape 的分布样本。注意,当 sample_shape 非空时,返回样本的前导维度(大小为 sample_shape)将填充从分布实例中独立同分布抽取的样本。
- 参数:
key (jax.random.PRNGKey) – 用于分布的 rng_key。
sample_shape (tuple) – 分布的样本形状。
- 返回:
形状为 sample_shape + batch_shape + event_shape 的数组
- 返回类型:
- log_prob(*args, **kwargs)
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
- property mean
分布的均值。
- property variance
分布的方差。
高斯状态空间
- class GaussianStateSpace(num_steps, transition_matrix, covariance_matrix=None, precision_matrix=None, scale_tril=None, *, validate_args=None)[source]
-
高斯状态空间模型。
\[\begin{split}\mathbf{z}_{t} &= \mathbf{A} \mathbf{z}_{t - 1} + \boldsymbol{\epsilon}_t\\ &=\sum_{k=1} \mathbf{A}^{t-k} \boldsymbol{\epsilon}_t,\end{split}\]其中 \(\mathbf{z}_t\) 是步骤 \(t\) 时的状态向量,\(\mathbf{A}\) 是转移矩阵,\(\boldsymbol\epsilon\) 是创新噪声。
- 参数:
num_steps – 步数。
transition_matrix – 状态空间转移矩阵 \(\mathbf{A}\)。
covariance_matrix – 创新噪声 \(\boldsymbol\epsilon\) 的协方差。
precision_matrix – 创新噪声 \(\boldsymbol\epsilon\) 的精度矩阵。
scale_tril – 创新噪声 \(\boldsymbol\epsilon\) 的尺度矩阵。
- arg_constraints = {'covariance_matrix': PositiveDefinite(), 'precision_matrix': PositiveDefinite(), 'scale_tril': LowerCholesky(), 'transition_matrix': RealMatrix(Real(), 2)}
- support = RealMatrix(Real(), 2)
- pytree_aux_fields = ('num_steps',)
- property mean
分布的均值。
- property variance
分布的方差。
Gompertz 分布
- class Gompertz(concentration, rate=1.0, *, validate_args=None)[source]
基类:
Distribution
Gompertz 分布。
Gompertz 分布是一种支持范围在正实数轴上的分布,与 Gumbel 分布密切相关。本实现遵循 Gompertz 分布维基百科条目中使用的符号。参阅 https://en.wikipedia.org/wiki/Gompertz_distribution。
然而,我们将参数“eta”称为 concentration(集中度)参数,将参数“b”称为 rate(比率)参数(与维基百科描述中的 scale 参数不同)。
累积分布函数(CDF),用 concentration (con) 和 rate 表示,为:
\[F(x) = 1 - \exp \left\{ - \text{con} * \left [ \exp\{x * rate \} - 1 \right ] \right\}\]- arg_constraints = {'concentration': Positive(lower_bound=0.0), 'rate': Positive(lower_bound=0.0)}
- support = Positive(lower_bound=0.0)
- reparametrized_params = ['concentration', 'rate']
- sample(key, sample_shape=())[source]
返回形状为 sample_shape + batch_shape + event_shape 的分布样本。注意,当 sample_shape 非空时,返回样本的前导维度(大小为 sample_shape)将填充从分布实例中独立同分布抽取的样本。
- 参数:
key (jax.random.PRNGKey) – 用于分布的 rng_key。
sample_shape (tuple) – 分布的样本形状。
- 返回:
形状为 sample_shape + batch_shape + event_shape 的数组
- 返回类型:
- log_prob(*args, **kwargs)
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
- property mean
分布的均值。
Gumbel 分布
- class Gumbel(loc=0.0, scale=1.0, *, validate_args=None)[source]
基类:
Distribution
- arg_constraints = {'loc': Real(), 'scale': Positive(lower_bound=0.0)}
- support = Real()
- reparametrized_params = ['loc', 'scale']
- sample(key, sample_shape=())[source]
返回形状为 sample_shape + batch_shape + event_shape 的分布样本。注意,当 sample_shape 非空时,返回样本的前导维度(大小为 sample_shape)将填充从分布实例中独立同分布抽取的样本。
- 参数:
key (jax.random.PRNGKey) – 用于分布的 rng_key。
sample_shape (tuple) – 分布的样本形状。
- 返回:
形状为 sample_shape + batch_shape + event_shape 的数组
- 返回类型:
- log_prob(*args, **kwargs)
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
- property mean
分布的均值。
- property variance
分布的方差。
半柯西分布
- class HalfCauchy(scale=1.0, *, validate_args=None)[source]
基类:
Distribution
- reparametrized_params = ['scale']
- support = Positive(lower_bound=0.0)
- arg_constraints = {'scale': Positive(lower_bound=0.0)}
- pytree_data_fields = ('_cauchy', 'scale')
- sample(key, sample_shape=())[source]
返回形状为 sample_shape + batch_shape + event_shape 的分布样本。注意,当 sample_shape 非空时,返回样本的前导维度(大小为 sample_shape)将填充从分布实例中独立同分布抽取的样本。
- 参数:
key (jax.random.PRNGKey) – 用于分布的 rng_key。
sample_shape (tuple) – 分布的样本形状。
- 返回:
形状为 sample_shape + batch_shape + event_shape 的数组
- 返回类型:
- log_prob(*args, **kwargs)
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
- property mean
分布的均值。
- property variance
分布的方差。
半正态分布
- class HalfNormal(scale=1.0, *, validate_args=None)[source]
基类:
Distribution
- reparametrized_params = ['scale']
- support = Positive(lower_bound=0.0)
- arg_constraints = {'scale': Positive(lower_bound=0.0)}
- pytree_data_fields = ('_normal', 'scale')
- sample(key, sample_shape=())[source]
返回形状为 sample_shape + batch_shape + event_shape 的分布样本。注意,当 sample_shape 非空时,返回样本的前导维度(大小为 sample_shape)将填充从分布实例中独立同分布抽取的样本。
- 参数:
key (jax.random.PRNGKey) – 用于分布的 rng_key。
sample_shape (tuple) – 分布的样本形状。
- 返回:
形状为 sample_shape + batch_shape + event_shape 的数组
- 返回类型:
- log_prob(*args, **kwargs)
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
- property mean
分布的均值。
- property variance
分布的方差。
逆Gamma分布
- class InverseGamma(concentration, rate=1.0, *, validate_args=None)[source]
-
注意
我们保留了与 Pyro 中相同的 rate 符号,但在文献中(例如维基百科:https://en.wikipedia.org/wiki/Inverse-gamma_distribution),它扮演着逆Gamma分布的尺度参数的角色。
- arg_constraints = {'concentration': Positive(lower_bound=0.0), 'rate': Positive(lower_bound=0.0)}
- reparametrized_params = ['concentration', 'rate']
- support = Positive(lower_bound=0.0)
- property mean
分布的均值。
- property variance
分布的方差。
Kumaraswamy 分布
- class Kumaraswamy(concentration1, concentration0, *, validate_args=None)[source]
基类:
Distribution
- arg_constraints = {'concentration0': Positive(lower_bound=0.0), 'concentration1': Positive(lower_bound=0.0)}
- reparametrized_params = ['concentration1', 'concentration0']
- support = UnitInterval(lower_bound=0.0, upper_bound=1.0)
- KL_KUMARASWAMY_BETA_TAYLOR_ORDER = 10
- sample(key, sample_shape=())[source]
返回形状为 sample_shape + batch_shape + event_shape 的分布样本。注意,当 sample_shape 非空时,返回样本的前导维度(大小为 sample_shape)将填充从分布实例中独立同分布抽取的样本。
- 参数:
key (jax.random.PRNGKey) – 用于分布的 rng_key。
sample_shape (tuple) – 分布的样本形状。
- 返回:
形状为 sample_shape + batch_shape + event_shape 的数组
- 返回类型:
- log_prob(*args, **kwargs)
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
- property mean
分布的均值。
- property variance
分布的方差。
拉普拉斯分布
- class Laplace(loc=0.0, scale=1.0, *, validate_args=None)[source]
基类:
Distribution
- arg_constraints = {'loc': Real(), 'scale': Positive(lower_bound=0.0)}
- support = Real()
- reparametrized_params = ['loc', 'scale']
- sample(key, sample_shape=())[source]
返回形状为 sample_shape + batch_shape + event_shape 的分布样本。注意,当 sample_shape 非空时,返回样本的前导维度(大小为 sample_shape)将填充从分布实例中独立同分布抽取的样本。
- 参数:
key (jax.random.PRNGKey) – 用于分布的 rng_key。
sample_shape (tuple) – 分布的样本形状。
- 返回:
形状为 sample_shape + batch_shape + event_shape 的数组
- 返回类型:
- log_prob(*args, **kwargs)
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
- property mean
分布的均值。
- property variance
分布的方差。
莱维分布
- class Levy(loc, scale, *, validate_args=None)[source]
基类:
Distribution
莱维(Lévy)分布是莱维 alpha 稳定分布的一个特例。其概率密度函数由下式给出:
\[f(x\mid \mu, c) = \sqrt{\frac{c}{2\pi(x-\mu)^{3}}} \exp\left(-\frac{c}{2(x-\mu)}\right), \qquad x > \mu\]其中 \(\mu\) 是位置参数,\(c\) 是尺度参数。
- 参数:
loc – 位置参数。
scale – 尺度参数。
- arg_constraints = {'loc': Positive(lower_bound=0.0), 'scale': Positive(lower_bound=0.0)}
- property support
- log_prob(*args, **kwargs)
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
- sample(key: Array | ndarray | bool_ | number | bool | int | float | complex, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool_ | number | bool | int | float | complex [source]
返回形状为 sample_shape + batch_shape + event_shape 的分布样本。注意,当 sample_shape 非空时,返回样本的前导维度(大小为 sample_shape)将填充从分布实例中独立同分布抽取的样本。
- 参数:
key (jax.random.PRNGKey) – 用于分布的 rng_key。
sample_shape (tuple) – 分布的样本形状。
- 返回:
形状为 sample_shape + batch_shape + event_shape 的数组
- 返回类型:
- icdf(q: Array | ndarray | bool_ | number | bool | int | float | complex) Array | ndarray | bool_ | number | bool | int | float | complex [source]
莱维(Lévy)分布的逆累积分布函数由下式给出:
\[F^{-1}(q\mid \mu, c) = \mu + c\left(\Phi^{-1}(1-q/2)\right)^{-2}\]其中 \(\Phi^{-1}\) 是标准正态累积分布函数的逆函数。
- 参数:
q – 分位数,应属于 [0, 1]。
- 返回:
其 cdf 值等于 q 的样本。
- cdf(value: Array | ndarray | bool_ | number | bool | int | float | complex) Array | ndarray | bool_ | number | bool | int | float | complex [source]
莱维(Lévy)分布的累积分布函数由下式给出:
\[F(x\mid \mu, c) = 2 - 2\Phi\left(\sqrt{\frac{c}{x-\mu}}\right)\]其中 \(\Phi\) 是标准正态累积分布函数。
- 参数:
value – 来自莱维分布的样本值。
- 返回:
在 value 处评估的累积分布函数的输出。
LKJ 分布
- class LKJ(dimension, concentration=1.0, sample_method='onion', *, validate_args=None)[source]
-
用于相关矩阵的 LKJ 分布。该分布由
concentration
参数 \(\eta\) 控制,使得相关矩阵 \(M\) 的概率与 \(\det(M)^{\eta - 1}\) 成比例。因此,当concentration == 1
时,我们在相关矩阵上具有均匀分布。当
concentration > 1
时,该分布倾向于具有较大行列式的样本。当我们先验地知道基础变量不相关时,这很有用。当
concentration < 1
时,该分布倾向于具有较小行列式的样本。当我们先验地知道一些基础变量相关时,这很有用。在多元正态样本上下文中使用 LKJ 的示例代码
def model(y): # y has dimension N x d d = y.shape[1] N = y.shape[0] # Vector of variances for each of the d variables theta = numpyro.sample("theta", dist.HalfCauchy(jnp.ones(d))) concentration = jnp.ones(1) # Implies a uniform distribution over correlation matrices corr_mat = numpyro.sample("corr_mat", dist.LKJ(d, concentration)) sigma = jnp.sqrt(theta) # we can also use a faster formula `cov_mat = jnp.outer(sigma, sigma) * corr_mat` cov_mat = jnp.matmul(jnp.matmul(jnp.diag(sigma), corr_mat), jnp.diag(sigma)) # Vector of expectations mu = jnp.zeros(d) with numpyro.plate("observations", N): obs = numpyro.sample("obs", dist.MultivariateNormal(mu, covariance_matrix=cov_mat), obs=y) return obs
- 参数:
参考文献
[1] Generating random correlation matrices based on vines and extended onion method, Daniel Lewandowski, Dorota Kurowicka, Harry Joe
- arg_constraints = {'concentration': Positive(lower_bound=0.0)}
- reparametrized_params = ['concentration']
- support = CorrMatrix()
- pytree_aux_fields = ('dimension', 'sample_method')
- property mean
分布的均值。
LKJCholesky 分解分布
- class LKJCholesky(dimension, concentration=1.0, sample_method='onion', *, validate_args=None)[source]
基类:
Distribution
用于相关矩阵的下三角 Cholesky 分解的 LKJ 分布。该分布由
concentration
参数 \(\eta\) 控制,使得由 Cholesky 分解生成的相关矩阵 \(M\) 的概率与 \(\det(M)^{\eta - 1}\) 成比例。因此,当concentration == 1
时,我们在相关矩阵的 Cholesky 分解上具有均匀分布。当
concentration > 1
时,该分布倾向于具有较大对角线元素(因此行列式较大)的样本。当我们先验地知道基础变量不相关时,这很有用。当
concentration < 1
时,该分布倾向于具有较小对角线元素(因此行列式较小)的样本。当我们先验地知道一些基础变量相关时,这很有用。在多元正态样本上下文中使用 LKJCholesky 分解的示例代码
def model(y): # y has dimension N x d d = y.shape[1] N = y.shape[0] # Vector of variances for each of the d variables theta = numpyro.sample("theta", dist.HalfCauchy(jnp.ones(d))) # Lower cholesky factor of a correlation matrix concentration = jnp.ones(1) # Implies a uniform distribution over correlation matrices L_omega = numpyro.sample("L_omega", dist.LKJCholesky(d, concentration)) # Lower cholesky factor of the covariance matrix sigma = jnp.sqrt(theta) # we can also use a faster formula `L_Omega = sigma[..., None] * L_omega` L_Omega = jnp.matmul(jnp.diag(sigma), L_omega) # Vector of expectations mu = jnp.zeros(d) with numpyro.plate("observations", N): obs = numpyro.sample("obs", dist.MultivariateNormal(mu, scale_tril=L_Omega), obs=y) return obs
- 参数:
参考文献
[1] Generating random correlation matrices based on vines and extended onion method, Daniel Lewandowski, Dorota Kurowicka, Harry Joe
- arg_constraints = {'concentration': Positive(lower_bound=0.0)}
- reparametrized_params = ['concentration']
- support = CorrCholesky()
- pytree_data_fields = ('_beta', 'concentration')
- pytree_aux_fields = ('dimension', 'sample_method')
- sample(key, sample_shape=())[source]
返回形状为 sample_shape + batch_shape + event_shape 的分布样本。注意,当 sample_shape 非空时,返回样本的前导维度(大小为 sample_shape)将填充从分布实例中独立同分布抽取的样本。
- 参数:
key (jax.random.PRNGKey) – 用于分布的 rng_key。
sample_shape (tuple) – 分布的样本形状。
- 返回:
形状为 sample_shape + batch_shape + event_shape 的数组
- 返回类型:
- log_prob(*args, **kwargs)
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
对数正态分布
对数均匀分布
- class LogUniform(low, high, *, validate_args=None)[source]
-
- arg_constraints = {'high': Positive(lower_bound=0.0), 'low': Positive(lower_bound=0.0)}
- reparametrized_params = ['low', 'high']
- pytree_data_fields = ('low', 'high', '_support')
- property support
- property mean
分布的均值。
- property variance
分布的方差。
Logistic 分布
- class Logistic(loc=0.0, scale=1.0, *, validate_args=None)[source]
基类:
Distribution
- arg_constraints = {'loc': Real(), 'scale': Positive(lower_bound=0.0)}
- support = Real()
- reparametrized_params = ['loc', 'scale']
- sample(key, sample_shape=())[source]
返回形状为 sample_shape + batch_shape + event_shape 的分布样本。注意,当 sample_shape 非空时,返回样本的前导维度(大小为 sample_shape)将填充从分布实例中独立同分布抽取的样本。
- 参数:
key (jax.random.PRNGKey) – 用于分布的 rng_key。
sample_shape (tuple) – 分布的样本形状。
- 返回:
形状为 sample_shape + batch_shape + event_shape 的数组
- 返回类型:
- log_prob(*args, **kwargs)
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
- property mean
分布的均值。
- property variance
分布的方差。
LowRankMultivariateNormal
- class LowRankMultivariateNormal(loc, cov_factor, cov_diag, *, validate_args=None)[source]
基类:
Distribution
- arg_constraints = {'cov_diag': IndependentConstraint(Positive(lower_bound=0.0), 1), 'cov_factor': IndependentConstraint(Real(), 2), 'loc': RealVector(Real(), 1)}
- support = RealVector(Real(), 1)
- reparametrized_params = ['loc', 'cov_factor', 'cov_diag']
- pytree_data_fields = ('loc', 'cov_factor', 'cov_diag', '_capacitance_tril')
- property mean
分布的均值。
- sample(key, sample_shape=())[source]
返回形状为 sample_shape + batch_shape + event_shape 的分布样本。注意,当 sample_shape 非空时,返回样本的前导维度(大小为 sample_shape)将填充从分布实例中独立同分布抽取的样本。
- 参数:
key (jax.random.PRNGKey) – 用于分布的 rng_key。
sample_shape (tuple) – 分布的样本形状。
- 返回:
形状为 sample_shape + batch_shape + event_shape 的数组
- 返回类型:
- log_prob(*args, **kwargs)
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
矩阵正态分布 (MatrixNormal)
- class MatrixNormal(loc, scale_tril_row, scale_tril_column, validate_args=None)[source]
基类:
Distribution
矩阵变量正态分布,如[1]中所述,但采用下三角参数化,即 \(U=scale_tril_row @ scale_tril_row^{T}\) 和 \(V=scale_tril_column @ scale_tril_column^{T}\)。该分布与多元正态分布的关系如下:如果 \(X ~ MN(loc,U,V)\),则 \(vec(X) ~ MVN(vec(loc), kron(V,U) )\)。
- 参数:
loc (array_like) – 分布的位置参数。
scale_tril_row (array_like) – 行相关矩阵的下三角乔利斯基因子。
scale_tril_column (array_like) – 列相关矩阵的下三角乔利斯基因子。
参考文献
[1] https://en.wikipedia.org/wiki/Matrix_normal_distribution
- arg_constraints = {'loc': RealVector(Real(), 1), 'scale_tril_column': LowerCholesky(), 'scale_tril_row': LowerCholesky()}
- support = RealMatrix(Real(), 2)
- reparametrized_params = ['loc', 'scale_tril_row', 'scale_tril_column']
- property mean
分布的均值。
- sample(key, sample_shape=())[source]
返回形状为 sample_shape + batch_shape + event_shape 的分布样本。注意,当 sample_shape 非空时,返回样本的前导维度(大小为 sample_shape)将填充从分布实例中独立同分布抽取的样本。
- 参数:
key (jax.random.PRNGKey) – 用于分布的 rng_key。
sample_shape (tuple) – 分布的样本形状。
- 返回:
形状为 sample_shape + batch_shape + event_shape 的数组
- 返回类型:
- log_prob(*args, **kwargs)
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
多元正态分布 (MultivariateNormal)
- class MultivariateNormal(loc=0.0, covariance_matrix=None, precision_matrix=None, scale_tril=None, validate_args=None)[source]
基类:
Distribution
- arg_constraints = {'covariance_matrix': PositiveDefinite(), 'loc': RealVector(Real(), 1), 'precision_matrix': PositiveDefinite(), 'scale_tril': LowerCholesky()}
- support = RealVector(Real(), 1)
- reparametrized_params = ['loc', 'covariance_matrix', 'precision_matrix', 'scale_tril']
- sample(key, sample_shape=())[source]
返回形状为 sample_shape + batch_shape + event_shape 的分布样本。注意,当 sample_shape 非空时,返回样本的前导维度(大小为 sample_shape)将填充从分布实例中独立同分布抽取的样本。
- 参数:
key (jax.random.PRNGKey) – 用于分布的 rng_key。
sample_shape (tuple) – 分布的样本形状。
- 返回:
形状为 sample_shape + batch_shape + event_shape 的数组
- 返回类型:
- log_prob(*args, **kwargs)
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
- property mean
分布的均值。
- property variance
分布的方差。
- static infer_shapes(loc=(), covariance_matrix=None, precision_matrix=None, scale_tril=None)[source]
根据
__init__()
参数的形状推断batch_shape
和event_shape
。注意
这假设分布形状仅依赖于张量输入的形状,而不依赖于输入中包含的数据。
- 参数:
*args – 位置参数,每个输入参数由表示每个张量输入大小的元组代替。
**kwargs – 关键字参数,将输入参数名称映射到表示每个张量输入大小的元组。
- 返回:
一对表示分布形状的元组
(batch_shape, event_shape)
,该分布将使用给定形状的输入参数创建。- 返回类型:
多元学生 t 分布 (MultivariateStudentT)
- class MultivariateStudentT(df, loc=0.0, scale_tril=None, *, validate_args=None)[source]
基类:
Distribution
- arg_constraints = {'df': Positive(lower_bound=0.0), 'loc': RealVector(Real(), 1), 'scale_tril': LowerCholesky()}
- support =RealVector(Real(), 1)
- reparametrized_params = ['df', 'loc', 'scale_tril']
- pytree_data_fields = ('df', 'loc', 'scale_tril', '_chi2')
- sample(key, sample_shape=())[source]
返回形状为 sample_shape + batch_shape + event_shape 的分布样本。注意,当 sample_shape 非空时,返回样本的前导维度(大小为 sample_shape)将填充从分布实例中独立同分布抽取的样本。
- 参数:
key (jax.random.PRNGKey) – 用于分布的 rng_key。
sample_shape (tuple) – 分布的样本形状。
- 返回:
形状为 sample_shape + batch_shape + event_shape 的数组
- 返回类型:
- log_prob(*args, **kwargs)
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
- property mean
分布的均值。
- property variance
分布的方差。
正态分布 (Normal)
- class Normal(loc=0.0, scale=1.0, *, validate_args=None)[source]
基类:
Distribution
- arg_constraints = {'loc': Real(), 'scale': Positive(lower_bound=0.0)}
- support = Real()
- reparametrized_params = ['loc', 'scale']
- sample(key, sample_shape=())[source]
返回形状为 sample_shape + batch_shape + event_shape 的分布样本。注意,当 sample_shape 非空时,返回样本的前导维度(大小为 sample_shape)将填充从分布实例中独立同分布抽取的样本。
- 参数:
key (jax.random.PRNGKey) – 用于分布的 rng_key。
sample_shape (tuple) – 分布的样本形状。
- 返回:
形状为 sample_shape + batch_shape + event_shape 的数组
- 返回类型:
- log_prob(*args, **kwargs)
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
- property mean
分布的均值。
- property variance
分布的方差。
帕累托分布 (Pareto)
松弛伯努利分布 (RelaxedBernoulli)
基于 Logits 的松弛伯努利分布 (RelaxedBernoulliLogits)
柔性拉普拉斯分布 (SoftLaplace)
- class SoftLaplace(loc, scale, *, validate_args=None)[source]
基类:
Distribution
具有拉普拉斯分布尾部行为的光滑分布。
该分布对应于对数凸密度函数
z = (value - loc) / scale log_prob = log(2 / pi) - log(scale) - logaddexp(z, -z)
与拉普拉斯密度函数一样,该密度函数具有最重的尾部(渐近意义下),同时仍保持对数凸性。与拉普拉斯分布不同,该分布处处无限可微,因此适用于 HMC 和拉普拉斯近似。
- 参数:
loc – 位置参数。
scale – 尺度参数。
- arg_constraints = {'loc': Real(), 'scale': Positive(lower_bound=0.0)}
- support = Real()
- reparametrized_params = ['loc', 'scale']
- log_prob(*args, **kwargs)
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
- sample(key, sample_shape=())[source]
返回形状为 sample_shape + batch_shape + event_shape 的分布样本。注意,当 sample_shape 非空时,返回样本的前导维度(大小为 sample_shape)将填充从分布实例中独立同分布抽取的样本。
- 参数:
key (jax.random.PRNGKey) – 用于分布的 rng_key。
sample_shape (tuple) – 分布的样本形状。
- 返回:
形状为 sample_shape + batch_shape + event_shape 的数组
- 返回类型:
- property mean
分布的均值。
- property variance
分布的方差。
学生 t 分布 (StudentT)
- class StudentT(df, loc=0.0, scale=1.0, *, validate_args=None)[source]
基类:
Distribution
- arg_constraints = {'df': Positive(lower_bound=0.0), 'loc': Real(), 'scale': Positive(lower_bound=0.0)}
- support = Real()
- reparametrized_params = ['df', 'loc', 'scale']
- pytree_data_fields = ('df', 'loc', 'scale', '_chi2')
- sample(key, sample_shape=())[source]
返回形状为 sample_shape + batch_shape + event_shape 的分布样本。注意,当 sample_shape 非空时,返回样本的前导维度(大小为 sample_shape)将填充从分布实例中独立同分布抽取的样本。
- 参数:
key (jax.random.PRNGKey) – 用于分布的 rng_key。
sample_shape (tuple) – 分布的样本形状。
- 返回:
形状为 sample_shape + batch_shape + event_shape 的数组
- 返回类型:
- log_prob(*args, **kwargs)
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
- property mean
分布的均值。
- property variance
分布的方差。
均匀分布 (Uniform)
- class Uniform(low=0.0, high=1.0, *, validate_args=None)[source]
基类:
Distribution
- arg_constraints = {'high': Dependent(), 'low': Dependent()}
- reparametrized_params = ['low', 'high']
- pytree_data_fields = ('low', 'high', '_support')
- property support
- sample(key, sample_shape=())[source]
返回形状为 sample_shape + batch_shape + event_shape 的分布样本。注意,当 sample_shape 非空时,返回样本的前导维度(大小为 sample_shape)将填充从分布实例中独立同分布抽取的样本。
- 参数:
key (jax.random.PRNGKey) – 用于分布的 rng_key。
sample_shape (tuple) – 分布的样本形状。
- 返回:
形状为 sample_shape + batch_shape + event_shape 的数组
- 返回类型:
- log_prob(*args, **kwargs)
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
- property mean
分布的均值。
- property variance
分布的方差。
威布尔分布 (Weibull)
- class Weibull(scale, concentration, *, validate_args=None)[source]
基类:
Distribution
- arg_constraints = {'concentration': Positive(lower_bound=0.0), 'scale': Positive(lower_bound=0.0)}
- support = Positive(lower_bound=0.0)
- reparametrized_params = ['scale', 'concentration']
- sample(key, sample_shape=())[source]
返回形状为 sample_shape + batch_shape + event_shape 的分布样本。注意,当 sample_shape 非空时,返回样本的前导维度(大小为 sample_shape)将填充从分布实例中独立同分布抽取的样本。
- 参数:
key (jax.random.PRNGKey) – 用于分布的 rng_key。
sample_shape (tuple) – 分布的样本形状。
- 返回:
形状为 sample_shape + batch_shape + event_shape 的数组
- 返回类型:
- log_prob(*args, **kwargs)
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
- property mean
分布的均值。
- property variance
分布的方差。
威沙特分布 (Wishart)
- class Wishart(concentration, scale_matrix=None, rate_matrix=None, scale_tril=None, *, validate_args=None)[source]
-
用于协方差矩阵的威沙特分布。
- 参数:
- arg_constraints = {'concentration': Dependent(), 'rate_matrix': PositiveDefinite(), 'scale_matrix': PositiveDefinite(), 'scale_tril': LowerCholesky()}
- support = PositiveDefinite()
- reparametrized_params = ['scale_matrix', 'rate_matrix', 'scale_tril']
- static infer_shapes(concentration=(), scale_matrix=None, rate_matrix=None, scale_tril=None)[source]
根据
__init__()
参数的形状推断batch_shape
和event_shape
。注意
这假设分布形状仅依赖于张量输入的形状,而不依赖于输入中包含的数据。
- 参数:
*args – 位置参数,每个输入参数由表示每个张量输入大小的元组代替。
**kwargs – 关键字参数,将输入参数名称映射到表示每个张量输入大小的元组。
- 返回:
一对表示分布形状的元组
(batch_shape, event_shape)
,该分布将使用给定形状的输入参数创建。- 返回类型:
威沙特分布的乔利斯基因子 (WishartCholesky)
- class WishartCholesky(concentration, scale_matrix=None, rate_matrix=None, scale_tril=None, *, validate_args=None)[source]
基类:
Distribution
协方差矩阵的威沙特分布的乔利斯基因子。
- 参数:
- arg_constraints = {'concentration': Dependent(), 'rate_matrix': PositiveDefinite(), 'scale_matrix': PositiveDefinite(), 'scale_tril': LowerCholesky()}
- support = LowerCholesky()
- reparametrized_params = ['scale_matrix', 'rate_matrix', 'scale_tril']
- log_prob(*args, **kwargs)
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
- sample(key, sample_shape=())[source]
返回形状为 sample_shape + batch_shape + event_shape 的分布样本。注意,当 sample_shape 非空时,返回样本的前导维度(大小为 sample_shape)将填充从分布实例中独立同分布抽取的样本。
- 参数:
key (jax.random.PRNGKey) – 用于分布的 rng_key。
sample_shape (tuple) – 分布的样本形状。
- 返回:
形状为 sample_shape + batch_shape + event_shape 的数组
- 返回类型:
- static infer_shapes(concentration=(), scale_matrix=None, rate_matrix=None, scale_tril=None)[source]
根据
__init__()
参数的形状推断batch_shape
和event_shape
。注意
这假设分布形状仅依赖于张量输入的形状,而不依赖于输入中包含的数据。
- 参数:
*args – 位置参数,每个输入参数由表示每个张量输入大小的元组代替。
**kwargs – 关键字参数,将输入参数名称映射到表示每个张量输入大小的元组。
- 返回:
一对表示分布形状的元组
(batch_shape, event_shape)
,该分布将使用给定形状的输入参数创建。- 返回类型:
零和正态分布 (ZeroSumNormal)
- class ZeroSumNormal(scale, event_shape, *, validate_args=None)[source]
-
零和正态分布,根据[2,3]中描述,改编自 PyMC [1]。这是一种正态分布,其中一个或多个轴被约束为总和为零(默认为最后一个轴)。
\[\begin{split}\begin{align*} ZSN(\sigma) = N(0, \sigma^2 (I - \tfrac{1}{n}J)) \\ \text{where} \ ~ J_{ij} = 1 \ ~ \text{and} \\ n = \text{number of zero-sum axes} \end{align*}\end{split}\]- 参数:
scale (array_like) – 在施加零和约束之前,基础正态分布的标准差。
event_shape (tuple) – 分布的事件形状,其轴被约束为总和为零。
示例
>>> from numpy.testing import assert_allclose >>> from jax import random >>> import jax.numpy as jnp >>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.infer import MCMC, NUTS >>> N = 1000 >>> n_categories = 20 >>> rng_key = random.PRNGKey(0) >>> key1, key2, key3 = random.split(rng_key, 3) >>> category_ind = random.choice(key1, jnp.arange(n_categories), shape=(N,)) >>> beta = random.normal(key2, shape=(n_categories,)) >>> beta -= beta.mean(-1) >>> y = 5 + beta[category_ind] + random.normal(key3, shape=(N,)) >>> def model(category_ind, y): # category_ind is an indexed categorical variable with 20 categories ... N = len(category_ind) ... alpha = numpyro.sample("alpha", dist.Normal(0, 2.5)) ... beta = numpyro.sample("beta", dist.ZeroSumNormal(1, event_shape=(n_categories,))) ... sigma = numpyro.sample("sigma", dist.Exponential(1)) ... with numpyro.plate("observations", N): ... mu = alpha + beta[category_ind] ... obs = numpyro.sample("obs", dist.Normal(mu, sigma), obs=y) ... return obs >>> nuts_kernel = NUTS(model=model, target_accept_prob=0.9) >>> mcmc = MCMC( ... sampler=nuts_kernel, ... num_samples=1_000, num_warmup=1_000, num_chains=4 ... ) >>> mcmc.run(random.PRNGKey(0), category_ind=category_ind, y=y) >>> posterior_samples = mcmc.get_samples() >>> # Confirm everything along last axis sums to zero >>> assert_allclose(posterior_samples['beta'].sum(-1), 0, atol=1e-3)
参考资料 [1] https://github.com/pymc-devs/pymc/blob/6252d2e58dc211c913ee2e652a4058d271d48bbd/pymc/distributions/multivariate.py#L2637 [2] https://pymc.cn/projects/docs/en/stable/api/distributions/generated/pymc.ZeroSumNormal.html [3] https://learnbayesstats.com/episode/74-optimizing-nuts-developing-zerosumnormal-distribution-adrian-seyboldt/
- arg_constraints = {'scale': Positive(lower_bound=0.0)}
- reparametrized_params = ['scale']
- property support
- property mean
分布的均值。
- property variance
分布的方差。
离散分布 (Discrete Distributions)
伯努利分布 (Bernoulli)
基于 Logits 的伯努利分布 (BernoulliLogits)
- class BernoulliLogits(logits=None, *, validate_args=None)[source]
基类:
Distribution
- arg_constraints = {'logits': Real()}
- support = Boolean()
- has_enumerate_support = True
- sample(key, sample_shape=())[source]
返回形状为 sample_shape + batch_shape + event_shape 的分布样本。注意,当 sample_shape 非空时,返回样本的前导维度(大小为 sample_shape)将填充从分布实例中独立同分布抽取的样本。
- 参数:
key (jax.random.PRNGKey) – 用于分布的 rng_key。
sample_shape (tuple) – 分布的样本形状。
- 返回:
形状为 sample_shape + batch_shape + event_shape 的数组
- 返回类型:
- log_prob(*args, **kwargs)
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
- property mean
分布的均值。
- property variance
分布的方差。
基于 Probabilities 的伯努利分布 (BernoulliProbs)
- class BernoulliProbs(probs, *, validate_args=None)[source]
基类:
Distribution
- arg_constraints = {'probs': UnitInterval(lower_bound=0.0, upper_bound=1.0)}
- support = Boolean()
- has_enumerate_support = True
- sample(key, sample_shape=())[source]
返回形状为 sample_shape + batch_shape + event_shape 的分布样本。注意,当 sample_shape 非空时,返回样本的前导维度(大小为 sample_shape)将填充从分布实例中独立同分布抽取的样本。
- 参数:
key (jax.random.PRNGKey) – 用于分布的 rng_key。
sample_shape (tuple) – 分布的样本形状。
- 返回:
形状为 sample_shape + batch_shape + event_shape 的数组
- 返回类型:
- log_prob(*args, **kwargs)
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
- property mean
分布的均值。
- property variance
分布的方差。
Beta-二项分布 (BetaBinomial)
- class BetaBinomial(concentration1, concentration0, total_count=1, *, validate_args=None)[source]
基类:
Distribution
由beta-binomial对组成的复合分布。
Binomial
分布的成功概率(probs
参数)未知,并在由total_count
给定的伯努利试验次数之前,从Beta
分布中随机抽取。- 参数:
concentration1 (numpy.ndarray) – Beta分布的第1个集中度参数 (alpha)。
concentration0 (numpy.ndarray) – Beta分布的第2个集中度参数 (beta)。
total_count (numpy.ndarray) – 伯努利试验次数。
- arg_constraints = {'concentration0': Positive(lower_bound=0.0), 'concentration1': Positive(lower_bound=0.0), 'total_count': IntegerNonnegative(lower_bound=0)}
- has_enumerate_support = True
- enumerate_support(expand=True)
返回一个形状为 len(support) x batch_shape 的数组,其中包含支持集中的所有值。
- pytree_data_fields = ('concentration1', 'concentration0', 'total_count', '_beta')
- sample(key, sample_shape=())[source]
返回形状为 sample_shape + batch_shape + event_shape 的分布样本。注意,当 sample_shape 非空时,返回样本的前导维度(大小为 sample_shape)将填充从分布实例中独立同分布抽取的样本。
- 参数:
key (jax.random.PRNGKey) – 用于分布的 rng_key。
sample_shape (tuple) – 分布的样本形状。
- 返回:
形状为 sample_shape + batch_shape + event_shape 的数组
- 返回类型:
- log_prob(*args, **kwargs)
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
- property mean
分布的均值。
- property variance
分布的方差。
- property support
二项分布
Logits形式的二项分布
- class BinomialLogits(logits, total_count=1, *, validate_args=None)[source]
基类:
Distribution
- arg_constraints = {'logits': Real(), 'total_count': IntegerNonnegative(lower_bound=0)}
- has_enumerate_support = True
- enumerate_support(expand=True)
返回一个形状为 len(support) x batch_shape 的数组,其中包含支持集中的所有值。
- sample(key, sample_shape=())[source]
返回形状为 sample_shape + batch_shape + event_shape 的分布样本。注意,当 sample_shape 非空时,返回样本的前导维度(大小为 sample_shape)将填充从分布实例中独立同分布抽取的样本。
- 参数:
key (jax.random.PRNGKey) – 用于分布的 rng_key。
sample_shape (tuple) – 分布的样本形状。
- 返回:
形状为 sample_shape + batch_shape + event_shape 的数组
- 返回类型:
- log_prob(*args, **kwargs)
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
- property mean
分布的均值。
- property variance
分布的方差。
- property support
概率形式的二项分布
- class BinomialProbs(probs, total_count=1, *, validate_args=None)[source]
基类:
Distribution
- arg_constraints = {'probs': UnitInterval(lower_bound=0.0, upper_bound=1.0), 'total_count': IntegerNonnegative(lower_bound=0)}
- has_enumerate_support = True
- sample(key, sample_shape=())[source]
返回形状为 sample_shape + batch_shape + event_shape 的分布样本。注意,当 sample_shape 非空时,返回样本的前导维度(大小为 sample_shape)将填充从分布实例中独立同分布抽取的样本。
- 参数:
key (jax.random.PRNGKey) – 用于分布的 rng_key。
sample_shape (tuple) – 分布的样本形状。
- 返回:
形状为 sample_shape + batch_shape + event_shape 的数组
- 返回类型:
- log_prob(*args, **kwargs)
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
- property mean
分布的均值。
- property variance
分布的方差。
- property support
范畴分布
Logits形式的范畴分布
- class CategoricalLogits(logits, *, validate_args=None)[source]
基类:
Distribution
- arg_constraints = {'logits': RealVector(Real(), 1)}
- has_enumerate_support = True
- sample(key, sample_shape=())[source]
返回形状为 sample_shape + batch_shape + event_shape 的分布样本。注意,当 sample_shape 非空时,返回样本的前导维度(大小为 sample_shape)将填充从分布实例中独立同分布抽取的样本。
- 参数:
key (jax.random.PRNGKey) – 用于分布的 rng_key。
sample_shape (tuple) – 分布的样本形状。
- 返回:
形状为 sample_shape + batch_shape + event_shape 的数组
- 返回类型:
- log_prob(*args, **kwargs)
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
- property mean
分布的均值。
- property variance
分布的方差。
- property support
概率形式的范畴分布
- class CategoricalProbs(probs, *, validate_args=None)[source]
基类:
Distribution
- arg_constraints = {'probs': Simplex()}
- has_enumerate_support = True
- sample(key, sample_shape=())[source]
返回形状为 sample_shape + batch_shape + event_shape 的分布样本。注意,当 sample_shape 非空时,返回样本的前导维度(大小为 sample_shape)将填充从分布实例中独立同分布抽取的样本。
- 参数:
key (jax.random.PRNGKey) – 用于分布的 rng_key。
sample_shape (tuple) – 分布的样本形状。
- 返回:
形状为 sample_shape + batch_shape + event_shape 的数组
- 返回类型:
- log_prob(*args, **kwargs)
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
- property mean
分布的均值。
- property variance
分布的方差。
- property support
狄利克雷-多项分布
- class DirichletMultinomial(concentration, total_count=1, *, validate_args=None)[source]
基类:
Distribution
由dirichlet-multinomial对组成的复合分布。
Multinomial
分布的类别概率(probs
参数)未知,并在由total_count
给定的范畴试验次数之前,从Dirichlet
分布中随机抽取。- 参数:
concentration (numpy.ndarray) – Dirichlet分布的集中度参数 (alpha)。
total_count (numpy.ndarray) – 范畴试验次数。
- arg_constraints = {'concentration': IndependentConstraint(Positive(lower_bound=0.0), 1), 'total_count': IntegerNonnegative(lower_bound=0)}
- pytree_data_fields = ('concentration', '_dirichlet')
- pytree_aux_fields = ('total_count',)
- sample(key, sample_shape=())[source]
返回形状为 sample_shape + batch_shape + event_shape 的分布样本。注意,当 sample_shape 非空时,返回样本的前导维度(大小为 sample_shape)将填充从分布实例中独立同分布抽取的样本。
- 参数:
key (jax.random.PRNGKey) – 用于分布的 rng_key。
sample_shape (tuple) – 分布的样本形状。
- 返回:
形状为 sample_shape + batch_shape + event_shape 的数组
- 返回类型:
- log_prob(*args, **kwargs)
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
- property mean
分布的均值。
- property variance
分布的方差。
- property support
离散均匀分布
- class DiscreteUniform(low=0, high=1, *, validate_args=None)[source]
基类:
Distribution
- arg_constraints = {'high': Dependent(), 'low': Dependent()}
- has_enumerate_support = True
- pytree_data_fields = ('low', 'high', '_support')
- property support
- sample(key, sample_shape=())[source]
返回形状为 sample_shape + batch_shape + event_shape 的分布样本。注意,当 sample_shape 非空时,返回样本的前导维度(大小为 sample_shape)将填充从分布实例中独立同分布抽取的样本。
- 参数:
key (jax.random.PRNGKey) – 用于分布的 rng_key。
sample_shape (tuple) – 分布的样本形状。
- 返回:
形状为 sample_shape + batch_shape + event_shape 的数组
- 返回类型:
- log_prob(*args, **kwargs)
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
- property mean
分布的均值。
- property variance
分布的方差。
伽马-泊松分布
- class GammaPoisson(concentration, rate=1.0, *, validate_args=None)[source]
基类:
Distribution
由gamma-poisson对组成的复合分布,也称为伽马-泊松混合。
Poisson
分布的rate
参数未知,从Gamma
分布中随机抽取。- 参数:
concentration (numpy.ndarray) – Gamma分布的形状参数 (alpha)。
rate (numpy.ndarray) – Gamma分布的率参数 (beta)。
- arg_constraints = {'concentration': Positive(lower_bound=0.0), 'rate': Positive(lower_bound=0.0)}
- support = IntegerNonnegative(lower_bound=0)
- pytree_data_fields = ('concentration', 'rate', '_gamma')
- sample(key, sample_shape=())[source]
返回形状为 sample_shape + batch_shape + event_shape 的分布样本。注意,当 sample_shape 非空时,返回样本的前导维度(大小为 sample_shape)将填充从分布实例中独立同分布抽取的样本。
- 参数:
key (jax.random.PRNGKey) – 用于分布的 rng_key。
sample_shape (tuple) – 分布的样本形状。
- 返回:
形状为 sample_shape + batch_shape + event_shape 的数组
- 返回类型:
- log_prob(*args, **kwargs)
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
- property mean
分布的均值。
- property variance
分布的方差。
几何分布
Logits形式的几何分布
- class GeometricLogits(logits, *, validate_args=None)[source]
基类:
Distribution
- arg_constraints = {'logits': Real()}
- support =IntegerNonnegative(lower_bound=0)
- sample(key, sample_shape=())[source]
返回形状为 sample_shape + batch_shape + event_shape 的分布样本。注意,当 sample_shape 非空时,返回样本的前导维度(大小为 sample_shape)将填充从分布实例中独立同分布抽取的样本。
- 参数:
key (jax.random.PRNGKey) – 用于分布的 rng_key。
sample_shape (tuple) – 分布的样本形状。
- 返回:
形状为 sample_shape + batch_shape + event_shape 的数组
- 返回类型:
- log_prob(*args, **kwargs)
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
- property mean
分布的均值。
- property variance
分布的方差。
概率形式的几何分布
- class GeometricProbs(probs, *, validate_args=None)[source]
基类:
Distribution
- arg_constraints = {'probs': UnitInterval(lower_bound=0.0, upper_bound=1.0)}
- support =IntegerNonnegative(lower_bound=0)
- sample(key, sample_shape=())[source]
返回形状为 sample_shape + batch_shape + event_shape 的分布样本。注意,当 sample_shape 非空时,返回样本的前导维度(大小为 sample_shape)将填充从分布实例中独立同分布抽取的样本。
- 参数:
key (jax.random.PRNGKey) – 用于分布的 rng_key。
sample_shape (tuple) – 分布的样本形状。
- 返回:
形状为 sample_shape + batch_shape + event_shape 的数组
- 返回类型:
- log_prob(*args, **kwargs)
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
- property mean
分布的均值。
- property variance
分布的方差。
多项分布
Logits形式的多项分布
- class MultinomialLogits(logits, total_count=1, *, total_count_max=None, validate_args=None)[source]
基类:
Distribution
- arg_constraints = {'logits': RealVector(Real(), 1), 'total_count': IntegerNonnegative(lower_bound=0)}
- pytree_data_fields = ('logits',)
- pytree_aux_fields = ('total_count', 'total_count_max')
- sample(key, sample_shape=())[source]
返回形状为 sample_shape + batch_shape + event_shape 的分布样本。注意,当 sample_shape 非空时,返回样本的前导维度(大小为 sample_shape)将填充从分布实例中独立同分布抽取的样本。
- 参数:
key (jax.random.PRNGKey) – 用于分布的 rng_key。
sample_shape (tuple) – 分布的样本形状。
- 返回:
形状为 sample_shape + batch_shape + event_shape 的数组
- 返回类型:
- log_prob(*args, **kwargs)
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
- property mean
分布的均值。
- property variance
分布的方差。
- property support
概率形式的多项分布
- class MultinomialProbs(probs, total_count=1, *, total_count_max=None, validate_args=None)[source]
基类:
Distribution
- arg_constraints = {'probs': Simplex(), 'total_count': IntegerNonnegative(lower_bound=0)}
- pytree_data_fields = ('probs',)
- pytree_aux_fields = ('total_count', 'total_count_max')
- sample(key, sample_shape=())[source]
返回形状为 sample_shape + batch_shape + event_shape 的分布样本。注意,当 sample_shape 非空时,返回样本的前导维度(大小为 sample_shape)将填充从分布实例中独立同分布抽取的样本。
- 参数:
key (jax.random.PRNGKey) – 用于分布的 rng_key。
sample_shape (tuple) – 分布的样本形状。
- 返回:
形状为 sample_shape + batch_shape + event_shape 的数组
- 返回类型:
- log_prob(*args, **kwargs)
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
- property mean
分布的均值。
- property variance
分布的方差。
- property support
有序Logistic分布
- class OrderedLogistic(predictor, cutpoints, *, validate_args=None)[source]
基类:
CategoricalProbs
具有有序结果的范畴分布。
参考文献
Stan 函数参考手册, v2.20 第12.6节, Stan 开发团队
- 参数:
predictor (numpy.ndarray) – 实数域中的预测值;通常这是线性模型的输出。
cutpoints (numpy.ndarray) – 实数域中用于分隔类别的位置。
- arg_constraints = {'cutpoints': OrderedVector(), 'predictor': Real()}
负二项分布
Logits形式的负二项分布
- class NegativeBinomialLogits(total_count, logits, *, validate_args=None)[source]
基类:
GammaPoisson
- arg_constraints = {'logits': Real(), 'total_count': Positive(lower_bound=0.0)}
- support =IntegerNonnegative(lower_bound=0)
- log_prob(*args, **kwargs)
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
概率形式的负二项分布
- class NegativeBinomialProbs(total_count, probs, *, validate_args=None)[source]
基类:
GammaPoisson
- arg_constraints = {'probs': UnitInterval(lower_bound=0.0, upper_bound=1.0), 'total_count': Positive(lower_bound=0.0)}
- support =IntegerNonnegative(lower_bound=0)
负二项分布2
- class NegativeBinomial2(mean, concentration, *, validate_args=None)[source]
基类:
GammaPoisson
GammaPoisson 的另一种参数化形式,其中 rate 被 mean 替换。
- arg_constraints = {'concentration': Positive(lower_bound=0.0), 'mean': Positive(lower_bound=0.0)}
- support =IntegerNonnegative(lower_bound=0)
- pytree_data_fields = ('concentration',)
泊松分布
- class Poisson(rate, *, is_sparse=False, validate_args=None)[source]
基类:
Distribution
创建一个由率参数
rate
参数化的泊松分布。样本是非负整数,其pmf(概率质量函数)由下式给出
\[\mathrm{rate}^k \frac{e^{-\mathrm{rate}}}{k!}\]- 参数:
rate (numpy.ndarray) – 率参数
is_sparse (bool) – 计算
log_prob()
时是否假定值主要为零,这可以在数据稀疏时加快计算速度。
- arg_constraints = {'rate': Positive(lower_bound=0.0)}
- support =IntegerNonnegative(lower_bound=0)
- pytree_aux_fields = ('is_sparse',)
- sample(key, sample_shape=())[source]
返回形状为 sample_shape + batch_shape + event_shape 的分布样本。注意,当 sample_shape 非空时,返回样本的前导维度(大小为 sample_shape)将填充从分布实例中独立同分布抽取的样本。
- 参数:
key (jax.random.PRNGKey) – 用于分布的 rng_key。
sample_shape (tuple) – 分布的样本形状。
- 返回:
形状为 sample_shape + batch_shape + event_shape 的数组
- 返回类型:
- log_prob(*args, **kwargs)
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
- property mean
分布的均值。
- property variance
分布的方差。
零膨胀分布
- ZeroInflatedDistribution(base_dist, *, gate=None, gate_logits=None, validate_args=None)[source]
通用零膨胀分布。
- 参数:
base_dist (Distribution) – 基分布。
gate (numpy.ndarray) – 通过伯努利分布给出的额外零的概率。
gate_logits (numpy.ndarray) – 通过伯努利分布给出的额外零的对数概率。
零膨胀泊松分布
- class ZeroInflatedPoisson(gate, rate=1.0, *, validate_args=None)[source]
基类:
ZeroInflatedProbs
零膨胀泊松分布。
- 参数:
gate (numpy.ndarray) – 额外零的概率。
rate (numpy.ndarray) – 泊松分布的速率。
- arg_constraints = {'gate': UnitInterval(lower_bound=0.0, upper_bound=1.0), 'rate': Positive(lower_bound=0.0)}
- support = IntegerNonnegative(lower_bound=0)
- pytree_data_fields = ('rate',)
零膨胀负二项分布2
混合分布
混合
- Mixture(mixing_distribution, component_distributions, *, validate_args=None)[source]
分量分布的边缘化有限混合
返回的分布将是以下之一:
MixtureGeneral
,当component_distributions
是一个列表时,或者MixtureSameFamily
,当component_distributions
是一个单一分布时。
更多详细信息请参阅这些类的文档。
- 参数:
mixing_distribution – 一个
Categorical
分布,指定每个混合分量的权重。此分布的大小指定了混合中的分量数量mixture_size
。component_distributions – 可以是分量分布列表,也可以是单一向量化分布。提供列表时,元素数量必须等于
mixture_size
。否则,分布的最后一个批量维度必须等于mixture_size
。
- 返回:
混合分布。
同族混合分布
- class MixtureSameFamily(mixing_distribution, component_distribution, *, validate_args=None)[source]
基类:
_MixtureBase
同族分量分布的有限混合。
这种混合仅支持同族分量分布的混合。不同的分量是沿输入
component_distribution
的最后一个批量维度指定的。如果需要不同族分布的混合,请使用更通用的MixtureGeneral
实现。- 参数:
mixing_distribution – 一个
Categorical
分布,指定每个混合分量的权重。此分布的大小指定了混合中的分量数量mixture_size
。component_distribution – 一个单一的向量化
Distribution
,其最后一个批量维度等于mixing_distribution
指定的mixture_size
。
示例
>>> import jax >>> import jax.numpy as jnp >>> import numpyro.distributions as dist >>> mixing_dist = dist.Categorical(probs=jnp.ones(3) / 3.) >>> component_dist = dist.Normal(loc=jnp.zeros(3), scale=jnp.ones(3)) >>> mixture = dist.MixtureSameFamily(mixing_dist, component_dist) >>> mixture.sample(jax.random.PRNGKey(42)).shape ()
- pytree_data_fields = ('_mixing_distribution', '_component_distribution')
- pytree_aux_fields = ('_mixture_size',)
- property component_distribution
返回被混合的分量的向量化分布。
- 返回:
分量分布
- 返回类型:
- property support
- property is_discrete
- property component_mean
- property component_variance
通用混合分布
- class MixtureGeneral(mixing_distribution, component_distributions, *, support=None, validate_args=None)[source]
基类:
_MixtureBase
不同族分量分布的有限混合。
如果所有分量分布都来自同一族,MixtureSameFamily 中更具体的实现会更有效一些。
- 参数:
mixing_distribution – 一个
Categorical
分布,指定每个混合分量的权重。此分布的大小指定了混合中的分量数量mixture_size
。component_distributions – 一个包含
mixture_size
个Distribution
对象的列表。support – 一个
Constraint
对象,指定混合分布的支持集。如果未提供,支持集将从分量分布推断。
示例
>>> import jax >>> import jax.numpy as jnp >>> import numpyro.distributions as dist >>> mixing_dist = dist.Categorical(probs=jnp.ones(3) / 3.) >>> component_dists = [ ... dist.Normal(loc=0.0, scale=1.0), ... dist.Normal(loc=-0.5, scale=0.3), ... dist.Normal(loc=0.6, scale=1.2), ... ] >>> mixture = dist.MixtureGeneral(mixing_dist, component_dists) >>> mixture.sample(jax.random.PRNGKey(42)).shape ()
>>> import jax >>> import jax.numpy as jnp >>> import numpyro.distributions as dist >>> mixing_dist = dist.Categorical(probs=jnp.ones(2) / 2.) >>> component_dists = [ ... dist.Normal(loc=0.0, scale=1.0), ... dist.HalfNormal(scale=0.3), ... ] >>> mixture = dist.MixtureGeneral(mixing_dist, component_dists, support=dist.constraints.real) >>> mixture.sample(jax.random.PRNGKey(42)).shape ()
- pytree_data_fields = ('_mixing_distribution', '_component_distributions', '_support')
- pytree_aux_fields = ('_mixture_size',)
- property component_distributions
混合中的分量分布列表
- 返回:
分量分布列表
- 返回类型:
- property support
- property is_discrete
- property component_mean
- property component_variance
方向分布
投影正态分布
- class ProjectedNormal(concentration, *, validate_args=None)[source]
基类:
Distribution
任意维度的各向同性投影正态分布。
这种方向数据的分布在定性上类似于 von Mises 和 von Mises-Fisher 分布,但允许通过重新参数化梯度进行可行的变分推断。
要将此分布与 autoguides 和 HMC 一起使用,请在模型中使用
handlers.reparam
和ProjectedNormalReparam
重新参数化器,例如:@handlers.reparam(config={"direction": ProjectedNormalReparam()}) def model(): direction = numpyro.sample("direction", ProjectedNormal(zeros(3))) ...
注意
这仅对维度 {2,3} 实现了
log_prob()
。- [1] D. Hernandez-Stumpfhauser, F.J. Breidt, M.J. van der Woerd (2017)
“任意维度的通用投影正态分布:建模与贝叶斯推断” https://projecteuclid.org/euclid.ba/1453211962
- arg_constraints = {'concentration': RealVector(Real(), 1)}
- reparametrized_params = ['concentration']
- support = Sphere()
- property mean
注意:这是在最小化预期平方测地距离的子流形质心意义上的均值。
- property mode
- sample(key, sample_shape=())[source]
返回形状为 sample_shape + batch_shape + event_shape 的分布样本。注意,当 sample_shape 非空时,返回样本的前导维度(大小为 sample_shape)将填充从分布实例中独立同分布抽取的样本。
- 参数:
key (jax.random.PRNGKey) – 用于分布的 rng_key。
sample_shape (tuple) – 分布的样本形状。
- 返回:
形状为 sample_shape + batch_shape + event_shape 的数组
- 返回类型:
正弦双变量 von Mises 分布
- class SineBivariateVonMises(phi_loc, psi_loc, phi_concentration, psi_concentration, correlation=None, weighted_correlation=None, validate_args=None)[source]
基类:
Distribution
给定以下公式的,在 2-torus (\(S^1 \otimes S^1\)) 上两个相关角度的单峰分布
\[C^{-1}\exp(\kappa_1\cos(x_1-\mu_1) + \kappa_2\cos(x_2 -\mu_2) + \rho\sin(x_1 - \mu_1)\sin(x_2 - \mu_2))\]和
\[C = (2\pi)^2 \sum_{i=0} {2i \choose i} \left(\frac{\rho^2}{4\kappa_1\kappa_2}\right)^i I_i(\kappa_1)I_i(\kappa_2),\]其中 \(I_i(\cdot)\) 是第一类修正贝塞尔函数,mu 是分布的位置,kappa 是集中度,rho 给出角度 \(x_1\) 和 \(x_2\) 之间的相关性。此分布有助于建模耦合角度,例如肽链中的扭转角。
要推断参数,请使用
NUTS
或HMC
并采用避免分布变成双峰的参数化先验;参见下面的注意事项。注意
采样效率随以下情况降低
\[\frac{\rho^2}{\kappa_1\kappa_2} \rightarrow 1\]因为分布变得越来越双峰。为避免低效采样,请使用 weighted_correlation 参数并使其偏离 1(例如 TransformedDistribution(Beta(5,5), AffineTransform(loc=-1, scale=2)))。weighted_correlation 应在 [-1,1] 范围内。
注意
correlation 和 weighted_correlation 参数互斥。
注意
在
SVI
的上下文中,此分布可用作似然函数,但不能用于潜在变量。注意
对于高达 10,000 的集中度,归一化保持准确。与 Pyro 不同,初始化期间没有断言来验证这一点,因为 JIT 编译会使此类检查无效。
- ** 参考资料:**
Probabilistic model for two dependent circular variables Singh, H., Hnizdo, V., and Demchuck, E. (2002)
- 参数:
phi_loc (np.ndarray) – 第一个角度的位置
psi_loc (np.ndarray) – 第二个角度的位置
phi_concentration (np.ndarray) – 第一个角度的集中度
psi_concentration (np.ndarray) – 第二个角度的集中度
correlation (np.ndarray) – 两个角度之间的相关性
weighted_correlation (np.ndarray) – 将 correlation 设置为 weighted_corr * sqrt(phi_conc*psi_conc) 以避免双峰性(参见注意事项)。weighted_correlation 应在 [0,1] 范围内。
- arg_constraints = {'correlation': Real(), 'phi_concentration': Positive(lower_bound=0.0), 'phi_loc': Circular(lower_bound=-3.141592653589793, upper_bound=3.141592653589793), 'psi_concentration': Positive(lower_bound=0.0), 'psi_loc': Circular(lower_bound=-3.141592653589793, upper_bound=3.141592653589793)}
- support = IndependentConstraint(Circular(lower_bound=-3.141592653589793, upper_bound=3.141592653589793), 1)
- max_sample_iter = 1000
- log_prob(*args, **kwargs)
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
- sample(key, sample_shape=())[source]
- ** 参考资料:**
A New Unified Approach for the Simulation of a Wide Class of Directional Distributions John T. Kent, Asaad M. Ganeiber & Kanti V. Mardia (2018)
- property mean
计算分布的圆形均值。注意:当映射到支持集 [-pi, pi] 时与位置相同
正弦偏斜分布
- class SineSkewed(base_dist: Distribution, skewness, *, validate_args=None)[source]
基类:
Distribution
正弦偏斜 [1] 是一种用于生成打破环面分布的点对称性的分布的过程。新的分布称为正弦偏斜 X 分布,其中 X 是(对称)基分布的名称。环面分布是支持集在圆的乘积(即 \(\otimes S^1\),其中 \(S^1 = [-pi,pi)\))上的分布。因此,0-环面是一个点,1-环面是一个圆,而 2-环面通常与甜甜圈形状相关。
正弦偏斜 X 分布通过 X 事件的每个维度的权重参数进行参数化。例如,对于圆(1-环面)上的 von Mises 分布,正弦偏斜 von Mises 分布有一个偏斜参数。可以使用
HMC
或NUTS
推断偏斜参数。例如,以下代码将生成 2-环面的偏斜度先验,@numpyro.handlers.reparam(config={'phi_loc': CircularReparam(), 'psi_loc': CircularReparam()}) def model(obs): # Sine priors phi_loc = numpyro.sample('phi_loc', VonMises(pi, 2.)) psi_loc = numpyro.sample('psi_loc', VonMises(-pi / 2, 2.)) phi_conc = numpyro.sample('phi_conc', Beta(1., 1.)) psi_conc = numpyro.sample('psi_conc', Beta(1., 1.)) corr_scale = numpyro.sample('corr_scale', Beta(2., 5.)) # Skewing prior ball_trans = L1BallTransform() skewness = numpyro.sample('skew_phi', Normal(0, 0.5).expand((2,))) skewness = ball_trans(skewness) # constraint sum |skewness_i| <= 1 with numpyro.plate('obs_plate'): sine = SineBivariateVonMises(phi_loc=phi_loc, psi_loc=psi_loc, phi_concentration=70 * phi_conc, psi_concentration=70 * psi_conc, weighted_correlation=corr_scale) return numpyro.sample('phi_psi', SineSkewed(sine, skewness), obs=obs)
为了确保偏斜不会改变(正弦双变量 von Mises)基分布的归一化常数,对偏斜参数进行了约束。该约束要求偏斜度绝对值的总和小于或等于一。我们可以使用
L1BallTransform
来实现这一点。在
SVI
的上下文中,此分布可以自由地用作似然函数,但用作潜在变量会导致 2 维及更高维环面的推断变慢。这是因为 base_dist 无法重新参数化。注意
基分布中的事件必须在 d-环面上,因此 event_shape 必须是 (d,)。
注意
对于 skewness 参数,必须满足其权重对事件的绝对值之和小于或等于一。参见 [1] 中的公式 2.1。
- ** 参考资料:**
- 正弦偏斜环面分布及其在蛋白质生物信息学中的应用
Ameijeiras-Alonso, J., Ley, C. (2019)
- 参数:
base_dist (numpyro.distributions.Distribution) – d 维环面上的基密度。支持的基分布包括:一维
VonMises
、SineBivariateVonMises
、一维ProjectedNormal
和Uniform
(-pi, pi)。skewness (jax.numpy.array) – 分布的偏斜度。
- arg_constraints = {'skewness': L1Ball()}
- pytree_data_fields = ('base_dist', 'skewness')
- support = IndependentConstraint(Circular(lower_bound=-3.141592653589793, upper_bound=3.141592653589793), 1)
- sample(key, sample_shape=())[source]
返回形状为 sample_shape + batch_shape + event_shape 的分布样本。注意,当 sample_shape 非空时,返回样本的前导维度(大小为 sample_shape)将填充从分布实例中独立同分布抽取的样本。
- 参数:
key (jax.random.PRNGKey) – 用于分布的 rng_key。
sample_shape (tuple) – 分布的样本形状。
- 返回:
形状为 sample_shape + batch_shape + event_shape 的数组
- 返回类型:
- log_prob(value)[source]
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
- property mean
基分布的均值
Von Mises 分布
- class VonMises(loc, concentration, *, validate_args=None)[source]
基类:
Distribution
Von Mises 分布,也称为圆形正态分布。
此分布由从 -pi 到 +pi 的圆形约束支持。默认情况下,圆形支持集的行为类似于
constraints.interval(-math.pi, math.pi)
。为了避免在此区间的边界进行采样时出现问题,您应该在模型中使用handlers.reparam
和CircularReparam
重新参数化器对该分布进行重新参数化,例如:@handlers.reparam(config={"direction": CircularReparam()}) def model(): direction = numpyro.sample("direction", VonMises(0.0, 4.0)) ...
- arg_constraints = {'concentration': Positive(lower_bound=0.0), 'loc': Real()}
- reparametrized_params = ['loc']
- support = Circular(lower_bound=-3.141592653589793, upper_bound=3.141592653589793)
- sample(key, sample_shape=())[source]
从 von Mises 分布生成样本
- 参数:
key – 随机数生成器密钥
sample_shape – 样本的形状
- 返回:
从 von Mises 分布生成的样本
- log_prob(*args, **kwargs)
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
- property mean
计算分布的圆形均值。注意:当映射到支持集 [-pi, pi] 时与位置相同
- property variance
计算分布的圆形方差
截断分布
双重截断幂律分布
- class DoublyTruncatedPowerLaw(alpha, low, high, *, validate_args=None)[source]
基类:
Distribution
具有 \(\alpha\) 指数以及上下界的幂律分布。我们可以将幂律分布定义为:
\[f(x; \alpha, a, b) = \frac{x^{\alpha}}{Z(\alpha, a, b)},\]其中,\(a\) 和 \(b\) 分别是下界和上界,\(Z(\alpha, a, b)\) 是归一化常数。其定义为:
\[\begin{split}Z(\alpha, a, b) = \begin{cases} \log(b) - \log(a) & \text{if } \alpha = -1, \\ \frac{b^{1 + \alpha} - a^{1 + \alpha}}{1 + \alpha} & \text{otherwise}. \end{cases}\end{split}\]- 参数:
alpha – 幂律分布的指数
low – 分布的下界
high – 分布的上界
- arg_constraints = {'alpha': Real(), 'high': GreaterThan(lower_bound=0), 'low': GreaterThanEq(lower_bound=0)}
- reparametrized_params = ['alpha', 'low', 'high']
- pytree_aux_fields = ('_support',)
- pytree_data_fields = ('alpha', 'low', 'high')
- property support
- log_prob(*args, **kwargs)
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
- cdf(value)[source]
累积概率分布:Z 不等于负一
\[\frac{x^{\alpha + 1} - a^{\alpha + 1}}{b^{\alpha + 1} - a^{1 + \alpha}}\]Z 等于负一
\[\frac{\log(x) - \log(a)}{\log(b) - \log(a)}\]推导是根据雅可比矩阵由 Wolfram Alpha 计算得出。
- icdf(q)[source]
逆累积概率分布:Z 不等于负一
\[a \left(\frac{b}{a}\right)^{q}\]Z 等于负一
\[\left(a^{1 + \alpha} + q (b^{1 + \alpha} - a^{1 + \alpha})\right)^{\frac{1}{1 + \alpha}}\]推导是根据雅可比矩阵由 Wolfram Alpha 计算得出。
左截断分布
- class LeftTruncatedDistribution(base_dist, low=0.0, *, validate_args=None)[source]
基类:
Distribution
- arg_constraints = {'low': Real()}
- reparametrized_params = ['low']
- supported_types = (<class 'numpyro.distributions.continuous.Cauchy'>, <class 'numpyro.distributions.continuous.Laplace'>, <class 'numpyro.distributions.continuous.Logistic'>, <class 'numpyro.distributions.continuous.Normal'>, <class 'numpyro.distributions.continuous.SoftLaplace'>, <class 'numpyro.distributions.continuous.StudentT'>)
- pytree_data_fields = ('base_dist', 'low', '_support')
- property support
- sample(key, sample_shape=())[source]
返回形状为 sample_shape + batch_shape + event_shape 的分布样本。注意,当 sample_shape 非空时,返回样本的前导维度(大小为 sample_shape)将填充从分布实例中独立同分布抽取的样本。
- 参数:
key (jax.random.PRNGKey) – 用于分布的 rng_key。
sample_shape (tuple) – 分布的样本形状。
- 返回:
形状为 sample_shape + batch_shape + event_shape 的数组
- 返回类型:
- log_prob(*args, **kwargs)
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
- property mean
分布的均值。
- property var
下截断幂律分布
- class LowerTruncatedPowerLaw(alpha, low, *, validate_args=None)[source]
基类:
Distribution
具有 \(\alpha\) 指数的下截断幂律分布。我们可以将幂律分布定义为:
\[f(x; \alpha, a) = (-\alpha-1)a^{-\alpha - 1}x^{-\alpha}, \qquad x \geq a, \qquad \alpha < -1,\]其中,\(a\) 是下界。分布的 cdf 由下式给出:
\[F(x; \alpha, a) = 1 - \left(\frac{x}{a}\right)^{1+\alpha}.\]分布的 k 阶矩由下式给出:
\[\begin{split}E[X^k] = \begin{cases} \frac{-\alpha-1}{-\alpha-1-k}a^k & \text{if } k < -\alpha-1, \\ \infty & \text{otherwise}. \end{cases}\end{split}\]- 参数:
alpha – 幂律分布的指数
low – 分布的下界
- arg_constraints = {'alpha': LessThan(upper_bound=-1.0), 'low': GreaterThan(lower_bound=0.0)}
- reparametrized_params = ['alpha', 'low']
- pytree_aux_fields = ('_support',)
- property support
- log_prob(*args, **kwargs)
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
右截断分布
- class RightTruncatedDistribution(base_dist, high=0.0, *, validate_args=None)[source]
基类:
Distribution
- arg_constraints = {'high': Real()}
- reparametrized_params = ['high']
- supported_types = (<class 'numpyro.distributions.continuous.Cauchy'>, <class 'numpyro.distributions.continuous.Laplace'>, <class 'numpyro.distributions.continuous.Logistic'>, <class 'numpyro.distributions.continuous.Normal'>, <class 'numpyro.distributions.continuous.SoftLaplace'>, <class 'numpyro.distributions.continuous.StudentT'>)
- pytree_data_fields = ('base_dist', 'high', '_support')
- property support
- sample(key, sample_shape=())[source]
返回形状为 sample_shape + batch_shape + event_shape 的分布样本。注意,当 sample_shape 非空时,返回样本的前导维度(大小为 sample_shape)将填充从分布实例中独立同分布抽取的样本。
- 参数:
key (jax.random.PRNGKey) – 用于分布的 rng_key。
sample_shape (tuple) – 分布的样本形状。
- 返回:
形状为 sample_shape + batch_shape + event_shape 的数组
- 返回类型:
- log_prob(*args, **kwargs)
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
- property mean
分布的均值。
- property var
截断柯西分布
截断分布
TruncatedNormal
TruncatedPolyaGamma
- class TruncatedPolyaGamma(batch_shape=(), *, validate_args=None)[source]
基类:
Distribution
- truncation_point = 2.5
- num_log_prob_terms = 7
- num_gamma_variates = 8
- arg_constraints = {}
- support = Interval(lower_bound=0.0, upper_bound=2.5)
- sample(key, sample_shape=())[source]
返回形状为 sample_shape + batch_shape + event_shape 的分布样本。注意,当 sample_shape 非空时,返回样本的前导维度(大小为 sample_shape)将填充从分布实例中独立同分布抽取的样本。
- 参数:
key (jax.random.PRNGKey) – 用于分布的 rng_key。
sample_shape (tuple) – 分布的样本形状。
- 返回:
形状为 sample_shape + batch_shape + event_shape 的数组
- 返回类型:
- log_prob(*args, **kwargs)
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
TwoSidedTruncatedDistribution
- class TwoSidedTruncatedDistribution(base_dist, low=0.0, high=1.0, *, validate_args=None)[source]
基类:
Distribution
- arg_constraints = {'high': Dependent(), 'low': Dependent()}
- reparametrized_params = ['low', 'high']
- supported_types = (<class 'numpyro.distributions.continuous.Cauchy'>, <class 'numpyro.distributions.continuous.Laplace'>, <class 'numpyro.distributions.continuous.Logistic'>, <class 'numpyro.distributions.continuous.Normal'>, <class 'numpyro.distributions.continuous.SoftLaplace'>, <class 'numpyro.distributions.continuous.StudentT'>)
- pytree_data_fields = ('base_dist', 'low', 'high', '_support')
- property support
- sample(key, sample_shape=())[source]
返回形状为 sample_shape + batch_shape + event_shape 的分布样本。注意,当 sample_shape 非空时,返回样本的前导维度(大小为 sample_shape)将填充从分布实例中独立同分布抽取的样本。
- 参数:
key (jax.random.PRNGKey) – 用于分布的 rng_key。
sample_shape (tuple) – 分布的样本形状。
- 返回:
形状为 sample_shape + batch_shape + event_shape 的数组
- 返回类型:
- log_prob(*args, **kwargs)
评估由 value 给出的一批样本的对数概率密度。
- 参数:
value – 分布的一批样本。
- 返回:
形状为 value.shape[:-self.event_shape] 的数组
- 返回类型:
- property mean
分布的均值。
- property var
TensorFlow Distributions
TensorFlow Probability (TFP) 分布的轻量级封装。有关 TFP 分布接口的详细信息,请参阅其分布文档。
BijectorConstraint
BijectorTransform
TFPDistribution
- class TFPDistribution(batch_shape=(), event_shape=(), *, validate_args=None)[source]
TensorFlow Probability (TFP) 分布的轻量级封装。其构造函数与对应的 TFP 分布具有相同的签名。
此类可用于将 TFP 分布转换为与 NumPyro 兼容的分布,如下所示
d = TFPDistribution[tfd.Normal](0, 1)
请注意,典型用例无需显式调用此封装器,因为 NumPyro 在模型代码中会自动封装 TFP 分布,例如
from tensorflow_probability.substrates.jax import distributions as tfd def model(): numpyro.sample("x", tfd.Normal(0, 1))
Constraints
Constraint
boolean
- boolean = Boolean()
circular
- circular = Circular(lower_bound=-3.141592653589793, upper_bound=3.141592653589793)
corr_cholesky
- corr_cholesky = CorrCholesky()
corr_matrix
- corr_matrix = CorrMatrix()
dependent
greater_than
- greater_than(lower_bound)
约束的抽象基类。
约束对象表示变量有效的区域,例如变量可以在其中进行优化的区域。
integer_interval
- integer_interval(lower_bound, upper_bound)
约束的抽象基类。
约束对象表示变量有效的区域,例如变量可以在其中进行优化的区域。
integer_greater_than
- integer_greater_than(lower_bound)
约束的抽象基类。
约束对象表示变量有效的区域,例如变量可以在其中进行优化的区域。
interval
- interval(lower_bound, upper_bound)
约束的抽象基类。
约束对象表示变量有效的区域,例如变量可以在其中进行优化的区域。
l1_ball
- l1_ball(x)
约束到任意维度的 L1 球。
less_than
- less_than(upper_bound)
约束的抽象基类。
约束对象表示变量有效的区域,例如变量可以在其中进行优化的区域。
lower_cholesky
- lower_cholesky = LowerCholesky()
multinomial
- multinomial(upper_bound)
约束的抽象基类。
约束对象表示变量有效的区域,例如变量可以在其中进行优化的区域。
nonnegative_integer
- nonnegative_integer = IntegerNonnegative(lower_bound=0)
ordered_vector
- ordered_vector = OrderedVector()
positive
- positive = Positive(lower_bound=0.0)
positive_definite
- positive_definite = PositiveDefinite()
positive_integer
- positive_integer = IntegerPositive(lower_bound=1)
positive_ordered_vector
- positive_ordered_vector = PositiveOrderedVector()
约束到一个正实值张量,其中元素沿 event_shape 维度单调递增。
real
- real = Real()
real_vector
- real_vector = RealVector(Real(), 1)
scaled_unit_lower_cholesky
- scaled_unit_lower_cholesky = ScaledUnitLowerCholesky()
softplus_positive
- softplus_positive = SoftplusPositive(lower_bound=0.0)
softplus_lower_cholesky
- softplus_lower_cholesky = SoftplusLowerCholesky()
simplex
- simplex = Simplex()
sphere
- sphere = Sphere()
约束到任意维度的欧几里得球面。
unit_interval
- unit_interval = UnitInterval(lower_bound=0.0, upper_bound=1.0)
zero_sum
- zero_sum = <class 'numpyro.distributions.constraints._ZeroSum'>
约束的抽象基类。
约束对象表示变量有效的区域,例如变量可以在其中进行优化的区域。
Transforms
biject_to
- biject_to(constraint)
Transform
AbsTransform
AffineTransform
CholeskyTransform
ComplexTransform
ComposeTransform
CorrCholeskyTransform
- class CorrCholeskyTransform[source]
基类:
ParameterFreeTransform
将长度为 \(D*(D-1)/2\) 的无约束实向量 \(x\) 变换为 D 维相关矩阵的 Cholesky 因子。此 Cholesky 因子是一个下三角矩阵,具有正对角线元素且每行具有单位欧几里得范数。变换过程如下
首先,我们将 \(x\) 按以下顺序转换为下三角矩阵
\[\begin{split}\begin{bmatrix} 1 & 0 & 0 & 0 \\ x_0 & 1 & 0 & 0 \\ x_1 & x_2 & 1 & 0 \\ x_3 & x_4 & x_5 & 1 \end{bmatrix}\end{split}\]2. 对于下三角部分的每一行 \(X_i\),我们应用类
StickBreakingTransform
的一个 *带符号* 版本,使用以下步骤将 \(X_i\) 变换为单位欧几里得长度向量缩放到区间 \((-1, 1)\) 域:\(r_i = \tanh(X_i)\)。
变换为无符号域:\(z_i = r_i^2\)。
应用 \(s_i = StickBreakingTransform(z_i)\)。
变换回带符号域:\(y_i = (sign(r_i), 1) * \sqrt{s_i}\)。
- domain = RealVector(Real(), 1)
- codomain = CorrCholesky()
CorrMatrixCholeskyTransform
ExpTransform
IdentityTransform
L1BallTransform
LowerCholeskyAffine
- class LowerCholeskyAffine(loc, scale_tril)[source]
基类:
Transform
通过映射 \(y = loc + scale\_tril\ @\ x\) 进行变换。
- 参数:
loc – 一个实向量。
scale_tril – 一个具有正对角线元素的下三角矩阵。
示例
>>> import jax.numpy as jnp >>> from numpyro.distributions.transforms import LowerCholeskyAffine >>> base = jnp.ones(2) >>> loc = jnp.zeros(2) >>> scale_tril = jnp.array([[0.3, 0.0], [1.0, 0.5]]) >>> affine = LowerCholeskyAffine(loc=loc, scale_tril=scale_tril) >>> affine(base) Array([0.3, 1.5], dtype=float32)
- domain = RealVector(Real(), 1)
- codomain = RealVector(Real(), 1)
LowerCholeskyTransform
OrderedTransform
- class OrderedTransform[source]
基类:
ParameterFreeTransform
将实向量变换为有序向量。
参考文献
Stan Reference Manual v2.20, section 10.6, Stan Development Team (Stan 参考手册 v2.20,第 10.6 节,Stan 开发团队)
示例
>>> import jax.numpy as jnp >>> from numpyro.distributions.transforms import OrderedTransform >>> base = jnp.ones(3) >>> transform = OrderedTransform() >>> assert jnp.allclose(transform(base), jnp.array([1., 3.7182817, 6.4365635]), rtol=1e-3, atol=1e-3)
- domain = RealVector(Real(), 1)
- codomain = OrderedVector()
PackRealFastFourierCoefficientsTransform
置换变换
幂变换
实数快速傅里叶变换
- class RealFastFourierTransform(transform_shape=None, transform_ndims=1)[source]
基类:
Transform
用于实数输入的 N 维离散快速傅里叶变换。
- 参数:
transform_shape – 输入中用于变换的每个轴的长度,默认为输入大小。
transform_ndims – 要变换的尾随维度的数量。
- property domain: Constraint
- property codomain: Constraint
递归线性变换
- class RecursiveLinearTransform(transition_matrix: Array)[source]
基类:
Transform
递归地应用线性变换,使得当 \(t > 0\) 时,\(y_t = A y_{t - 1} + x_t\),其中 \(x_t\) 和 \(y_t\) 是向量,\(A\) 是一个方阵形式的转移矩阵。该序列由 \(y_0 = 0\) 初始化。
- 参数:
transition_matrix – 用于连续状态的方阵形式的转移矩阵 \(A\) 或一批转移矩阵。
示例
>>> from jax import random >>> from jax import numpy as jnp >>> import numpyro >>> from numpyro import distributions as dist >>> >>> def cauchy_random_walk(): ... return numpyro.sample( ... "x", ... dist.TransformedDistribution( ... dist.Cauchy(0, 1).expand([10, 1]).to_event(1), ... dist.transforms.RecursiveLinearTransform(jnp.eye(1)), ... ), ... ) >>> >>> numpyro.handlers.seed(cauchy_random_walk, 0)().shape (10, 1) >>> >>> def rocket_trajectory(): ... scale = numpyro.sample( ... "scale", ... dist.HalfCauchy(1).expand([2]).to_event(1), ... ) ... transition_matrix = jnp.array([[1, 1], [0, 1]]) ... return numpyro.sample( ... "x", ... dist.TransformedDistribution( ... dist.Normal(0, scale).expand([10, 2]).to_event(1), ... dist.transforms.RecursiveLinearTransform(transition_matrix), ... ), ... ) >>> >>> numpyro.handlers.seed(rocket_trajectory, 0)().shape (10, 2)
- domain = RealMatrix(Real(), 2)
- codomain = RealMatrix(Real(), 2)
缩放单位下三角 Cholesky 变换
- class ScaledUnitLowerCholeskyTransform[source]
-
与 LowerCholeskyTransform 类似,这个 Transform 将实数向量转换为下三角 Cholesky 因子。然而,它通过以下分解来实现:
\(y = loc + unit\_scale\_tril\ @\ scale\_diag\ @\ x\).
其中 \(unit\_scale\_tril\) 的对角线上是 1,而 \(scale\_diag\) 是一个所有元素都为正的对角矩阵,它通过 softplus 变换参数化。
- domain = RealVector(Real(), 1)
- codomain = ScaledUnitLowerCholesky()
Sigmoid 变换
单形到有序变换
- class SimplexToOrderedTransform(anchor_point=0.0)[source]
基类:
Transform
将单形变换为有序向量(通过分界点之间的 Logistic CDF 差)。在 [1] 中用于通过变换有序类别概率来在潜在分界点上引入先验。
- 参数:
anchor_point – 锚点是一个干扰参数,用于提高变换的可识别性。为简单起见,我们假设它是一个标量值,但它可以广播到 x.shape[:-1]。更多详细信息请参阅 [1] 的第 2.2 节。
参考文献
有序回归案例研究,第 2.2 节, M. Betancourt, https://betanalpha.github.io/assets/case_studies/ordinal_regression.html
示例
>>> import jax.numpy as jnp >>> from numpyro.distributions.transforms import SimplexToOrderedTransform >>> base = jnp.array([0.3, 0.1, 0.4, 0.2]) >>> transform = SimplexToOrderedTransform() >>> assert jnp.allclose(transform(base), jnp.array([-0.8472978, -0.40546507, 1.3862944]), rtol=1e-3, atol=1e-3)
- domain = Simplex()
- codomain = OrderedVector()
Softplus 下三角 Cholesky 变换
Softplus 变换
Stick Breaking 变换
零和变换
- class ZeroSumTransform(transform_ndims: int = 1)[source]
基类:
Transform
一种将数组约束为总和为零的变换,改编自 PyMC [1],如 [2,3] 中所述。
- 参数:
transform_ndims – 要变换的尾随维度的数量。
References [1] https://github.com/pymc-devs/pymc/blob/244fb97b01ad0f3dadf5c3837b65839e2a59a0e8/pymc/distributions/transforms.py#L266 [2] https://pymc.cn/projects/docs/en/stable/api/distributions/generated/pymc.ZeroSumNormal.html [3] https://learnbayesstats.com/episode/74-optimizing-nuts-developing-zerosumnormal-distribution-adrian-seyboldt/
- property domain: Constraint
- property codomain: Constraint
流
逆自回归变换
- class InverseAutoregressiveTransform(autoregressive_nn, log_scale_min_clip=-5.0, log_scale_max_clip=3.0)[source]
基类:
Transform
逆自回归流 (Inverse Autoregressive Flow) 的实现,使用了 Kingma 等人 (2016) 论文中的公式 (10):
\(\mathbf{y} = \mu_t + \sigma_t\odot\mathbf{x}\)
其中 \(\mathbf{x}\) 是输入,\(\mathbf{y}\) 是输出,\(\mu_t,\sigma_t\) 是从基于 \(\mathbf{x}\) 的自回归网络计算得出的,且 \(\sigma_t>0\)。
参考文献
使用逆自回归流改进变分推断 [arXiv:1606.04934], Diederik P. Kingma, Tim Salimans, Rafal Jozefowicz, Xi Chen, Ilya Sutskever, Max Welling
- domain = RealVector(Real(), 1)
- codomain = RealVector(Real(), 1)
- log_abs_det_jacobian(x, y, intermediates=None)[source]
计算对数雅可比矩阵的元素级行列式。
- 参数:
x (numpy.ndarray) – 变换的输入
y (numpy.ndarray) – 变换的输出
块神经网络自回归变换
- class BlockNeuralAutoregressiveTransform(bn_arn)[source]
基类:
Transform
块神经网络自回归流的实现。
参考文献
块神经网络自回归流, Nicola De Cao, Ivan Titov, Wilker Aziz
- domain = RealVector(Real(), 1)
- codomain = RealVector(Real(), 1)
- log_abs_det_jacobian(x, y, intermediates=None)[source]
计算对数雅可比矩阵的元素级行列式。
- 参数:
x (numpy.ndarray) – 变换的输入
y (numpy.ndarray) – 变换的输出
实用工具
log1mexp
logdiffexp
- logdiffexp(a: Array | ndarray | bool_ | number | bool | int | float | complex, b: Array | ndarray | bool_ | number | bool | int | float | complex) Array | ndarray | bool_ | number | bool | int |float | complex [source]
数值稳定的计算量 \(\log(\exp(a) - \exp(b))\),前提是 \(+\infty > a \ge b\),遵循了 Mächler 2012 的算法。
当
a == b
时返回-jnp.inf
,包括当a == b == -jnp.inf
时,因为这对应于jnp.log(0)
。当a < b
或a == jnp.inf
时返回jnp.nan
。- 参数:
a – 一个数字或数字数组。
b – 一个数字或数字数组。
- 返回:
\(\log(\exp(a) - \exp(b))\) 的值。