优化器
此处定义的优化器类是对来源于 jax.example_libraries.optimizers
相应优化器的轻量级封装,其接口更适用于与 NumPyro 推断算法配合使用。
Adam
- class Adam(*args, **kwargs)[source]
JAX 优化器
adam()
的封装类- eval_and_stable_update(fn: Callable[[Any], tuple], state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any], forward_mode_differentiation: bool = False) tuple[tuple[Any, Any], tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]]
与
eval_and_update()
类似,但当目标函数的值或梯度不是有限值时,我们将不更新输入 state,并将目标输出设置为 nan。- 参数:
fn – 目标函数。
state – 当前优化器状态。
forward_mode_differentiation – 布尔标志,指示是否使用前向模式微分。
- 返回:
目标函数输出和新优化器状态的对。
- eval_and_update(fn: Callable[[Any], tuple], state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any], forward_mode_differentiation: bool = False) tuple[tuple[Any, Any], tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]]
对目标函数 fn 执行一个优化步骤。对于大多数优化器,更新是基于目标函数相对于当前状态的梯度进行的。然而,对于某些优化器(例如
Minimize
),更新是通过多次重新评估函数来获取最优参数进行的。- 参数:
fn – 一个目标函数,返回一个对,其中第一个项目是标量损失函数(用于求导),第二个项目是辅助输出。
state – 当前优化器状态。
forward_mode_differentiation – 布尔标志,指示是否使用前向模式微分。
- 返回:
目标函数输出和新优化器状态的对。
- get_params(state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]) Any
获取当前参数值。
- 参数:
state – 当前优化器状态。
- 返回:
包含当前参数值的集合。
- init(params: Any) tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]
使用指定用于优化的参数初始化优化器。
- 参数:
params – numpy 数组的集合。
- 返回:
初始优化器状态。
- update(g: Any, state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any], value: Array | ndarray | bool_ | number | bool | int | float | complex | None = None) tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]
优化器的梯度更新。
- 参数:
g – 参数的梯度信息。
state – 当前优化器状态。
- 返回:
更新后的新优化器状态。
Adagrad
- class Adagrad(*args, **kwargs)[source]
JAX 优化器
adagrad()
的封装类- eval_and_stable_update(fn: Callable[[Any], tuple], state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any], forward_mode_differentiation: bool = False) tuple[tuple[Any, Any], tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]]
与
eval_and_update()
类似,但当目标函数的值或梯度不是有限值时,我们将不更新输入 state,并将目标输出设置为 nan。- 参数:
fn – 目标函数。
state – 当前优化器状态。
forward_mode_differentiation – 布尔标志,指示是否使用前向模式微分。
- 返回:
目标函数输出和新优化器状态的对。
- eval_and_update(fn: Callable[[Any], tuple], state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any], forward_mode_differentiation: bool = False) tuple[tuple[Any, Any], tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]]
对目标函数 fn 执行一个优化步骤。对于大多数优化器,更新是基于目标函数相对于当前状态的梯度进行的。然而,对于某些优化器(例如
Minimize
),更新是通过多次重新评估函数来获取最优参数进行的。- 参数:
fn – 一个目标函数,返回一个对,其中第一个项目是标量损失函数(用于求导),第二个项目是辅助输出。
state – 当前优化器状态。
forward_mode_differentiation – 布尔标志,指示是否使用前向模式微分。
- 返回:
目标函数输出和新优化器状态的对。
- get_params(state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]) Any
获取当前参数值。
- 参数:
state – 当前优化器状态。
- 返回:
包含当前参数值的集合。
- init(params: Any) tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]
使用指定用于优化的参数初始化优化器。
- 参数:
params – numpy 数组的集合。
- 返回:
初始优化器状态。
- update(g: Any, state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any], value: Array | ndarray | bool_ | number | bool | int | float | complex | None = None) tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]
优化器的梯度更新。
- 参数:
g – 参数的梯度信息。
state – 当前优化器状态。
- 返回:
更新后的新优化器状态。
ClippedAdam
- class ClippedAdam(*args, clip_norm: float = 10.0, **kwargs)[source]
带有梯度裁剪的
Adam
优化器。- 参数:
clip_norm (float) – 所有梯度值将被裁剪到 [-clip_norm, clip_norm] 之间。
参考
A Method for Stochastic Optimization, Diederik P. Kingma, Jimmy Ba https://arxiv.org/abs/1412.6980
- eval_and_stable_update(fn: Callable[[Any], tuple], state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any], forward_mode_differentiation: bool = False) tuple[tuple[Any, Any], tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]]
与
eval_and_update()
类似,但当目标函数的值或梯度不是有限值时,我们将不更新输入 state,并将目标输出设置为 nan。- 参数:
fn – 目标函数。
state – 当前优化器状态。
forward_mode_differentiation – 布尔标志,指示是否使用前向模式微分。
- 返回:
目标函数输出和新优化器状态的对。
- eval_and_update(fn: Callable[[Any], tuple], state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any], forward_mode_differentiation: bool = False) tuple[tuple[Any, Any], tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]]
对目标函数 fn 执行一个优化步骤。对于大多数优化器,更新是基于目标函数相对于当前状态的梯度进行的。然而,对于某些优化器(例如
Minimize
),更新是通过多次重新评估函数来获取最优参数进行的。- 参数:
fn – 一个目标函数,返回一个对,其中第一个项目是标量损失函数(用于求导),第二个项目是辅助输出。
state – 当前优化器状态。
forward_mode_differentiation – 布尔标志,指示是否使用前向模式微分。
- 返回:
目标函数输出和新优化器状态的对。
- get_params(state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]) Any
获取当前参数值。
- 参数:
state – 当前优化器状态。
- 返回:
包含当前参数值的集合。
Minimize
- class Minimize(method='BFGS', **kwargs)[source]
JAX 最小化器
minimize()
的封装类。示例
>>> from numpy.testing import assert_allclose >>> from jax import random >>> import jax.numpy as jnp >>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.infer import SVI, Trace_ELBO >>> from numpyro.infer.autoguide import AutoLaplaceApproximation >>> def model(x, y): ... a = numpyro.sample("a", dist.Normal(0, 1)) ... b = numpyro.sample("b", dist.Normal(0, 1)) ... with numpyro.plate("N", y.shape[0]): ... numpyro.sample("obs", dist.Normal(a + b * x, 0.1), obs=y) >>> x = jnp.linspace(0, 10, 100) >>> y = 3 * x + 2 >>> optimizer = numpyro.optim.Minimize() >>> guide = AutoLaplaceApproximation(model) >>> svi = SVI(model, guide, optimizer, loss=Trace_ELBO()) >>> init_state = svi.init(random.PRNGKey(0), x, y) >>> optimal_state, loss = svi.update(init_state, x, y) >>> params = svi.get_params(optimal_state) # get guide's parameters >>> quantiles = guide.quantiles(params, 0.5) # get means of posterior samples >>> assert_allclose(quantiles["a"], 2., atol=1e-3) >>> assert_allclose(quantiles["b"], 3., atol=1e-3)
- eval_and_stable_update(fn: Callable[[Any], tuple], state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any], forward_mode_differentiation: bool = False) tuple[tuple[Any, Any], tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]]
类似于
eval_and_update()
,但当目标函数值或梯度不是有限值时,我们将不更新输入 state 并将目标函数输出设置为 nan。- 参数:
fn – 目标函数。
state – 当前优化器状态。
forward_mode_differentiation – 布尔标志,指示是否使用前向模式微分。
- 返回:
目标函数输出和新优化器状态的对。
- eval_and_update(fn: Callable[[Any], tuple], state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any], forward_mode_differentiation: bool = False) tuple[tuple[Any, None], tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]] [source]
对目标函数 fn 执行一个优化步骤。对于大多数优化器,更新是基于目标函数相对于当前状态的梯度进行的。然而,对于某些优化器(例如
Minimize
),更新是通过多次重新评估函数来获取最优参数进行的。- 参数:
fn – 一个目标函数,返回一个对,其中第一个项目是标量损失函数(用于求导),第二个项目是辅助输出。
state – 当前优化器状态。
forward_mode_differentiation – 布尔标志,指示是否使用前向模式微分。
- 返回:
目标函数输出和新优化器状态的对。
- get_params(state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]) Any
获取当前参数值。
- 参数:
state – 当前优化器状态。
- 返回:
包含当前参数值的集合。
- init(params: Any) tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]
使用指定用于优化的参数初始化优化器。
- 参数:
params – numpy 数组的集合。
- 返回:
初始优化器状态。
- update(g: Any, state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any], value: Array | ndarray | bool_ | number | bool | int | float | complex | None = None) tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]
优化器的梯度更新。
- 参数:
g – 参数的梯度信息。
state – 当前优化器状态。
- 返回:
更新后的新优化器状态。
动量
- class Momentum(*args, **kwargs)[source]
JAX 优化器
momentum()
的包装类。- eval_and_stable_update(fn: Callable[[Any], tuple], state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any], forward_mode_differentiation: bool = False) tuple[tuple[Any, Any], tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]]
类似于
eval_and_update()
,但当目标函数值或梯度不是有限值时,我们将不更新输入 state 并将目标函数输出设置为 nan。- 参数:
fn – 目标函数。
state – 当前优化器状态。
forward_mode_differentiation – 布尔标志,指示是否使用前向模式微分。
- 返回:
目标函数输出和新优化器状态的对。
- eval_and_update(fn: Callable[[Any], tuple], state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any], forward_mode_differentiation: bool = False) tuple[tuple[Any, Any], tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]]
对目标函数 fn 执行一个优化步骤。对于大多数优化器,更新是基于目标函数相对于当前状态的梯度进行的。然而,对于某些优化器(例如
Minimize
),更新是通过多次重新评估函数来获取最优参数进行的。- 参数:
fn – 一个目标函数,返回一个对,其中第一个项目是标量损失函数(用于求导),第二个项目是辅助输出。
state – 当前优化器状态。
forward_mode_differentiation – 布尔标志,指示是否使用前向模式微分。
- 返回:
目标函数输出和新优化器状态的对。
- get_params(state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]) Any
获取当前参数值。
- 参数:
state – 当前优化器状态。
- 返回:
包含当前参数值的集合。
- init(params: Any) tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]
使用指定用于优化的参数初始化优化器。
- 参数:
params – numpy 数组的集合。
- 返回:
初始优化器状态。
- update(g: Any, state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any], value: Array | ndarray | bool_ | number | bool | int | float | complex | None = None) tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]
优化器的梯度更新。
- 参数:
g – 参数的梯度信息。
state – 当前优化器状态。
- 返回:
更新后的新优化器状态。
RMSProp
- class RMSProp(*args, **kwargs)[source]
JAX 优化器
rmsprop()
的包装类。- eval_and_stable_update(fn: Callable[[Any], tuple], state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any], forward_mode_differentiation: bool = False) tuple[tuple[Any, Any], tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]]
类似于
eval_and_update()
,但当目标函数值或梯度不是有限值时,我们将不更新输入 state 并将目标函数输出设置为 nan。- 参数:
fn – 目标函数。
state – 当前优化器状态。
forward_mode_differentiation – 布尔标志,指示是否使用前向模式微分。
- 返回:
目标函数输出和新优化器状态的对。
- eval_and_update(fn: Callable[[Any], tuple], state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any], forward_mode_differentiation: bool = False) tuple[tuple[Any, Any], tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]]
对目标函数 fn 执行一个优化步骤。对于大多数优化器,更新是基于目标函数相对于当前状态的梯度进行的。然而,对于某些优化器(例如
Minimize
),更新是通过多次重新评估函数来获取最优参数进行的。- 参数:
fn – 一个目标函数,返回一个对,其中第一个项目是标量损失函数(用于求导),第二个项目是辅助输出。
state – 当前优化器状态。
forward_mode_differentiation – 布尔标志,指示是否使用前向模式微分。
- 返回:
目标函数输出和新优化器状态的对。
- get_params(state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]) Any
获取当前参数值。
- 参数:
state – 当前优化器状态。
- 返回:
包含当前参数值的集合。
- init(params: Any) tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]
使用指定用于优化的参数初始化优化器。
- 参数:
params – numpy 数组的集合。
- 返回:
初始优化器状态。
- update(g: Any, state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any], value: Array | ndarray | bool_ | number | bool | int | float | complex | None = None) tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]
优化器的梯度更新。
- 参数:
g – 参数的梯度信息。
state – 当前优化器状态。
- 返回:
更新后的新优化器状态。
RMSPropMomentum
- class RMSPropMomentum(*args, **kwargs)[source]
JAX 优化器
rmsprop_momentum()
的包装类。- eval_and_stable_update(fn: Callable[[Any], tuple], state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any], forward_mode_differentiation: bool = False) tuple[tuple[Any, Any], tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]]
类似于
eval_and_update()
,但当目标函数值或梯度不是有限值时,我们将不更新输入 state 并将目标函数输出设置为 nan。- 参数:
fn – 目标函数。
state – 当前优化器状态。
forward_mode_differentiation – 布尔标志,指示是否使用前向模式微分。
- 返回:
目标函数输出和新优化器状态的对。
- eval_and_update(fn: Callable[[Any], tuple], state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any], forward_mode_differentiation: bool = False) tuple[tuple[Any, Any], tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]]
对目标函数 fn 执行一个优化步骤。对于大多数优化器,更新是基于目标函数相对于当前状态的梯度进行的。然而,对于某些优化器(例如
Minimize
),更新是通过多次重新评估函数来获取最优参数进行的。- 参数:
fn – 一个目标函数,返回一个对,其中第一个项目是标量损失函数(用于求导),第二个项目是辅助输出。
state – 当前优化器状态。
forward_mode_differentiation – 布尔标志,指示是否使用前向模式微分。
- 返回:
目标函数输出和新优化器状态的对。
- get_params(state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]) Any
获取当前参数值。
- 参数:
state – 当前优化器状态。
- 返回:
包含当前参数值的集合。
- init(params: Any) tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]
使用指定用于优化的参数初始化优化器。
- 参数:
params – numpy 数组的集合。
- 返回:
初始优化器状态。
- update(g: Any, state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any], value: Array | ndarray | bool_ | number | bool | int | float | complex | None = None) tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]
优化器的梯度更新。
- 参数:
g – 参数的梯度信息。
state – 当前优化器状态。
- 返回:
更新后的新优化器状态。
SGD
- class SGD(*args, **kwargs)[source]
JAX 优化器
sgd()
的封装类。- eval_and_stable_update(fn: Callable[[Any], tuple], state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any], forward_mode_differentiation: bool = False) tuple[tuple[Any, Any], tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]]
类似于
eval_and_update()
,但当目标函数值或梯度不是有限值时,我们不会更新输入的 state,并将目标输出设置为 nan。- 参数:
fn – 目标函数。
state – 当前优化器状态。
forward_mode_differentiation – 布尔标志,指示是否使用前向模式微分。
- 返回:
目标函数输出和新优化器状态的对。
- eval_and_update(fn: Callable[[Any], tuple], state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any], forward_mode_differentiation: bool = False) tuple[tuple[Any, Any], tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]]
对目标函数 fn 执行一个优化步骤。对于大多数优化器,更新是基于目标函数相对于当前状态的梯度进行的。然而,对于某些优化器(例如
Minimize
),更新是通过多次重新评估函数来获取最优参数进行的。- 参数:
fn – 一个目标函数,返回一个对,其中第一个项目是标量损失函数(用于求导),第二个项目是辅助输出。
state – 当前优化器状态。
forward_mode_differentiation – 布尔标志,指示是否使用前向模式微分。
- 返回:
目标函数输出和新优化器状态的对。
- get_params(state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]) Any
获取当前参数值。
- 参数:
state – 当前优化器状态。
- 返回:
包含当前参数值的集合。
- init(params: Any) tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]
使用指定用于优化的参数初始化优化器。
- 参数:
params – numpy 数组的集合。
- 返回:
初始优化器状态。
- update(g: Any, state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any], value: Array | ndarray | bool_ | number | bool | int | float | complex | None = None) tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]
优化器的梯度更新。
- 参数:
g – 参数的梯度信息。
state – 当前优化器状态。
- 返回:
更新后的新优化器状态。
SM3
- class SM3(*args, **kwargs)[source]
JAX 优化器
sm3()
的封装类。- eval_and_stable_update(fn: Callable[[Any], tuple], state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any], forward_mode_differentiation: bool = False) tuple[tuple[Any, Any], tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]]
类似于
eval_and_update()
,但当目标函数值或梯度不是有限值时,我们不会更新输入的 state,并将目标输出设置为 nan。- 参数:
fn – 目标函数。
state – 当前优化器状态。
forward_mode_differentiation – 布尔标志,指示是否使用前向模式微分。
- 返回:
目标函数输出和新优化器状态的对。
- eval_and_update(fn: Callable[[Any], tuple], state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any], forward_mode_differentiation: bool = False) tuple[tuple[Any, Any], tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]]
对目标函数 fn 执行一个优化步骤。对于大多数优化器,更新是基于目标函数相对于当前状态的梯度进行的。然而,对于某些优化器(例如
Minimize
),更新是通过多次重新评估函数来获取最优参数进行的。- 参数:
fn – 一个目标函数,返回一个对,其中第一个项目是标量损失函数(用于求导),第二个项目是辅助输出。
state – 当前优化器状态。
forward_mode_differentiation – 布尔标志,指示是否使用前向模式微分。
- 返回:
目标函数输出和新优化器状态的对。
- get_params(state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]) Any
获取当前参数值。
- 参数:
state – 当前优化器状态。
- 返回:
包含当前参数值的集合。
- init(params: Any) tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]
使用指定用于优化的参数初始化优化器。
- 参数:
params – numpy 数组的集合。
- 返回:
初始优化器状态。
- update(g: Any, state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any], value: Array | ndarray | bool_ | number | bool | int | float | complex | None = None) tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]
优化器的梯度更新。
- 参数:
g – 参数的梯度信息。
state – 当前优化器状态。
- 返回:
更新后的新优化器状态。
Optax 支持
- optax_to_numpyro(transformation) _NumPyroOptim [source]
此函数从一个
optax.GradientTransformation
实例生成一个numpyro.optim._NumPyroOptim
实例,以便可以将其与numpyro.infer.svi.SVI
一起使用。这是一个轻量级封装器,它重新创建了由jax.example_libraries.optimizers
定义的(init_fn, update_fn, get_params_fn)
接口。- 参数:
transformation — 要封装的
optax.GradientTransformation
实例。- 返回:
封装提供的 Optax 优化器的
numpyro.optim._NumPyroOptim
实例。