注意
跳转到末尾 下载完整的示例代码。
示例:零膨胀泊松回归模型
在此示例中,我们对州立公园游客捕获的鱼数量进行建模和预测。许多游客群体捕获的鱼数量为零,原因可能是他们根本没有钓鱼,或者是因为运气不好。我们希望明确地对这种双峰行为(零与非零)进行建模,并确定哪些变量对每种行为有贡献。
我们通过拟合零膨胀泊松回归模型来回答这个问题。我们使用 MAP、VI 和 MCMC 作为估计方法。最后,从 MCMC 样本中,我们确定了对零膨胀泊松似然函数的零分量和非零分量有贡献的变量。
import argparse
import os
import random
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.metrics import mean_squared_error
import jax.numpy as jnp
from jax.random import PRNGKey
import jax.scipy as jsp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, SVI, Predictive, Trace_ELBO, autoguide
matplotlib.use("Agg") # noqa: E402
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
def model(X, Y):
D_X = X.shape[1]
b1 = numpyro.sample("b1", dist.Normal(0.0, 1.0).expand([D_X]).to_event(1))
b2 = numpyro.sample("b2", dist.Normal(0.0, 1.0).expand([D_X]).to_event(1))
q = jsp.special.expit(jnp.dot(X, b1[:, None])).reshape(-1)
lam = jnp.exp(jnp.dot(X, b2[:, None]).reshape(-1))
with numpyro.plate("obs", X.shape[0]):
numpyro.sample("Y", dist.ZeroInflatedPoisson(gate=q, rate=lam), obs=Y)
def run_mcmc(model, args, X, Y):
kernel = NUTS(model)
mcmc = MCMC(
kernel,
num_warmup=args.num_warmup,
num_samples=args.num_samples,
num_chains=args.num_chains,
progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
)
mcmc.run(PRNGKey(1), X, Y)
mcmc.print_summary()
return mcmc.get_samples()
def run_svi(model, guide_family, args, X, Y):
if guide_family == "AutoDelta":
guide = autoguide.AutoDelta(model)
elif guide_family == "AutoDiagonalNormal":
guide = autoguide.AutoDiagonalNormal(model)
optimizer = numpyro.optim.Adam(0.001)
svi = SVI(model, guide, optimizer, Trace_ELBO())
svi_results = svi.run(PRNGKey(1), args.maxiter, X=X, Y=Y)
params = svi_results.params
return params, guide
def main(args):
set_seed(args.seed)
# prepare dataset
df = pd.read_stata("http://www.stata-press.com/data/r11/fish.dta")
df["intercept"] = 1
cols = ["livebait", "camper", "persons", "child", "intercept"]
mask = np.random.randn(len(df)) < args.train_size
df_train = df[mask]
df_test = df[~mask]
X_train = jnp.asarray(df_train[cols].values)
y_train = jnp.asarray(df_train["count"].values)
X_test = jnp.asarray(df_test[cols].values)
y_test = jnp.asarray(df_test["count"].values)
print("run MAP.")
map_params, map_guide = run_svi(model, "AutoDelta", args, X_train, y_train)
print("run VI.")
vi_params, vi_guide = run_svi(model, "AutoDiagonalNormal", args, X_train, y_train)
print("run MCMC.")
posterior_samples = run_mcmc(model, args, X_train, y_train)
# evaluation
def svi_predict(model, guide, params, args, X):
predictive = Predictive(
model=model, guide=guide, params=params, num_samples=args.num_samples
)
predictions = predictive(PRNGKey(1), X=X, Y=None)
svi_predictions = jnp.rint(predictions["Y"].mean(0))
return svi_predictions
map_predictions = svi_predict(model, map_guide, map_params, args, X_test)
vi_predictions = svi_predict(model, vi_guide, vi_params, args, X_test)
predictive = Predictive(model, posterior_samples=posterior_samples)
predictions = predictive(PRNGKey(1), X=X_test, Y=None)
mcmc_predictions = jnp.rint(predictions["Y"].mean(0))
print(
"MAP RMSE: ",
mean_squared_error(y_test.to_py(), map_predictions.to_py(), squared=False),
)
print(
"VI RMSE: ",
mean_squared_error(y_test.to_py(), vi_predictions.to_py(), squared=False),
)
print(
"MCMC RMSE: ",
mean_squared_error(y_test.to_py(), mcmc_predictions.to_py(), squared=False),
)
# make plot
fig, axes = plt.subplots(2, 1, figsize=(6, 6), constrained_layout=True)
def add_fig(var_name, title, ax):
ax.set_title(title)
ax.violinplot(
[posterior_samples[var_name][:, i].to_py() for i in range(len(cols))]
)
ax.set_xticks(np.arange(1, len(cols) + 1))
ax.set_xticklabels(cols, rotation=45, fontsize=10)
add_fig("b1", "Coefficients for probability of catching fish", axes[0])
add_fig("b2", "Coefficients for the number of fish caught", axes[1])
plt.savefig("zip_fish.png")
if __name__ == "__main__":
assert numpyro.__version__.startswith("0.18.0")
parser = argparse.ArgumentParser("Zero-Inflated Poisson Regression")
parser.add_argument("--seed", nargs="?", default=42, type=int)
parser.add_argument("-n", "--num-samples", nargs="?", default=2000, type=int)
parser.add_argument("--num-warmup", nargs="?", default=1000, type=int)
parser.add_argument("--num-chains", nargs="?", default=1, type=int)
parser.add_argument("--num-data", nargs="?", default=100, type=int)
parser.add_argument("--maxiter", nargs="?", default=5000, type=int)
parser.add_argument("--train-size", nargs="?", default=0.8, type=float)
parser.add_argument("--device", default="cpu", type=str, help='use "cpu" or "gpu".')
args = parser.parse_args()
numpyro.set_platform(args.device)
numpyro.set_host_device_count(args.num_chains)
main(args)