NumPyro 模型自动渲染
在本教程中,我们将演示如何使用 numpyro.render_model 创建您的概率图模型的精美可视化效果。
[1]:
!pip install -q numpyro@git+https://github.com/pyro-ppl/numpyro
[1]:
import numpy as np
import flax.linen as flax_nn
from jax import nn
import jax.numpy as jnp
import numpyro
from numpyro.contrib.module import flax_module
import numpyro.distributions as dist
import numpyro.distributions.constraints as constraints
assert numpyro.__version__.startswith("0.18.0")
一个简单示例
可视化界面可以方便地与您的模型一起使用
[2]:
def model(data):
m = numpyro.sample("m", dist.Normal(0, 1))
sd = numpyro.sample("sd", dist.LogNormal(m, 1))
with numpyro.plate("N", len(data)):
numpyro.sample("obs", dist.Normal(m, sd), obs=data)
[3]:
data = jnp.ones(10)
numpyro.render_model(model, model_args=(data,))
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[3]:
通过向 numpyro.render_model
提供 filename='path'
,可以将可视化保存到文件。您可以通过更改文件名的后缀来使用不同的格式,例如 PDF 或 PNG。当不保存到文件(filename=None
)时,您还可以通过 graph.format = 'pdf'
更改格式,其中 graph
是 numpyro.render_model
返回的对象。
[4]:
graph = numpyro.render_model(model, model_args=(data,), filename="model.pdf")
微调可视化
由于 numpyro.render_model
返回类型为 graphviz.dot.Digraph
的对象,您可以进一步改进此图的可视化。例如,您可以使用 unflatten 预处理器 来改进更复杂模型的布局纵横比。
[5]:
def mace(positions, annotations):
"""
This model corresponds to the plate diagram in Figure 3 of https://www.aclweb.org/anthology/Q18-1040.pdf.
"""
num_annotators = int(np.max(positions)) + 1
num_classes = int(np.max(annotations)) + 1
num_items, num_positions = annotations.shape
with numpyro.plate("annotator", num_annotators):
epsilon = numpyro.sample("epsilon", dist.Dirichlet(jnp.full(num_classes, 10)))
theta = numpyro.sample("theta", dist.Beta(0.5, 0.5))
with numpyro.plate("item", num_items, dim=-2):
c = numpyro.sample("c", dist.DiscreteUniform(0, num_classes - 1))
with numpyro.plate("position", num_positions):
s = numpyro.sample("s", dist.Bernoulli(1 - theta[positions]))
probs = jnp.where(
s[..., None] == 0, nn.one_hot(c, num_classes), epsilon[positions]
)
numpyro.sample("y", dist.Categorical(probs), obs=annotations)
positions = np.array([1, 1, 1, 2, 3, 4, 5])
# fmt: off
annotations = np.array([
[1, 3, 1, 2, 2, 2, 1, 3, 2, 2, 4, 2, 1, 2, 1,
1, 1, 1, 2, 2, 2, 2, 2, 2, 1, 1, 2, 1, 1, 1,
1, 3, 1, 2, 2, 4, 2, 2, 3, 1, 1, 1, 2, 1, 2],
[1, 3, 1, 2, 2, 2, 2, 3, 2, 3, 4, 2, 1, 2, 2,
1, 1, 1, 2, 2, 2, 2, 2, 2, 1, 1, 3, 1, 1, 1,
1, 3, 1, 2, 2, 3, 2, 3, 3, 1, 1, 2, 3, 2, 2],
[1, 3, 2, 2, 2, 2, 2, 3, 2, 2, 4, 2, 1, 2, 1,
1, 1, 1, 2, 2, 2, 2, 2, 1, 1, 1, 2, 1, 1, 2,
1, 3, 1, 2, 2, 3, 1, 2, 3, 1, 1, 1, 2, 1, 2],
[1, 4, 2, 3, 3, 3, 2, 3, 2, 2, 4, 3, 1, 3, 1,
2, 1, 1, 2, 1, 2, 2, 3, 2, 1, 1, 2, 1, 1, 1,
1, 3, 1, 2, 3, 4, 2, 3, 3, 1, 1, 2, 2, 1, 2],
[1, 3, 1, 1, 2, 3, 1, 4, 2, 2, 4, 3, 1, 2, 1,
1, 1, 1, 2, 3, 2, 2, 2, 2, 1, 1, 2, 1, 1, 1,
1, 2, 1, 2, 2, 3, 2, 2, 4, 1, 1, 1, 2, 1, 2],
[1, 3, 2, 2, 2, 2, 1, 3, 2, 2, 4, 4, 1, 1, 1,
1, 1, 1, 2, 2, 2, 2, 2, 2, 1, 1, 2, 1, 1, 2,
1, 3, 1, 2, 3, 4, 3, 3, 3, 1, 1, 1, 2, 1, 2],
[1, 4, 2, 1, 2, 2, 1, 3, 3, 3, 4, 3, 1, 2, 1,
1, 1, 1, 1, 2, 2, 1, 2, 2, 1, 1, 2, 1, 1, 1,
1, 3, 1, 2, 2, 3, 2, 3, 2, 1, 1, 1, 2, 1, 2],
]).T
# fmt: on
# we subtract 1 because the first index starts with 0 in Python
positions -= 1
annotations -= 1
mace_graph = numpyro.render_model(mace, model_args=(positions, annotations))
[6]:
# default layout
mace_graph
[6]:
[7]:
# layout after processing the layout with unflatten
mace_graph.unflatten(stagger=2)
[7]:
渲染参数
我们可以通过在 numpyro.render_model
中设置 render_params=True
来渲染定义为 numpyro.param
的参数。
[8]:
def model(data):
m = numpyro.param("m", 0.0)
sd = numpyro.param("sd", 1.0, constraint=constraints.positive)
lambd = numpyro.sample("lambda", dist.LogNormal(m, sd))
with numpyro.plate("N", len(data)):
numpyro.sample("obs", dist.Exponential(lambd), obs=data)
[9]:
data = jnp.ones(10)
numpyro.render_model(model, model_args=(data,), render_params=True)
[9]:
分布和约束注解
通过在调用 numpyro.render_model
时提供 render_distributions=True
,可以在生成的图中显示每个 RV 的分布。当 render_distributions=True
时,与参数相关的约束也会显示。
[10]:
numpyro.render_model(
model, model_args=(data,), render_params=True, render_distributions=True
)
[10]:
在上面的图中,'~' 表示 RV 的分布,而 ':math:`in`' 表示参数的约束。
渲染确定性站点
我们还可以渲染通过 numpyro.deterministic
定义的确定性站点。这些站点将用虚线绘制,以区别于随机站点。下面的示例说明了这一点。
[11]:
def model(data):
m = numpyro.sample("m", dist.Normal(0, 1))
sd = numpyro.sample("sd", dist.LogNormal(m, 1))
# deterministic site
m_transformed = numpyro.deterministic("m_transformed", m + 1)
with numpyro.plate("N", len(data)):
numpyro.sample("obs", dist.Normal(m_transformed, sd), obs=data)
[12]:
data = jnp.ones(10)
numpyro.render_model(model, model_args=(data,))
[12]:
渲染神经网络参数
[13]:
def model(data):
lambda_base = numpyro.sample("lambda", dist.Normal(0, 1))
net = flax_module("affine_net", flax_nn.Dense(1), input_shape=(1,))
lambd = jnp.exp(net(jnp.expand_dims(lambda_base, -1)).squeeze(-1))
with numpyro.plate("N", len(data)):
numpyro.sample("obs", dist.Exponential(lambd), obs=data)
[14]:
numpyro.render_model(
model, model_args=(data,), render_distributions=True, render_params=True
)
[14]:
重叠的非嵌套 Plate
请注意,重叠的非嵌套 Plate 可能会被绘制成多个矩形。
[15]:
def model():
plate1 = numpyro.plate("plate1", 2, dim=-2)
plate2 = numpyro.plate("plate2", 3, dim=-1)
with plate1:
x = numpyro.sample("x", dist.Normal(0, 1))
with plate1, plate2:
y = numpyro.sample("y", dist.Normal(x, 1))
with plate2:
numpyro.sample("z", dist.Normal(y.sum(-2, keepdims=True), 1), obs=jnp.zeros(3))
[16]:
numpyro.render_model(model)
[16]:
[ ]: