示例:广义线性混合模型

UCBadmit 数据来源于对 1973 年秋季加州大学伯克利分校研究生招生中性别偏见的研究 [1]

UCBadmit 数据集

系别

男性

申请数

录取数

0

1

825

512

0

0

108

89

1

1

560

353

1

0

25

17

2

1

325

120

2

0

593

202

3

1

417

138

3

0

375

131

4

1

191

53

4

0

393

94

5

1

373

22

5

0

341

24

本示例复现了 [3] 中的多层模型 m_glmm5,该模型用于评估数据中是否存在跨系招生中的性别偏见证据。这是一种用于二项回归问题的广义线性混合模型,其建模内容包括

  • 跨系的变截距,

  • 跨系的变斜率(或男性效应),

  • 截距和斜率之间的相关性,

并使用非中心化参数化(或白化)。

关于二项回归和非中心化参数化的更全面解释,请参见 [2] 的第 10 章(计数和分类)和第 13 章(协方差探险)。

参考文献

  1. Bickel, P. J., Hammel, E. A., and O’Connell, J. W. (1975), “研究生招生中的性别偏见:来自伯克利的数据”, Science, 187(4175), 398-404.

  2. McElreath, R. (2018), “统计学反思:一本包含 R 和 Stan 示例的贝叶斯课程”, Chapman and Hall/CRC.

  3. https://github.com/rmcelreath/rethinking/tree/Experimental#multilevel-model-formulas

../_images/ucbadmit.png
import argparse
import os

import matplotlib.pyplot as plt
import numpy as np

from jax import random
import jax.numpy as jnp
from jax.scipy.special import expit

import numpyro
import numpyro.distributions as dist
from numpyro.examples.datasets import UCBADMIT, load_dataset
from numpyro.infer import MCMC, NUTS, Predictive


def glmm(dept, male, applications, admit=None):
    v_mu = numpyro.sample("v_mu", dist.Normal(0, jnp.array([4.0, 1.0])))

    sigma = numpyro.sample("sigma", dist.HalfNormal(jnp.ones(2)))
    L_Rho = numpyro.sample("L_Rho", dist.LKJCholesky(2, concentration=2))
    scale_tril = sigma[..., jnp.newaxis] * L_Rho
    # non-centered parameterization
    num_dept = len(np.unique(dept))
    z = numpyro.sample("z", dist.Normal(jnp.zeros((num_dept, 2)), 1))
    v = jnp.dot(scale_tril, z.T).T

    logits = v_mu[0] + v[dept, 0] + (v_mu[1] + v[dept, 1]) * male
    if admit is None:
        # we use a Delta site to record probs for predictive distribution
        probs = expit(logits)
        numpyro.sample("probs", dist.Delta(probs), obs=probs)
    numpyro.sample("admit", dist.Binomial(applications, logits=logits), obs=admit)


def run_inference(dept, male, applications, admit, rng_key, args):
    kernel = NUTS(glmm)
    mcmc = MCMC(
        kernel,
        num_warmup=args.num_warmup,
        num_samples=args.num_samples,
        num_chains=args.num_chains,
        progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
    )
    mcmc.run(rng_key, dept, male, applications, admit)
    return mcmc.get_samples()


def print_results(header, preds, dept, male, probs):
    columns = ["Dept", "Male", "ActualProb", "Pred(p25)", "Pred(p50)", "Pred(p75)"]
    header_format = "{:>10} {:>10} {:>10} {:>10} {:>10} {:>10}"
    row_format = "{:>10.0f} {:>10.0f} {:>10.2f} {:>10.2f} {:>10.2f} {:>10.2f}"
    quantiles = jnp.quantile(preds, jnp.array([0.25, 0.5, 0.75]), axis=0)
    print("\n", header, "\n")
    print(header_format.format(*columns))
    for i in range(len(dept)):
        print(row_format.format(dept[i], male[i], probs[i], *quantiles[:, i]), "\n")


def main(args):
    _, fetch_train = load_dataset(UCBADMIT, split="train", shuffle=False)
    dept, male, applications, admit = fetch_train()
    rng_key, rng_key_predict = random.split(random.PRNGKey(1))
    zs = run_inference(dept, male, applications, admit, rng_key, args)
    pred_probs = Predictive(glmm, zs)(rng_key_predict, dept, male, applications)[
        "probs"
    ]
    header = "=" * 30 + "glmm - TRAIN" + "=" * 30
    print_results(header, pred_probs, dept, male, admit / applications)

    # make plots
    fig, ax = plt.subplots(figsize=(8, 6), constrained_layout=True)

    ax.plot(range(1, 13), admit / applications, "o", ms=7, label="actual rate")
    ax.errorbar(
        range(1, 13),
        jnp.mean(pred_probs, 0),
        jnp.std(pred_probs, 0),
        fmt="o",
        c="k",
        mfc="none",
        ms=7,
        elinewidth=1,
        label=r"mean $\pm$ std",
    )
    ax.plot(range(1, 13), jnp.percentile(pred_probs, 5, 0), "k+")
    ax.plot(range(1, 13), jnp.percentile(pred_probs, 95, 0), "k+")
    ax.set(
        xlabel="cases",
        ylabel="admit rate",
        title="Posterior Predictive Check with 90% CI",
    )
    ax.legend()

    plt.savefig("ucbadmit_plot.pdf")


if __name__ == "__main__":
    assert numpyro.__version__.startswith("0.18.0")
    parser = argparse.ArgumentParser(
        description="UCBadmit gender discrimination using HMC"
    )
    parser.add_argument("-n", "--num-samples", nargs="?", default=2000, type=int)
    parser.add_argument("--num-warmup", nargs="?", default=500, type=int)
    parser.add_argument("--num-chains", nargs="?", default=1, type=int)
    parser.add_argument("--device", default="cpu", type=str, help='use "cpu" or "gpu".')
    args = parser.parse_args()

    numpyro.set_platform(args.device)
    numpyro.set_host_device_count(args.num_chains)

    main(args)

本画廊由 Sphinx-Gallery 生成