注意
跳至末尾下载完整的示例代码。
示例:贝叶斯标注模型
在此示例中,我们针对 [1] 中的各种众包标注模型运行 MCMC。
所有模型都包含离散潜在变量。在底层,我们在推断中枚举(边缘化)这些离散潜在站点。这些模型的复杂性不同,因此对于刚接触 Pyro/NumPyro 枚举机制的读者来说,它们是很好的参考。我们建议读者将实现与 [1] 中相应的板块图进行比较,以了解 Pyro/NumPyro 程序是多么简洁。
感兴趣的读者也可以参考 [3] 以了解有关枚举的更多解释。
数据取自参考文献 [2] 的表 1。
目前,此示例不包含处理“标签切换”问题([1] 的第 6.2 节中提及)的后处理步骤。
参考文献
Paun, S., Carpenter, B., Chamberlain, J., Hovy, D., Kruschwitz, U., and Poesio, M. (2018). “Comparing bayesian models of annotation” (https://www.aclweb.org/anthology/Q18-1040/)
Dawid, A. P., and Skene, A. M. (1979). “Maximum likelihood estimation of observer error‐rates using the EM algorithm”
“Inference with Discrete Latent Variables” (https://pyro.org.cn/examples/enumeration.html)
import argparse
import os
import numpy as np
from jax import nn, random, vmap
import jax.numpy as jnp
import numpyro
from numpyro import handlers
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive
from numpyro.infer.reparam import LocScaleReparam
from numpyro.ops.indexing import Vindex
def get_data():
"""
:return: a tuple of annotator indices and class indices. The first term has shape
`num_positions` whose entries take values from `0` to `num_annotators - 1`.
The second term has shape `num_items x num_positions` whose entries take values
from `0` to `num_classes - 1`.
"""
# NB: the first annotator assessed each item 3 times
positions = np.array([1, 1, 1, 2, 3, 4, 5])
# fmt: off
annotations = np.array(
[[1, 1, 1, 1, 1, 1, 1], [3, 3, 3, 4, 3, 3, 4], [1, 1, 2, 2, 1, 2, 2],
[2, 2, 2, 3, 1, 2, 1], [2, 2, 2, 3, 2, 2, 2], [2, 2, 2, 3, 3, 2, 2],
[1, 2, 2, 2, 1, 1, 1], [3, 3, 3, 3, 4, 3, 3], [2, 2, 2, 2, 2, 2, 3],
[2, 3, 2, 2, 2, 2, 3], [4, 4, 4, 4, 4, 4, 4], [2, 2, 2, 3, 3, 4, 3],
[1, 1, 1, 1, 1, 1, 1], [2, 2, 2, 3, 2, 1, 2], [1, 2, 1, 1, 1, 1, 1],
[1, 1, 1, 2, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1],
[2, 2, 2, 2, 2, 2, 1], [2, 2, 2, 1, 3, 2, 2], [2, 2, 2, 2, 2, 2, 2],
[2, 2, 2, 2, 2, 2, 1], [2, 2, 2, 3, 2, 2, 2], [2, 2, 1, 2, 2, 2, 2],
[1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1], [2, 3, 2, 2, 2, 2, 2],
[1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1], [1, 1, 2, 1, 1, 2, 1],
[1, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 2, 3, 3], [1, 1, 1, 1, 1, 1, 1],
[2, 2, 2, 2, 2, 2, 2], [2, 2, 2, 3, 2, 3, 2], [4, 3, 3, 4, 3, 4, 3],
[2, 2, 1, 2, 2, 3, 2], [2, 3, 2, 3, 2, 3, 3], [3, 3, 3, 3, 4, 3, 2],
[1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1], [1, 2, 1, 2, 1, 1, 1],
[2, 3, 2, 2, 2, 2, 2], [1, 2, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2, 2, 2]])
# fmt: on
# we minus 1 because in Python, the first index is 0
return positions - 1, annotations - 1
def multinomial(annotations):
"""
This model corresponds to the plate diagram in Figure 1 of reference [1].
"""
num_classes = int(np.max(annotations)) + 1
num_items, num_positions = annotations.shape
with numpyro.plate("class", num_classes):
zeta = numpyro.sample("zeta", dist.Dirichlet(jnp.ones(num_classes)))
pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))
with numpyro.plate("item", num_items, dim=-2):
c = numpyro.sample("c", dist.Categorical(pi), infer={"enumerate": "parallel"})
with numpyro.plate("position", num_positions):
numpyro.sample("y", dist.Categorical(zeta[c]), obs=annotations)
def dawid_skene(positions, annotations):
"""
This model corresponds to the plate diagram in Figure 2 of reference [1].
"""
num_annotators = int(np.max(positions)) + 1
num_classes = int(np.max(annotations)) + 1
num_items, num_positions = annotations.shape
with numpyro.plate("annotator", num_annotators, dim=-2):
with numpyro.plate("class", num_classes):
beta = numpyro.sample("beta", dist.Dirichlet(jnp.ones(num_classes)))
pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))
with numpyro.plate("item", num_items, dim=-2):
c = numpyro.sample("c", dist.Categorical(pi), infer={"enumerate": "parallel"})
# here we use Vindex to allow broadcasting for the second index `c`
# ref: https://num.pyro.org.cn/en/stable/utilities.html#numpyro.contrib.indexing.vindex
with numpyro.plate("position", num_positions):
numpyro.sample(
"y", dist.Categorical(Vindex(beta)[positions, c, :]), obs=annotations
)
def mace(positions, annotations):
"""
This model corresponds to the plate diagram in Figure 3 of reference [1].
"""
num_annotators = int(np.max(positions)) + 1
num_classes = int(np.max(annotations)) + 1
num_items, num_positions = annotations.shape
with numpyro.plate("annotator", num_annotators):
epsilon = numpyro.sample("epsilon", dist.Dirichlet(jnp.full(num_classes, 10)))
theta = numpyro.sample("theta", dist.Beta(0.5, 0.5))
with numpyro.plate("item", num_items, dim=-2):
c = numpyro.sample(
"c",
dist.DiscreteUniform(0, num_classes - 1),
infer={"enumerate": "parallel"},
)
with numpyro.plate("position", num_positions):
s = numpyro.sample(
"s",
dist.Bernoulli(1 - theta[positions]),
infer={"enumerate": "parallel"},
)
probs = jnp.where(
s[..., None] == 0, nn.one_hot(c, num_classes), epsilon[positions]
)
numpyro.sample("y", dist.Categorical(probs), obs=annotations)
def hierarchical_dawid_skene(positions, annotations):
"""
This model corresponds to the plate diagram in Figure 4 of reference [1].
"""
num_annotators = int(np.max(positions)) + 1
num_classes = int(np.max(annotations)) + 1
num_items, num_positions = annotations.shape
with numpyro.plate("class", num_classes):
# NB: we define `beta` as the `logits` of `y` likelihood; but `logits` is
# invariant up to a constant, so we'll follow [1]: fix the last term of `beta`
# to 0 and only define hyperpriors for the first `num_classes - 1` terms.
zeta = numpyro.sample(
"zeta", dist.Normal(0, 1).expand([num_classes - 1]).to_event(1)
)
omega = numpyro.sample(
"Omega", dist.HalfNormal(1).expand([num_classes - 1]).to_event(1)
)
with numpyro.plate("annotator", num_annotators, dim=-2):
with numpyro.plate("class", num_classes):
# non-centered parameterization
with handlers.reparam(config={"beta": LocScaleReparam(0)}):
beta = numpyro.sample("beta", dist.Normal(zeta, omega).to_event(1))
# pad 0 to the last item
beta = jnp.pad(beta, [(0, 0)] * (jnp.ndim(beta) - 1) + [(0, 1)])
pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))
with numpyro.plate("item", num_items, dim=-2):
c = numpyro.sample("c", dist.Categorical(pi), infer={"enumerate": "parallel"})
with numpyro.plate("position", num_positions):
logits = Vindex(beta)[positions, c, :]
numpyro.sample("y", dist.Categorical(logits=logits), obs=annotations)
def item_difficulty(annotations):
"""
This model corresponds to the plate diagram in Figure 5 of reference [1].
"""
num_classes = int(np.max(annotations)) + 1
num_items, num_positions = annotations.shape
with numpyro.plate("class", num_classes):
eta = numpyro.sample(
"eta", dist.Normal(0, 1).expand([num_classes - 1]).to_event(1)
)
chi = numpyro.sample(
"Chi", dist.HalfNormal(1).expand([num_classes - 1]).to_event(1)
)
pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))
with numpyro.plate("item", num_items, dim=-2):
c = numpyro.sample("c", dist.Categorical(pi), infer={"enumerate": "parallel"})
with handlers.reparam(config={"theta": LocScaleReparam(0)}):
theta = numpyro.sample("theta", dist.Normal(eta[c], chi[c]).to_event(1))
theta = jnp.pad(theta, [(0, 0)] * (jnp.ndim(theta) - 1) + [(0, 1)])
with numpyro.plate("position", annotations.shape[-1]):
numpyro.sample("y", dist.Categorical(logits=theta), obs=annotations)
def logistic_random_effects(positions, annotations):
"""
This model corresponds to the plate diagram in Figure 5 of reference [1].
"""
num_annotators = int(np.max(positions)) + 1
num_classes = int(np.max(annotations)) + 1
num_items, num_positions = annotations.shape
with numpyro.plate("class", num_classes):
zeta = numpyro.sample(
"zeta", dist.Normal(0, 1).expand([num_classes - 1]).to_event(1)
)
omega = numpyro.sample(
"Omega", dist.HalfNormal(1).expand([num_classes - 1]).to_event(1)
)
chi = numpyro.sample(
"Chi", dist.HalfNormal(1).expand([num_classes - 1]).to_event(1)
)
with numpyro.plate("annotator", num_annotators, dim=-2):
with numpyro.plate("class", num_classes):
with handlers.reparam(config={"beta": LocScaleReparam(0)}):
beta = numpyro.sample("beta", dist.Normal(zeta, omega).to_event(1))
beta = jnp.pad(beta, [(0, 0)] * (jnp.ndim(beta) - 1) + [(0, 1)])
pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))
with numpyro.plate("item", num_items, dim=-2):
c = numpyro.sample("c", dist.Categorical(pi), infer={"enumerate": "parallel"})
with handlers.reparam(config={"theta": LocScaleReparam(0)}):
theta = numpyro.sample("theta", dist.Normal(0, chi[c]).to_event(1))
theta = jnp.pad(theta, [(0, 0)] * (jnp.ndim(theta) - 1) + [(0, 1)])
with numpyro.plate("position", num_positions):
logits = Vindex(beta)[positions, c, :] - theta
numpyro.sample("y", dist.Categorical(logits=logits), obs=annotations)
NAME_TO_MODEL = {
"mn": multinomial,
"ds": dawid_skene,
"mace": mace,
"hds": hierarchical_dawid_skene,
"id": item_difficulty,
"lre": logistic_random_effects,
}
def main(args):
annotators, annotations = get_data()
model = NAME_TO_MODEL[args.model]
data = (
(annotations,)
if model in [multinomial, item_difficulty]
else (annotators, annotations)
)
mcmc = MCMC(
NUTS(model),
num_warmup=args.num_warmup,
num_samples=args.num_samples,
num_chains=args.num_chains,
progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
)
mcmc.run(random.PRNGKey(0), *data)
mcmc.print_summary()
posterior_samples = mcmc.get_samples()
predictive = Predictive(model, posterior_samples, infer_discrete=True)
discrete_samples = predictive(random.PRNGKey(1), *data)
item_class = vmap(lambda x: jnp.bincount(x, length=4), in_axes=1)(
discrete_samples["c"].squeeze(-1)
)
print("Histogram of the predicted class of each item:")
row_format = "{:>10}" * 5
print(row_format.format("", *["c={}".format(i) for i in range(4)]))
for i, row in enumerate(item_class):
print(row_format.format(f"item[{i}]", *row))
注意
在上面的推断代码中,我们将离散潜在变量 c 进行了边缘化,因此 mcmc.get_samples(…) 不包含 c 的样本。然后我们使用 Predictive(…, infer_discrete=True) 来获取 c 的后验样本,这些样本存储在 discrete_samples 中。要将这些离散样本合并到 mcmc 实例中,可以使用以下模式
chain_discrete_samples = jax.tree.map(
lambda x: x.reshape((args.num_chains, args.num_samples) + x.shape[1:]),
discrete_samples)
mcmc.get_samples().update(discrete_samples)
mcmc.get_samples(group_by_chain=True).update(chain_discrete_samples)
当我们需要通过 arviz.from_numpyro(mcmc) 将 mcmc 实例传递给 arviz 时,这很有用。
if __name__ == "__main__":
assert numpyro.__version__.startswith("0.18.0")
parser = argparse.ArgumentParser(description="Bayesian Models of Annotation")
parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
parser.add_argument("--num-warmup", nargs="?", default=1000, type=int)
parser.add_argument("--num-chains", nargs="?", default=1, type=int)
parser.add_argument(
"--model",
nargs="?",
default="ds",
help='one of "mn" (multinomial), "ds" (dawid_skene), "mace",'
' "hds" (hierarchical_dawid_skene),'
' "id" (item_difficulty), "lre" (logistic_random_effects)',
)
parser.add_argument("--device", default="cpu", type=str, help='use "cpu" or "gpu".')
args = parser.parse_args()
numpyro.set_platform(args.device)
numpyro.set_host_device_count(args.num_chains)
main(args)