注意
前往末尾 下载完整示例代码。
示例:棒球击球平均数
来自 Pyro 的原始示例: https://github.com/pyro-ppl/pyro/blob/dev/examples/baseball.py
此示例改编自 [1]。它演示了如何在 Pyro 中使用各种 MCMC 核(HMC、NUTS、SA)进行贝叶斯推断,以及一些常用推断工具的用法。
如同 Stan 教程一样,本示例使用 Efron 和 Morris [2] 的小型棒球数据集来估计球员的击球平均数,即球员获得安打次数占其上场击球总次数的比例。
该数据集将最初 45 次上场击球的统计数据与剩余赛季的数据分开。我们使用最初 45 次上场击球的安打数据来估计每位球员的击球平均数。然后,我们使用剩余赛季的数据来验证模型的预测结果。
评估了三个模型
完全合并模型:获得安打的成功概率在所有球员之间共享。
不合并模型:每个球员的成功概率是独立的,球员之间不共享数据。
部分合并模型:一个具有部分数据共享的分层模型。
我们推荐 Radford Neal 关于 HMC 的教程 ([3]) 给希望更全面了解 HMC 及其变体的用户,以及 [4] 关于 No U-Turn Sampler(无 U 形转弯采样器)的详细信息,该采样器提供了一种高效且自动化(即超参数有限)的方式在不同问题上运行 HMC。
请注意,基于 [5] 实现的样本自适应 (SA) 核需要较大的 num_warmup 和 num_samples(例如 15,000 和 300,000)。因此最好禁用进度条以避免调度开销。
参考文献
Carpenter B. (2016), “分层部分合并用于重复二元试验”。
Efron B., Morris C. (1975), “使用 Stein 估计量及其推广的数据分析”, J. Amer. Statist. Assoc., 70, 311-319.
Neal, R. (2012), “使用哈密顿动力学的 MCMC”, (https://arxiv.org/pdf/1206.1901.pdf)
Hoffman, M. D. and Gelman, A. (2014), “无 U 形转弯采样器:在哈密顿蒙特卡罗中自适应设置路径长度”, (https://arxiv.org/abs/1111.4246)
Michael Zhu (2019), “样本自适应 MCMC”, (https://papers.nips.cc/paper/9107-sample-adaptive-mcmc)
import argparse
import os
import jax.numpy as jnp
import jax.random as random
from jax.scipy.special import logsumexp
import numpyro
import numpyro.distributions as dist
from numpyro.examples.datasets import BASEBALL, load_dataset
from numpyro.infer import HMC, MCMC, NUTS, SA, Predictive, log_likelihood
def fully_pooled(at_bats, hits=None):
r"""
Number of hits in $K$ at bats for each player has a Binomial
distribution with a common probability of success, $\phi$.
:param (jnp.ndarray) at_bats: Number of at bats for each player.
:param (jnp.ndarray) hits: Number of hits for the given at bats.
:return: Number of hits predicted by the model.
"""
phi_prior = dist.Uniform(0, 1)
phi = numpyro.sample("phi", phi_prior)
num_players = at_bats.shape[0]
with numpyro.plate("num_players", num_players):
return numpyro.sample("obs", dist.Binomial(at_bats, probs=phi), obs=hits)
def not_pooled(at_bats, hits=None):
r"""
Number of hits in $K$ at bats for each player has a Binomial
distribution with independent probability of success, $\phi_i$.
:param (jnp.ndarray) at_bats: Number of at bats for each player.
:param (jnp.ndarray) hits: Number of hits for the given at bats.
:return: Number of hits predicted by the model.
"""
num_players = at_bats.shape[0]
with numpyro.plate("num_players", num_players):
phi_prior = dist.Uniform(0, 1)
phi = numpyro.sample("phi", phi_prior)
return numpyro.sample("obs", dist.Binomial(at_bats, probs=phi), obs=hits)
def partially_pooled(at_bats, hits=None):
r"""
Number of hits has a Binomial distribution with independent
probability of success, $\phi_i$. Each $\phi_i$ follows a Beta
distribution with concentration parameters $c_1$ and $c_2$, where
$c_1 = m * kappa$, $c_2 = (1 - m) * kappa$, $m ~ Uniform(0, 1)$,
and $kappa ~ Pareto(1, 1.5)$.
:param (jnp.ndarray) at_bats: Number of at bats for each player.
:param (jnp.ndarray) hits: Number of hits for the given at bats.
:return: Number of hits predicted by the model.
"""
m = numpyro.sample("m", dist.Uniform(0, 1))
kappa = numpyro.sample("kappa", dist.Pareto(1, 1.5))
num_players = at_bats.shape[0]
with numpyro.plate("num_players", num_players):
phi_prior = dist.Beta(m * kappa, (1 - m) * kappa)
phi = numpyro.sample("phi", phi_prior)
return numpyro.sample("obs", dist.Binomial(at_bats, probs=phi), obs=hits)
def partially_pooled_with_logit(at_bats, hits=None):
r"""
Number of hits has a Binomial distribution with a logit link function.
The logits $\alpha$ for each player is normally distributed with the
mean and scale parameters sharing a common prior.
:param (jnp.ndarray) at_bats: Number of at bats for each player.
:param (jnp.ndarray) hits: Number of hits for the given at bats.
:return: Number of hits predicted by the model.
"""
loc = numpyro.sample("loc", dist.Normal(-1, 1))
scale = numpyro.sample("scale", dist.HalfCauchy(1))
num_players = at_bats.shape[0]
with numpyro.plate("num_players", num_players):
alpha = numpyro.sample("alpha", dist.Normal(loc, scale))
return numpyro.sample("obs", dist.Binomial(at_bats, logits=alpha), obs=hits)
def run_inference(model, at_bats, hits, rng_key, args):
if args.algo == "NUTS":
kernel = NUTS(model)
elif args.algo == "HMC":
kernel = HMC(model)
elif args.algo == "SA":
kernel = SA(model)
mcmc = MCMC(
kernel,
num_warmup=args.num_warmup,
num_samples=args.num_samples,
num_chains=args.num_chains,
progress_bar=False
if ("NUMPYRO_SPHINXBUILD" in os.environ or args.disable_progbar)
else True,
)
mcmc.run(rng_key, at_bats, hits)
return mcmc.get_samples()
def predict(model, at_bats, hits, z, rng_key, player_names, train=True):
header = model.__name__ + (" - TRAIN" if train else " - TEST")
predictions = Predictive(model, posterior_samples=z)(rng_key, at_bats)["obs"]
print_results(
"=" * 30 + header + "=" * 30, predictions, player_names, at_bats, hits
)
if not train:
post_loglik = log_likelihood(model, z, at_bats, hits)["obs"]
# computes expected log predictive density at each data point
exp_log_density = logsumexp(post_loglik, axis=0) - jnp.log(
jnp.shape(post_loglik)[0]
)
# reports log predictive density of all test points
print(
"\nLog pointwise predictive density: {:.2f}\n".format(exp_log_density.sum())
)
def print_results(header, preds, player_names, at_bats, hits):
columns = ["", "At-bats", "ActualHits", "Pred(p25)", "Pred(p50)", "Pred(p75)"]
header_format = "{:>20} {:>10} {:>10} {:>10} {:>10} {:>10}"
row_format = "{:>20} {:>10.0f} {:>10.0f} {:>10.2f} {:>10.2f} {:>10.2f}"
quantiles = jnp.quantile(preds, jnp.array([0.25, 0.5, 0.75]), axis=0)
print("\n", header, "\n")
print(header_format.format(*columns))
for i, p in enumerate(player_names):
print(row_format.format(p, at_bats[i], hits[i], *quantiles[:, i]), "\n")
def main(args):
_, fetch_train = load_dataset(BASEBALL, split="train", shuffle=False)
train, player_names = fetch_train()
_, fetch_test = load_dataset(BASEBALL, split="test", shuffle=False)
test, _ = fetch_test()
at_bats, hits = train[:, 0], train[:, 1]
season_at_bats, season_hits = test[:, 0], test[:, 1]
for i, model in enumerate(
(fully_pooled, not_pooled, partially_pooled, partially_pooled_with_logit)
):
rng_key, rng_key_predict = random.split(random.PRNGKey(i + 1))
zs = run_inference(model, at_bats, hits, rng_key, args)
predict(model, at_bats, hits, zs, rng_key_predict, player_names)
predict(
model,
season_at_bats,
season_hits,
zs,
rng_key_predict,
player_names,
train=False,
)
if __name__ == "__main__":
assert numpyro.__version__.startswith("0.18.0")
parser = argparse.ArgumentParser(description="Baseball batting average using MCMC")
parser.add_argument("-n", "--num-samples", nargs="?", default=3000, type=int)
parser.add_argument("--num-warmup", nargs="?", default=1500, type=int)
parser.add_argument("--num-chains", nargs="?", default=1, type=int)
parser.add_argument(
"--algo", default="NUTS", type=str, help='whether to run "HMC", "NUTS", or "SA"'
)
parser.add_argument(
"-dp",
"--disable-progbar",
action="store_true",
default=False,
help="whether to disable progress bar",
)
parser.add_argument("--device", default="cpu", type=str, help='use "cpu" or "gpu".')
args = parser.parse_args()
numpyro.set_platform(args.device)
numpyro.set_host_device_count(args.num_chains)
main(args)