交互式在线版本: Open In Colab

NumPyro 模型自动渲染

在本教程中,我们将演示如何使用 numpyro.render_model 创建您的概率图模型的精美可视化效果。

[1]:
!pip install -q numpyro@git+https://github.com/pyro-ppl/numpyro
[1]:
import numpy as np

import flax.linen as flax_nn
from jax import nn
import jax.numpy as jnp

import numpyro
from numpyro.contrib.module import flax_module
import numpyro.distributions as dist
import numpyro.distributions.constraints as constraints

assert numpyro.__version__.startswith("0.18.0")

一个简单示例

可视化界面可以方便地与您的模型一起使用

[2]:
def model(data):
    m = numpyro.sample("m", dist.Normal(0, 1))
    sd = numpyro.sample("sd", dist.LogNormal(m, 1))
    with numpyro.plate("N", len(data)):
        numpyro.sample("obs", dist.Normal(m, sd), obs=data)
[3]:
data = jnp.ones(10)
numpyro.render_model(model, model_args=(data,))
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[3]:
../_images/tutorials_model_rendering_5_1.svg

通过向 numpyro.render_model 提供 filename='path',可以将可视化保存到文件。您可以通过更改文件名的后缀来使用不同的格式,例如 PDF 或 PNG。当不保存到文件(filename=None)时,您还可以通过 graph.format = 'pdf' 更改格式,其中 graphnumpyro.render_model 返回的对象。

[4]:
graph = numpyro.render_model(model, model_args=(data,), filename="model.pdf")

微调可视化

由于 numpyro.render_model 返回类型为 graphviz.dot.Digraph 的对象,您可以进一步改进此图的可视化。例如,您可以使用 unflatten 预处理器 来改进更复杂模型的布局纵横比。

[5]:
def mace(positions, annotations):
    """
    This model corresponds to the plate diagram in Figure 3 of https://www.aclweb.org/anthology/Q18-1040.pdf.
    """
    num_annotators = int(np.max(positions)) + 1
    num_classes = int(np.max(annotations)) + 1
    num_items, num_positions = annotations.shape

    with numpyro.plate("annotator", num_annotators):
        epsilon = numpyro.sample("epsilon", dist.Dirichlet(jnp.full(num_classes, 10)))
        theta = numpyro.sample("theta", dist.Beta(0.5, 0.5))

    with numpyro.plate("item", num_items, dim=-2):
        c = numpyro.sample("c", dist.DiscreteUniform(0, num_classes - 1))

        with numpyro.plate("position", num_positions):
            s = numpyro.sample("s", dist.Bernoulli(1 - theta[positions]))
            probs = jnp.where(
                s[..., None] == 0, nn.one_hot(c, num_classes), epsilon[positions]
            )
            numpyro.sample("y", dist.Categorical(probs), obs=annotations)


positions = np.array([1, 1, 1, 2, 3, 4, 5])
# fmt: off
annotations = np.array([
    [1, 3, 1, 2, 2, 2, 1, 3, 2, 2, 4, 2, 1, 2, 1,
     1, 1, 1, 2, 2, 2, 2, 2, 2, 1, 1, 2, 1, 1, 1,
     1, 3, 1, 2, 2, 4, 2, 2, 3, 1, 1, 1, 2, 1, 2],
    [1, 3, 1, 2, 2, 2, 2, 3, 2, 3, 4, 2, 1, 2, 2,
     1, 1, 1, 2, 2, 2, 2, 2, 2, 1, 1, 3, 1, 1, 1,
     1, 3, 1, 2, 2, 3, 2, 3, 3, 1, 1, 2, 3, 2, 2],
    [1, 3, 2, 2, 2, 2, 2, 3, 2, 2, 4, 2, 1, 2, 1,
     1, 1, 1, 2, 2, 2, 2, 2, 1, 1, 1, 2, 1, 1, 2,
     1, 3, 1, 2, 2, 3, 1, 2, 3, 1, 1, 1, 2, 1, 2],
    [1, 4, 2, 3, 3, 3, 2, 3, 2, 2, 4, 3, 1, 3, 1,
     2, 1, 1, 2, 1, 2, 2, 3, 2, 1, 1, 2, 1, 1, 1,
     1, 3, 1, 2, 3, 4, 2, 3, 3, 1, 1, 2, 2, 1, 2],
    [1, 3, 1, 1, 2, 3, 1, 4, 2, 2, 4, 3, 1, 2, 1,
     1, 1, 1, 2, 3, 2, 2, 2, 2, 1, 1, 2, 1, 1, 1,
     1, 2, 1, 2, 2, 3, 2, 2, 4, 1, 1, 1, 2, 1, 2],
    [1, 3, 2, 2, 2, 2, 1, 3, 2, 2, 4, 4, 1, 1, 1,
     1, 1, 1, 2, 2, 2, 2, 2, 2, 1, 1, 2, 1, 1, 2,
     1, 3, 1, 2, 3, 4, 3, 3, 3, 1, 1, 1, 2, 1, 2],
    [1, 4, 2, 1, 2, 2, 1, 3, 3, 3, 4, 3, 1, 2, 1,
     1, 1, 1, 1, 2, 2, 1, 2, 2, 1, 1, 2, 1, 1, 1,
     1, 3, 1, 2, 2, 3, 2, 3, 2, 1, 1, 1, 2, 1, 2],
]).T
# fmt: on

# we subtract 1 because the first index starts with 0 in Python
positions -= 1
annotations -= 1

mace_graph = numpyro.render_model(mace, model_args=(positions, annotations))
[6]:
# default layout
mace_graph
[6]:
../_images/tutorials_model_rendering_10_0.svg
[7]:
# layout after processing the layout with unflatten
mace_graph.unflatten(stagger=2)
[7]:
../_images/tutorials_model_rendering_11_0.svg

渲染参数

我们可以通过在 numpyro.render_model 中设置 render_params=True 来渲染定义为 numpyro.param 的参数。

[8]:
def model(data):
    m = numpyro.param("m", 0.0)
    sd = numpyro.param("sd", 1.0, constraint=constraints.positive)
    lambd = numpyro.sample("lambda", dist.LogNormal(m, sd))
    with numpyro.plate("N", len(data)):
        numpyro.sample("obs", dist.Exponential(lambd), obs=data)
[9]:
data = jnp.ones(10)
numpyro.render_model(model, model_args=(data,), render_params=True)
[9]:
../_images/tutorials_model_rendering_15_0.svg

分布和约束注解

通过在调用 numpyro.render_model 时提供 render_distributions=True,可以在生成的图中显示每个 RV 的分布。当 render_distributions=True 时,与参数相关的约束也会显示。

[10]:
numpyro.render_model(
    model, model_args=(data,), render_params=True, render_distributions=True
)
[10]:
../_images/tutorials_model_rendering_17_0.svg

在上面的图中,'~' 表示 RV 的分布,而 ':math:`in`' 表示参数的约束。

渲染确定性站点

我们还可以渲染通过 numpyro.deterministic 定义的确定性站点。这些站点将用虚线绘制,以区别于随机站点。下面的示例说明了这一点。

[11]:
def model(data):
    m = numpyro.sample("m", dist.Normal(0, 1))
    sd = numpyro.sample("sd", dist.LogNormal(m, 1))
    # deterministic site
    m_transformed = numpyro.deterministic("m_transformed", m + 1)
    with numpyro.plate("N", len(data)):
        numpyro.sample("obs", dist.Normal(m_transformed, sd), obs=data)
[12]:
data = jnp.ones(10)
numpyro.render_model(model, model_args=(data,))
[12]:
../_images/tutorials_model_rendering_22_0.svg

渲染神经网络参数

[13]:
def model(data):
    lambda_base = numpyro.sample("lambda", dist.Normal(0, 1))
    net = flax_module("affine_net", flax_nn.Dense(1), input_shape=(1,))
    lambd = jnp.exp(net(jnp.expand_dims(lambda_base, -1)).squeeze(-1))
    with numpyro.plate("N", len(data)):
        numpyro.sample("obs", dist.Exponential(lambd), obs=data)
[14]:
numpyro.render_model(
    model, model_args=(data,), render_distributions=True, render_params=True
)
[14]:
../_images/tutorials_model_rendering_25_0.svg

重叠的非嵌套 Plate

请注意,重叠的非嵌套 Plate 可能会被绘制成多个矩形。

[15]:
def model():
    plate1 = numpyro.plate("plate1", 2, dim=-2)
    plate2 = numpyro.plate("plate2", 3, dim=-1)
    with plate1:
        x = numpyro.sample("x", dist.Normal(0, 1))
    with plate1, plate2:
        y = numpyro.sample("y", dist.Normal(x, 1))
    with plate2:
        numpyro.sample("z", dist.Normal(y.sum(-2, keepdims=True), 1), obs=jnp.zeros(3))
[16]:
numpyro.render_model(model)
[16]:
../_images/tutorials_model_rendering_29_0.svg
[ ]: