示例:跨空间和时间建模死亡率

此示例改编自 [1]。论文中的模型估计了英格兰 6791 个小区域在 2002-19 年间 19 个年龄组(0 岁,1-4 岁,5-9 岁,10-14 岁,…,80-84 岁,85 岁以上)的死亡率。

在高空间分辨率下建模死亡率时,每个年龄组、空间单元和年份的死亡人数很少,这意味着从观察数据计算出的死亡率具有明显的变异性,这比真实的死亡风险差异要大。贝叶斯多级建模框架可以通过在年龄、空间和时间上共享信息来克服小样本问题,从而获得平滑的死亡率并捕获估计中的不确定性。

除了全局截距 (\(\alpha_0\)) 和斜率 (\(\beta_0\)) 外,模型还包括以下效应

  • 年龄 (\(\alpha_{2a}\), \(\beta_{2a}\))。每个年龄组都有不同的截距和斜率,并在年龄组上具有随机游走结构,以允许非线性年龄关联。

  • 空间 (\(\alpha_{1s}\))。每个空间单元都有一个截距。空间效应由遵循地方政府行政层级的嵌套随机效应层次结构定义。较低级别单元的空间项以包含该较低级别单元的较高级别单元的空间项(例如,\(\alpha_{1s_1}\))为中心。

模型在时间上还具有随机游走效应 (\(\pi_{t}\))。

死亡率使用二项式似然函数与死亡和人口数据相关联。死亡率的完整生成模型写为

\begin{align} \alpha_{1s_1} & \sim \text{Normal}(0,\sigma_{\alpha_{s_1}}^2) \\ \alpha_{1s} & \sim \text{Normal}(\alpha_{1s_1(s_2)},\sigma_{\alpha_{s_2}}^2) \\ \alpha_{2a} & \sim \text{Normal}(\alpha_{2,a-1},\sigma_{\alpha_a}^2) \quad \alpha_{2,0} = \alpha_0 \\ \beta_{2a} & \sim \text{Normal}(\beta_{2,a-1},\sigma_{\beta_a}^2) \quad \beta_{2,0} = \beta_0 \\ \pi_{t} & \sim \text{Normal}(\pi_{t-1},\sigma_{\pi}^2), \quad \pi_{0} = 0 \\ \text{logit}(m_{ast}) & = \alpha_{1s} + \alpha_{2a} + \beta_{2a} t + \pi_{t} \end{align}

使用超先验

\begin{align} \alpha_0 & \sim \text{Normal}(0,10), \\ \beta_0 & \sim \text{Normal}(0,10), \\ \sigma_i & \sim \text{Half-Normal}(1) \end{align}

有关模型项的更多详细信息可在 [1] 中找到。

下面的 NumPyro 实现使用 plate 符号来声明年龄、空间和时间变量的批处理维度。这使得我们可以在似然函数中高效地广播数组。

如上所述,模型包含了许多中心化随机效应。NUTS 算法受益于非中心化重参数化,以克服困难的后验几何形状 [2]。我们不是手动写出非中心化参数化,而是利用 NumPyro 在 LocScaleReparam 中的自动重参数化。

[1] 中空间分辨率下的死亡数据是可识别的,因此在此示例中,我们使用模拟数据。与 [1] 相比,模拟数据具有更少的空间单元和两层(而不是三层)空间层次结构。与原始研究一样,仍有 19 个年龄组和 18 年。此处的数据具有 (19, 113, 18) (年龄,空间,时间) 的(事件)维度。

nimble 中的原始实现可在 [3] 找到。

参考文献

  1. Rashid, T., Bennett, J.E. et al. (2021). Life expectancy and risk of death in 6791 communities in England from 2002 to 2019: high-resolution spatiotemporal analysis of civil registration data. The Lancet Public Health, 6, e805 - e816.

  2. Stan 用户指南。 https://mc-stan.org/docs/2_28/stan-users-guide/reparameterization.html

  3. 使用贝叶斯分层模型建模死亡率。 https://github.com/theorashid/mortality-statsmodel

import argparse
import os

import numpy as np

from jax import random
import jax.numpy as jnp

import numpyro
import numpyro.distributions as dist
from numpyro.examples.datasets import MORTALITY, load_dataset
from numpyro.infer import MCMC, NUTS
from numpyro.infer.reparam import LocScaleReparam


def create_lookup(s1, s2):
    """
    Create a map between s1 indices and unique s2 indices
    """
    lookup = np.column_stack([s1, s2])
    lookup = np.unique(lookup, axis=0)
    lookup = lookup[lookup[:, 1].argsort()]
    return lookup[:, 0]


reparam_config = {
    k: LocScaleReparam(0)
    for k in [
        "alpha_s1",
        "alpha_s2",
        "alpha_age_drift",
        "beta_age_drift",
        "pi_drift",
    ]
}


@numpyro.handlers.reparam(config=reparam_config)
def model(age, space, time, lookup, population, deaths=None):
    N_s1 = len(np.unique(lookup))
    N_s2 = len(np.unique(space))
    N_age = len(np.unique(age))
    N_t = len(np.unique(time))
    N = len(population)

    # plates
    age_plate = numpyro.plate("age_groups", N_age, dim=-3)
    space_plate = numpyro.plate("space", N_s2, dim=-2)
    year_plate = numpyro.plate("year", N_t - 1, dim=-1)

    # hyperparameters
    sigma_alpha_s1 = numpyro.sample("sigma_alpha_s1", dist.HalfNormal(1.0))
    sigma_alpha_s2 = numpyro.sample("sigma_alpha_s2", dist.HalfNormal(1.0))
    sigma_alpha_age = numpyro.sample("sigma_alpha_age", dist.HalfNormal(1.0))
    sigma_beta_age = numpyro.sample("sigma_beta_age", dist.HalfNormal(1.0))
    sigma_pi = numpyro.sample("sigma_pi", dist.HalfNormal(1.0))

    # spatial hierarchy
    with numpyro.plate("s1", N_s1, dim=-2):
        alpha_s1 = numpyro.sample("alpha_s1", dist.Normal(0, sigma_alpha_s1))
    with space_plate:
        alpha_s2 = numpyro.sample(
            "alpha_s2", dist.Normal(alpha_s1[lookup], sigma_alpha_s2)
        )

    # age
    with age_plate:
        alpha_age_drift_scale = jnp.pad(
            jnp.broadcast_to(sigma_alpha_age, N_age - 1),
            (1, 0),
            constant_values=10.0,  # pad so first term is alpha0, prior N(0, 10)
        )[:, jnp.newaxis, jnp.newaxis]
        alpha_age_drift = numpyro.sample(
            "alpha_age_drift", dist.Normal(0, alpha_age_drift_scale)
        )
        alpha_age = jnp.cumsum(alpha_age_drift, -3)

        beta_age_drift_scale = jnp.pad(
            jnp.broadcast_to(sigma_beta_age, N_age - 1), (1, 0), constant_values=10.0
        )[:, jnp.newaxis, jnp.newaxis]
        beta_age_drift = numpyro.sample(
            "beta_age_drift", dist.Normal(0, beta_age_drift_scale)
        )
        beta_age = jnp.cumsum(beta_age_drift, -3)
        beta_age_cum = jnp.outer(beta_age, jnp.arange(N_t))[:, jnp.newaxis, :]

    # random walk over time
    with year_plate:
        pi_drift = numpyro.sample("pi_drift", dist.Normal(0, sigma_pi))
        pi = jnp.pad(jnp.cumsum(pi_drift, -1), (1, 0))

    # likelihood
    latent_rate = alpha_age + beta_age_cum + alpha_s2 + pi
    with numpyro.plate("N", N):
        mu_logit = latent_rate[age, space, time]
        numpyro.sample("deaths", dist.Binomial(population, logits=mu_logit), obs=deaths)


def print_model_shape(model, age, space, time, lookup, population):
    with numpyro.handlers.seed(rng_seed=1):
        trace = numpyro.handlers.trace(model).get_trace(
            age=age,
            space=space,
            time=time,
            lookup=lookup,
            population=population,
        )
    print(numpyro.util.format_shapes(trace))


def run_inference(model, age, space, time, lookup, population, deaths, rng_key, args):
    kernel = NUTS(model)
    mcmc = MCMC(
        kernel,
        num_warmup=args.num_warmup,
        num_samples=args.num_samples,
        num_chains=args.num_chains,
        progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
    )
    mcmc.run(rng_key, age, space, time, lookup, population, deaths)
    mcmc.print_summary()
    return mcmc.get_samples()


def main(args):
    print("Fetching simulated data...")
    _, fetch = load_dataset(MORTALITY, shuffle=False)
    a, s1, s2, t, deaths, population = fetch()

    lookup = create_lookup(s1, s2)

    print("Model shape:")
    print_model_shape(model, a, s2, t, lookup, population)

    print("Starting inference...")
    rng_key = random.PRNGKey(args.rng_seed)
    run_inference(model, a, s2, t, lookup, population, deaths, rng_key, args)


if __name__ == "__main__":
    assert numpyro.__version__.startswith("0.18.0")

    parser = argparse.ArgumentParser(description="Mortality regression model")
    parser.add_argument("-n", "--num-samples", nargs="?", default=500, type=int)
    parser.add_argument("--num-warmup", nargs="?", default=200, type=int)
    parser.add_argument("--num-chains", nargs="?", default=1, type=int)
    parser.add_argument("--device", default="cpu", type=str, help='use "cpu" or "gpu".')
    parser.add_argument(
        "--rng_seed", default=21, type=int, help="random number generator seed"
    )
    args = parser.parse_args()

    numpyro.set_platform(args.device)
    numpyro.enable_x64()

    main(args)

由 Sphinx-Gallery 生成的画廊