优化器

此处定义的优化器类是对来源于 jax.example_libraries.optimizers 相应优化器的轻量级封装,其接口更适用于与 NumPyro 推断算法配合使用。

Adam

class Adam(*args, **kwargs)[source]

JAX 优化器 adam() 的封装类

eval_and_stable_update(fn: Callable[[Any], tuple], state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any], forward_mode_differentiation: bool = False) tuple[tuple[Any, Any], tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]]

eval_and_update() 类似,但当目标函数的值或梯度不是有限值时,我们将不更新输入 state,并将目标输出设置为 nan

参数:
  • fn – 目标函数。

  • state – 当前优化器状态。

  • forward_mode_differentiation – 布尔标志,指示是否使用前向模式微分。

返回:

目标函数输出和新优化器状态的对。

eval_and_update(fn: Callable[[Any], tuple], state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any], forward_mode_differentiation: bool = False) tuple[tuple[Any, Any], tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]]

对目标函数 fn 执行一个优化步骤。对于大多数优化器,更新是基于目标函数相对于当前状态的梯度进行的。然而,对于某些优化器(例如 Minimize),更新是通过多次重新评估函数来获取最优参数进行的。

参数:
  • fn – 一个目标函数,返回一个对,其中第一个项目是标量损失函数(用于求导),第二个项目是辅助输出。

  • state – 当前优化器状态。

  • forward_mode_differentiation – 布尔标志,指示是否使用前向模式微分。

返回:

目标函数输出和新优化器状态的对。

get_params(state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]) Any

获取当前参数值。

参数:

state – 当前优化器状态。

返回:

包含当前参数值的集合。

init(params: Any) tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]

使用指定用于优化的参数初始化优化器。

参数:

params – numpy 数组的集合。

返回:

初始优化器状态。

update(g: Any, state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any], value: Array | ndarray | bool_ | number | bool | int | float | complex | None = None) tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]

优化器的梯度更新。

参数:
  • g – 参数的梯度信息。

  • state – 当前优化器状态。

返回:

更新后的新优化器状态。

Adagrad

class Adagrad(*args, **kwargs)[source]

JAX 优化器 adagrad() 的封装类

eval_and_stable_update(fn: Callable[[Any], tuple], state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any], forward_mode_differentiation: bool = False) tuple[tuple[Any, Any], tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]]

eval_and_update() 类似,但当目标函数的值或梯度不是有限值时,我们将不更新输入 state,并将目标输出设置为 nan

参数:
  • fn – 目标函数。

  • state – 当前优化器状态。

  • forward_mode_differentiation – 布尔标志,指示是否使用前向模式微分。

返回:

目标函数输出和新优化器状态的对。

eval_and_update(fn: Callable[[Any], tuple], state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any], forward_mode_differentiation: bool = False) tuple[tuple[Any, Any], tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]]

对目标函数 fn 执行一个优化步骤。对于大多数优化器,更新是基于目标函数相对于当前状态的梯度进行的。然而,对于某些优化器(例如 Minimize),更新是通过多次重新评估函数来获取最优参数进行的。

参数:
  • fn – 一个目标函数,返回一个对,其中第一个项目是标量损失函数(用于求导),第二个项目是辅助输出。

  • state – 当前优化器状态。

  • forward_mode_differentiation – 布尔标志,指示是否使用前向模式微分。

返回:

目标函数输出和新优化器状态的对。

get_params(state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]) Any

获取当前参数值。

参数:

state – 当前优化器状态。

返回:

包含当前参数值的集合。

init(params: Any) tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]

使用指定用于优化的参数初始化优化器。

参数:

params – numpy 数组的集合。

返回:

初始优化器状态。

update(g: Any, state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any], value: Array | ndarray | bool_ | number | bool | int | float | complex | None = None) tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]

优化器的梯度更新。

参数:
  • g – 参数的梯度信息。

  • state – 当前优化器状态。

返回:

更新后的新优化器状态。

ClippedAdam

class ClippedAdam(*args, clip_norm: float = 10.0, **kwargs)[source]

带有梯度裁剪的 Adam 优化器。

参数:

clip_norm (float) – 所有梯度值将被裁剪到 [-clip_norm, clip_norm] 之间。

参考

A Method for Stochastic Optimization, Diederik P. Kingma, Jimmy Ba https://arxiv.org/abs/1412.6980

eval_and_stable_update(fn: Callable[[Any], tuple], state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any], forward_mode_differentiation: bool = False) tuple[tuple[Any, Any], tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]]

eval_and_update() 类似,但当目标函数的值或梯度不是有限值时,我们将不更新输入 state,并将目标输出设置为 nan

参数:
  • fn – 目标函数。

  • state – 当前优化器状态。

  • forward_mode_differentiation – 布尔标志,指示是否使用前向模式微分。

返回:

目标函数输出和新优化器状态的对。

eval_and_update(fn: Callable[[Any], tuple], state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any], forward_mode_differentiation: bool = False) tuple[tuple[Any, Any], tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]]

对目标函数 fn 执行一个优化步骤。对于大多数优化器,更新是基于目标函数相对于当前状态的梯度进行的。然而,对于某些优化器(例如 Minimize),更新是通过多次重新评估函数来获取最优参数进行的。

参数:
  • fn – 一个目标函数,返回一个对,其中第一个项目是标量损失函数(用于求导),第二个项目是辅助输出。

  • state – 当前优化器状态。

  • forward_mode_differentiation – 布尔标志,指示是否使用前向模式微分。

返回:

目标函数输出和新优化器状态的对。

get_params(state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]) Any

获取当前参数值。

参数:

state – 当前优化器状态。

返回:

包含当前参数值的集合。

init(params: Any) tuple[Array | ndarray | bool_ | number | bool | int | float | complex , Any]

使用指定用于优化的参数初始化优化器。

参数:

params – numpy 数组的集合。

返回:

初始优化器状态。

update(g: Any, state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex | None = None) tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any][source]

优化器的梯度更新。

参数:
  • g – 参数的梯度信息。

  • state – 当前优化器状态。

返回:

更新后的新优化器状态。

Minimize

class Minimize(method='BFGS', **kwargs)[source]

JAX 最小化器 minimize() 的封装类。

示例

>>> from numpy.testing import assert_allclose
>>> from jax import random
>>> import jax.numpy as jnp
>>> import numpyro
>>> import numpyro.distributions as dist
>>> from numpyro.infer import SVI, Trace_ELBO
>>> from numpyro.infer.autoguide import AutoLaplaceApproximation

>>> def model(x, y):
...     a = numpyro.sample("a", dist.Normal(0, 1))
...     b = numpyro.sample("b", dist.Normal(0, 1))
...     with numpyro.plate("N", y.shape[0]):
...         numpyro.sample("obs", dist.Normal(a + b * x, 0.1), obs=y)

>>> x = jnp.linspace(0, 10, 100)
>>> y = 3 * x + 2
>>> optimizer = numpyro.optim.Minimize()
>>> guide = AutoLaplaceApproximation(model)
>>> svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
>>> init_state = svi.init(random.PRNGKey(0), x, y)
>>> optimal_state, loss = svi.update(init_state, x, y)
>>> params = svi.get_params(optimal_state)  # get guide's parameters
>>> quantiles = guide.quantiles(params, 0.5)  # get means of posterior samples
>>> assert_allclose(quantiles["a"], 2., atol=1e-3)
>>> assert_allclose(quantiles["b"], 3., atol=1e-3)
eval_and_stable_update(fn: Callable[[Any], tuple], state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any], forward_mode_differentiation: bool = False) tuple[tuple[Any, Any], tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]]

类似于 eval_and_update(),但当目标函数值或梯度不是有限值时,我们将不更新输入 state 并将目标函数输出设置为 nan

参数:
  • fn – 目标函数。

  • state – 当前优化器状态。

  • forward_mode_differentiation – 布尔标志,指示是否使用前向模式微分。

返回:

目标函数输出和新优化器状态的对。

eval_and_update(fn: Callable[[Any], tuple], state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any], forward_mode_differentiation: bool = False) tuple[tuple[Any, None], tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]][source]

对目标函数 fn 执行一个优化步骤。对于大多数优化器,更新是基于目标函数相对于当前状态的梯度进行的。然而,对于某些优化器(例如 Minimize),更新是通过多次重新评估函数来获取最优参数进行的。

参数:
  • fn – 一个目标函数,返回一个对,其中第一个项目是标量损失函数(用于求导),第二个项目是辅助输出。

  • state – 当前优化器状态。

  • forward_mode_differentiation – 布尔标志,指示是否使用前向模式微分。

返回:

目标函数输出和新优化器状态的对。

get_params(state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]) Any

获取当前参数值。

参数:

state – 当前优化器状态。

返回:

包含当前参数值的集合。

init(params: Any) tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]

使用指定用于优化的参数初始化优化器。

参数:

params – numpy 数组的集合。

返回:

初始优化器状态。

update(g: Any, state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any], value: Array | ndarray | bool_ | number | bool | int | float | complex | None = None) tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]

优化器的梯度更新。

参数:
  • g – 参数的梯度信息。

  • state – 当前优化器状态。

返回:

更新后的新优化器状态。

动量

class Momentum(*args, **kwargs)[source]

JAX 优化器 momentum() 的包装类。

eval_and_stable_update(fn: Callable[[Any], tuple], state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any], forward_mode_differentiation: bool = False) tuple[tuple[Any, Any], tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]]

类似于 eval_and_update(),但当目标函数值或梯度不是有限值时,我们将不更新输入 state 并将目标函数输出设置为 nan

参数:
  • fn – 目标函数。

  • state – 当前优化器状态。

  • forward_mode_differentiation – 布尔标志,指示是否使用前向模式微分。

返回:

目标函数输出和新优化器状态的对。

eval_and_update(fn: Callable[[Any], tuple], state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any], forward_mode_differentiation: bool = False) tuple[tuple[Any, Any], tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]]

对目标函数 fn 执行一个优化步骤。对于大多数优化器,更新是基于目标函数相对于当前状态的梯度进行的。然而,对于某些优化器(例如 Minimize),更新是通过多次重新评估函数来获取最优参数进行的。

参数:
  • fn – 一个目标函数,返回一个对,其中第一个项目是标量损失函数(用于求导),第二个项目是辅助输出。

  • state – 当前优化器状态。

  • forward_mode_differentiation – 布尔标志,指示是否使用前向模式微分。

返回:

目标函数输出和新优化器状态的对。

get_params(state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]) Any

获取当前参数值。

参数:

state – 当前优化器状态。

返回:

包含当前参数值的集合。

init(params: Any) tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]

使用指定用于优化的参数初始化优化器。

参数:

params – numpy 数组的集合。

返回:

初始优化器状态。

update(g: Any, state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any], value: Array | ndarray | bool_ | number | bool | int | float | complex | None = None) tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]

优化器的梯度更新。

参数:
  • g – 参数的梯度信息。

  • state – 当前优化器状态。

返回:

更新后的新优化器状态。

RMSProp

class RMSProp(*args, **kwargs)[source]

JAX 优化器 rmsprop() 的包装类。

eval_and_stable_update(fn: Callable[[Any], tuple], state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any], forward_mode_differentiation: bool = False) tuple[tuple[Any, Any], tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]]

类似于 eval_and_update(),但当目标函数值或梯度不是有限值时,我们将不更新输入 state 并将目标函数输出设置为 nan

参数:
  • fn – 目标函数。

  • state – 当前优化器状态。

  • forward_mode_differentiation – 布尔标志,指示是否使用前向模式微分。

返回:

目标函数输出和新优化器状态的对。

eval_and_update(fn: Callable[[Any], tuple], state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any], forward_mode_differentiation: bool = False) tuple[tuple[Any, Any], tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]]

对目标函数 fn 执行一个优化步骤。对于大多数优化器,更新是基于目标函数相对于当前状态的梯度进行的。然而,对于某些优化器(例如 Minimize),更新是通过多次重新评估函数来获取最优参数进行的。

参数:
  • fn – 一个目标函数,返回一个对,其中第一个项目是标量损失函数(用于求导),第二个项目是辅助输出。

  • state – 当前优化器状态。

  • forward_mode_differentiation – 布尔标志,指示是否使用前向模式微分。

返回:

目标函数输出和新优化器状态的对。

get_params(state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]) Any

获取当前参数值。

参数:

state – 当前优化器状态。

返回:

包含当前参数值的集合。

init(params: Any) tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]

使用指定用于优化的参数初始化优化器。

参数:

params – numpy 数组的集合。

返回:

初始优化器状态。

update(g: Any, state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any], value: Array | ndarray | bool_ | number | bool | int | float | complex | None = None) tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]

优化器的梯度更新。

参数:
  • g – 参数的梯度信息。

  • state – 当前优化器状态。

返回:

更新后的新优化器状态。

RMSPropMomentum

class RMSPropMomentum(*args, **kwargs)[source]

JAX 优化器 rmsprop_momentum() 的包装类。

eval_and_stable_update(fn: Callable[[Any], tuple], state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any], forward_mode_differentiation: bool = False) tuple[tuple[Any, Any], tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]]

类似于 eval_and_update(),但当目标函数值或梯度不是有限值时,我们将不更新输入 state 并将目标函数输出设置为 nan

参数:
  • fn – 目标函数。

  • state – 当前优化器状态。

  • forward_mode_differentiation – 布尔标志,指示是否使用前向模式微分。

返回:

目标函数输出和新优化器状态的对。

eval_and_update(fn: Callable[[Any], tuple], state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any], forward_mode_differentiation: bool = False) tuple[tuple[Any, Any], tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]]

对目标函数 fn 执行一个优化步骤。对于大多数优化器,更新是基于目标函数相对于当前状态的梯度进行的。然而,对于某些优化器(例如 Minimize),更新是通过多次重新评估函数来获取最优参数进行的。

参数:
  • fn – 一个目标函数,返回一个对,其中第一个项目是标量损失函数(用于求导),第二个项目是辅助输出。

  • state – 当前优化器状态。

  • forward_mode_differentiation – 布尔标志,指示是否使用前向模式微分。

返回:

目标函数输出和新优化器状态的对。

get_params(state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]) Any

获取当前参数值。

参数:

state – 当前优化器状态。

返回:

包含当前参数值的集合。

init(params: Any) tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]

使用指定用于优化的参数初始化优化器。

参数:

params – numpy 数组的集合。

返回:

初始优化器状态。

update(g: Any, state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any], value: Array | ndarray | bool_ | number | bool | int | float | complex | None = None) tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]

优化器的梯度更新。

参数:
  • g – 参数的梯度信息。

  • state – 当前优化器状态。

返回:

更新后的新优化器状态。

SGD

class SGD(*args, **kwargs)[source]

JAX 优化器 sgd() 的封装类。

eval_and_stable_update(fn: Callable[[Any], tuple], state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any], forward_mode_differentiation: bool = False) tuple[tuple[Any, Any], tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]]

类似于 eval_and_update(),但当目标函数值或梯度不是有限值时,我们不会更新输入的 state,并将目标输出设置为 nan

参数:
  • fn – 目标函数。

  • state – 当前优化器状态。

  • forward_mode_differentiation – 布尔标志,指示是否使用前向模式微分。

返回:

目标函数输出和新优化器状态的对。

eval_and_update(fn: Callable[[Any], tuple], state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any], forward_mode_differentiation: bool = False) tuple[tuple[Any, Any], tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]]

对目标函数 fn 执行一个优化步骤。对于大多数优化器,更新是基于目标函数相对于当前状态的梯度进行的。然而,对于某些优化器(例如 Minimize),更新是通过多次重新评估函数来获取最优参数进行的。

参数:
  • fn – 一个目标函数,返回一个对,其中第一个项目是标量损失函数(用于求导),第二个项目是辅助输出。

  • state – 当前优化器状态。

  • forward_mode_differentiation – 布尔标志,指示是否使用前向模式微分。

返回:

目标函数输出和新优化器状态的对。

get_params(state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]) Any

获取当前参数值。

参数:

state – 当前优化器状态。

返回:

包含当前参数值的集合。

init(params: Any) tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]

使用指定用于优化的参数初始化优化器。

参数:

params – numpy 数组的集合。

返回:

初始优化器状态。

update(g: Any, state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any], value: Array | ndarray | bool_ | number | bool | int | float | complex | None = None) tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]

优化器的梯度更新。

参数:
  • g – 参数的梯度信息。

  • state – 当前优化器状态。

返回:

更新后的新优化器状态。

SM3

class SM3(*args, **kwargs)[source]

JAX 优化器 sm3() 的封装类。

eval_and_stable_update(fn: Callable[[Any], tuple], state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any], forward_mode_differentiation: bool = False) tuple[tuple[Any, Any], tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]]

类似于 eval_and_update(),但当目标函数值或梯度不是有限值时,我们不会更新输入的 state,并将目标输出设置为 nan

参数:
  • fn – 目标函数。

  • state – 当前优化器状态。

  • forward_mode_differentiation – 布尔标志,指示是否使用前向模式微分。

返回:

目标函数输出和新优化器状态的对。

eval_and_update(fn: Callable[[Any], tuple], state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any], forward_mode_differentiation: bool = False) tuple[tuple[Any, Any], tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]]

对目标函数 fn 执行一个优化步骤。对于大多数优化器,更新是基于目标函数相对于当前状态的梯度进行的。然而,对于某些优化器(例如 Minimize),更新是通过多次重新评估函数来获取最优参数进行的。

参数:
  • fn – 一个目标函数,返回一个对,其中第一个项目是标量损失函数(用于求导),第二个项目是辅助输出。

  • state – 当前优化器状态。

  • forward_mode_differentiation – 布尔标志,指示是否使用前向模式微分。

返回:

目标函数输出和新优化器状态的对。

get_params(state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]) Any

获取当前参数值。

参数:

state – 当前优化器状态。

返回:

包含当前参数值的集合。

init(params: Any) tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]

使用指定用于优化的参数初始化优化器。

参数:

params – numpy 数组的集合。

返回:

初始优化器状态。

update(g: Any, state: tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any], value: Array | ndarray | bool_ | number | bool | int | float | complex | None = None) tuple[Array | ndarray | bool_ | number | bool | int | float | complex, Any]

优化器的梯度更新。

参数:
  • g – 参数的梯度信息。

  • state – 当前优化器状态。

返回:

更新后的新优化器状态。

Optax 支持

optax_to_numpyro(transformation) _NumPyroOptim[source]

此函数从一个 optax.GradientTransformation 实例生成一个 numpyro.optim._NumPyroOptim 实例,以便可以将其与 numpyro.infer.svi.SVI 一起使用。这是一个轻量级封装器,它重新创建了由 jax.example_libraries.optimizers 定义的 (init_fn, update_fn, get_params_fn) 接口。

参数:

transformation — 要封装的 optax.GradientTransformation 实例。

返回:

封装提供的 Optax 优化器的 numpyro.optim._NumPyroOptim 实例。