交互式在线版本: 在 Colab 中打开

离散协变量中缺失值的贝叶斯归因

在实际应用中,缺失数据是一个非常普遍的问题,无论是在协变量(“解释变量”)还是结果变量中。使用 MCMC 进行贝叶斯推断时,无法使用汉密尔顿蒙特卡罗技术对离散缺失值进行归因。解决这个问题的一种方法是创建一个新模型,该模型枚举离散变量并对新模型进行推断,对于单个离散变量而言,这构成了一个混合模型。(例如,参见 Stan 用户指南关于潜离散参数的内容)枚举离散潜变量站点需要一些手动数学工作,对于复杂模型来说可能会很繁琐。NumPyro 中实现了离散变量自动枚举的推断功能,这为处理离散缺失数据提供了一种非常便捷的方法。

[ ]:
!pip install -q numpyro@git+https://github.com/pyro-ppl/numpyro funsor
[1]:
from math import inf

from graphviz import Digraph

from jax import numpy as jnp, random
from jax.scipy.special import expit

import numpyro
from numpyro import distributions as dist, sample
from numpyro.infer.hmc import NUTS
from numpyro.infer.mcmc import MCMC

simkeys = random.split(random.PRNGKey(0), 10)
nsim = 5000
mcmc_key = random.PRNGKey(1)

首先,我们将模拟具有相关二元协变量的数据。假设我们希望无偏地估计某个参数模型的参数(例如,用于推断因果效应)。对于几种不同的缺失数据模式,我们将看到如何对缺失值进行归因以获得无偏模型。

基本数据结构如下。Z 是一个潜变量,它导致观测到的协变量 A 和 B 之间存在边际依赖关系。我们将考虑变量 A 的不同缺失数据机制,其中变量 B 和 Y 完全观测到。A 和 B 对 Y 的影响是我们感兴趣的影响。

[2]:
dot = Digraph()
dot.node("A")
dot.node("B")
dot.node("Z")
dot.node("Y")
dot.edges(["ZA", "ZB", "AY", "BY"])
dot
[2]:
../_images/tutorials_discrete_imputation_4_0.svg
[3]:
b_A = 0.25
b_B = 0.25
s_Y = 0.25
Z = random.normal(simkeys[0], (nsim,))
A = random.bernoulli(simkeys[1], expit(Z))
B = random.bernoulli(simkeys[2], expit(Z))
Y = A * b_A + B * b_B + s_Y * random.normal(simkeys[3], (nsim,))

取决于结果的 MAR

根据 Rubin 的经典定义,缺失数据机制有 3 种不同类型

  1. 完全随机缺失 (MCAR)

  2. 随机缺失,取决于观测数据 (MAR)

  3. 非随机缺失,即使在取决于观测数据后 (MNAR)

缺失数据机制 1 和 2 易于处理,因为它们仅依赖于观测数据。机制 3 (MNAR) 则更棘手,因为它依赖于未观测到的数据,但这些数据可能仍与您正在建模的结果相关(参见下文的具体示例)。

首先,我们将根据 Y 的值生成 A 中的缺失值(因此这是一种 MAR 机制)。

[4]:
dot_mnar_y = Digraph()
with dot_mnar_y.subgraph() as s:
    s.attr(rank="same")
    s.node("Y")
    s.node("M")
dot_mnar_y.node("A")
dot_mnar_y.node("B")
dot_mnar_y.node("Z")
dot_mnar_y.node("M")
dot_mnar_y.edges(["YM", "ZA", "ZB", "AY", "BY"])
dot_mnar_y
[4]:
../_images/tutorials_discrete_imputation_7_0.svg

此图描绘了数据生成机制,其中 Y 是导致 A 中缺失的唯一原因,标记为 M。这意味着 M 中的缺失是随机的,取决于 Y。

举例来说,考虑这个简化场景

  • A 代表心脏病史

  • B 代表患者年龄

  • Y 代表患者是否会去看全科医生

一位全科医生想找出为什么分配到她诊所的患者会或不会来看诊。她认为心脏病史和年龄是看诊的潜在原因。患者年龄数据可以通过他们的登记表获得,但先前的心脏病信息可能只有在他们来诊后才能获得。这使得 A (心脏病史) 中的缺失取决于结果 (来诊)。

[5]:
A_isobs = random.bernoulli(simkeys[4], expit(3 * (Y - Y.mean())))
Aobs = jnp.where(A_isobs, A, -1)
A_obsidx = jnp.where(A_isobs)

# generate complete case arrays
Acc = Aobs[A_obsidx]
Bcc = B[A_obsidx]
Ycc = Y[A_obsidx]

我们将评估两种方法

  1. 完整病例分析(将导致有偏推断)

  2. 带归因(取决于 B)

请注意,在 A 的归因模型中明确包含 Y 是不必要的。对 A 的采样归因将间接取决于 Y,因为 Y 的似然取决于 A。因此,对 Y 产生高似然的 A 值将比其他值被更频繁地采样。

[6]:
def ccmodel(A, B, Y):
    ntotal = A.shape[0]
    # get parameters of outcome model
    b_A = sample("b_A", dist.Normal(0, 2.5))
    b_B = sample("b_B", dist.Normal(0, 2.5))
    s_Y = sample("s_Y", dist.HalfCauchy(2.5))

    with numpyro.plate("obs", ntotal):
        ### outcome model
        eta_Y = b_A * A + b_B * B
        sample("obs_Y", dist.Normal(eta_Y, s_Y), obs=Y)
[7]:
cckernel = NUTS(ccmodel)
ccmcmc = MCMC(cckernel, num_warmup=250, num_samples=750)
ccmcmc.run(mcmc_key, Acc, Bcc, Ycc)
ccmcmc.print_summary()
sample: 100%|██████████| 1000/1000 [00:02<00:00, 348.50it/s, 3 steps of size 4.27e-01. acc. prob=0.94]
                mean       std    median      5.0%     95.0%     n_eff     r_hat
       b_A      0.30      0.01      0.30      0.29      0.31    500.83      1.00
       b_B      0.28      0.01      0.28      0.27      0.29    546.34      1.00
       s_Y      0.25      0.00      0.25      0.24      0.25    559.55      1.00

Number of divergences: 0
[8]:
def impmodel(A, B, Y):
    ntotal = A.shape[0]
    A_isobs = A >= 0

    # get parameters of imputation model
    mu_A = sample("mu_A", dist.Normal(0, 2.5))
    b_B_A = sample("b_B_A", dist.Normal(0, 2.5))

    # get parameters of outcome model
    b_A = sample("b_A", dist.Normal(0, 2.5))
    b_B = sample("b_B", dist.Normal(0, 2.5))
    s_Y = sample("s_Y", dist.HalfCauchy(2.5))

    with numpyro.plate("obs", ntotal):
        ### imputation model
        # get linear predictor for missing values
        eta_A = mu_A + B * b_B_A

        # sample imputation values for A
        # mask out to not add log_prob to total likelihood right now
        Aimp = sample(
            "A",
            dist.Bernoulli(logits=eta_A).mask(False),
            infer={"enumerate": "parallel"},
        )

        # 'manually' calculate the log_prob
        log_prob = dist.Bernoulli(logits=eta_A).log_prob(Aimp)

        # cancel out enumerated values that are not equal to observed values
        log_prob = jnp.where(A_isobs & (Aimp != A), -inf, log_prob)

        # add to total likelihood for sampler
        numpyro.factor("A_obs", log_prob)

        ### outcome model
        eta_Y = b_A * Aimp + b_B * B
        sample("obs_Y", dist.Normal(eta_Y, s_Y), obs=Y)
[9]:
impkernel = NUTS(impmodel)
impmcmc = MCMC(impkernel, num_warmup=250, num_samples=750)
impmcmc.run(mcmc_key, Aobs, B, Y)
impmcmc.print_summary()
sample: 100%|██████████| 1000/1000 [00:05<00:00, 174.83it/s, 7 steps of size 4.41e-01. acc. prob=0.91]
                mean       std    median      5.0%     95.0%     n_eff     r_hat
       b_A      0.25      0.01      0.25      0.24      0.27    447.79      1.01
       b_B      0.25      0.01      0.25      0.24      0.26    570.66      1.01
     b_B_A      0.74      0.08      0.74      0.60      0.86    316.36      1.00
      mu_A     -0.39      0.06     -0.39     -0.48     -0.29    290.86      1.00
       s_Y      0.25      0.00      0.25      0.25      0.25    527.97      1.00

Number of divergences: 0

正如我们所见,当数据取决于 Y 缺失时,归因能够一致地估计感兴趣的参数 (b_A 和 b_B)。

取决于协变量的 MNAR

当数据取决于未观测数据缺失时,事情会变得更棘手。这里我们将根据 A 本身的值生成 A 中的缺失值(非随机缺失 (MNAR),但取决于 A 随机缺失)。

举例来说,考虑癌症患者

  • A 代表体重减轻

  • B 代表年龄

  • Y 代表总生存时间

A 和 B 都可能与癌症患者的生存时间有关。对于体重减轻极度的患者,医生更有可能注意到这一点并将其记录在电子病历中。对于没有或轻微体重减轻的患者,医生可能忘记询问,因此没有将其记录在案。

[10]:
dot_mnar_x = Digraph()
with dot_mnar_y.subgraph() as s:
    s.attr(rank="same")
    s.node("A")
    s.node("M")
dot_mnar_x.node("B")
dot_mnar_x.node("Z")
dot_mnar_x.node("Y")
dot_mnar_x.edges(["AM", "ZA", "ZB", "AY", "BY"])
dot_mnar_x
[10]:
../_images/tutorials_discrete_imputation_17_0.svg
[11]:
A_isobs = random.bernoulli(simkeys[5], 0.9 - 0.8 * A)
Aobs = jnp.where(A_isobs, A, -1)
A_obsidx = jnp.where(A_isobs)

# generate complete case arrays
Acc = Aobs[A_obsidx]
Bcc = B[A_obsidx]
Ycc = Y[A_obsidx]
[12]:
cckernel = NUTS(ccmodel)
ccmcmc = MCMC(cckernel, num_warmup=250, num_samples=750)
ccmcmc.run(mcmc_key, Acc, Bcc, Ycc)
ccmcmc.print_summary()
sample: 100%|██████████| 1000/1000 [00:02<00:00, 342.07it/s, 3 steps of size 5.97e-01. acc. prob=0.92]
                mean       std    median      5.0%     95.0%     n_eff     r_hat
       b_A      0.27      0.02      0.26      0.24      0.29    667.08      1.01
       b_B      0.25      0.01      0.25      0.24      0.26    811.49      1.00
       s_Y      0.25      0.00      0.25      0.24      0.25    547.51      1.00

Number of divergences: 0
[13]:
impkernel = NUTS(impmodel)
impmcmc = MCMC(impkernel, num_warmup=250, num_samples=750)
impmcmc.run(mcmc_key, Aobs, B, Y)
impmcmc.print_summary()
sample: 100%|██████████| 1000/1000 [00:06<00:00, 166.36it/s, 7 steps of size 4.10e-01. acc. prob=0.94]
                mean       std    median      5.0%     95.0%     n_eff     r_hat
       b_A      0.34      0.01      0.34      0.32      0.35    576.15      1.00
       b_B      0.33      0.01      0.33      0.32      0.34    800.58      1.00
     b_B_A      0.32      0.12      0.32      0.12      0.51    342.21      1.01
      mu_A     -1.81      0.09     -1.81     -1.95     -1.67    288.57      1.00
       s_Y      0.26      0.00      0.26      0.25      0.26    820.20      1.00

Number of divergences: 0

也许令人惊讶的是,当缺失机制取决于变量本身时,归因缺失值实际上会导致偏差,而完整病例分析是无偏的!例如,参见 Bias and efficiency of multiple imputation compared with complete‐case analysis for missing covariate values

然而,完整病例分析也可能不理想。例如,由于导致估计 B 到 Y 的参数精度较低,或者当 A 的值与 A 到 Y 的参数之间存在预期差异交互作用时。为了处理这种情况,需要建立一个明确的缺失(/观测)原因模型。我们将在下面添加一个。

[14]:
def impmissmodel(A, B, Y):
    ntotal = A.shape[0]
    A_isobs = A >= 0

    # get parameters of imputation model
    mu_A = sample("mu_A", dist.Normal(0, 2.5))
    b_B_A = sample("b_B_A", dist.Normal(0, 2.5))

    # get parameters of outcome model
    b_A = sample("b_A", dist.Normal(0, 2.5))
    b_B = sample("b_B", dist.Normal(0, 2.5))
    s_Y = sample("s_Y", dist.HalfCauchy(2.5))

    # get parameter of model of missingness
    with numpyro.plate("obsmodel", 2):
        p_Aobs = sample("p_Aobs", dist.Beta(1, 1))

    with numpyro.plate("obs", ntotal):
        ### imputation model
        # get linear predictor for missing values
        eta_A = mu_A + B * b_B_A

        # sample imputation values for A
        # mask out to not add log_prob to total likelihood right now
        Aimp = sample(
            "A",
            dist.Bernoulli(logits=eta_A).mask(False),
            infer={"enumerate": "parallel"},
        )

        # 'manually' calculate the log_prob
        log_prob = dist.Bernoulli(logits=eta_A).log_prob(Aimp)

        # cancel out enumerated values that are not equal to observed values
        log_prob = jnp.where(A_isobs & (Aimp != A), -inf, log_prob)

        # add to total likelihood for sampler
        numpyro.factor("obs_A", log_prob)

        ### outcome model
        eta_Y = b_A * Aimp + b_B * B
        sample("obs_Y", dist.Normal(eta_Y, s_Y), obs=Y)

        ### missingness / observationmodel
        eta_Aobs = jnp.where(Aimp, p_Aobs[0], p_Aobs[1])
        sample("obs_Aobs", dist.Bernoulli(probs=eta_Aobs), obs=A_isobs)
[15]:
impmisskernel = NUTS(impmissmodel)
impmissmcmc = MCMC(impmisskernel, num_warmup=250, num_samples=750)
impmissmcmc.run(mcmc_key, Aobs, B, Y)
impmissmcmc.print_summary()
sample: 100%|██████████| 1000/1000 [00:09<00:00, 106.81it/s, 7 steps of size 2.86e-01. acc. prob=0.91]
                mean       std    median      5.0%     95.0%     n_eff     r_hat
       b_A      0.26      0.01      0.26      0.24      0.27    267.57      1.00
       b_B      0.25      0.01      0.25      0.24      0.26    537.10      1.00
     b_B_A      0.74      0.07      0.74      0.62      0.84    421.54      1.00
      mu_A     -0.45      0.08     -0.45     -0.58     -0.31    241.11      1.00
 p_Aobs[0]      0.10      0.01      0.10      0.09      0.11    451.90      1.00
 p_Aobs[1]      0.86      0.03      0.86      0.82      0.91    244.47      1.00
       s_Y      0.25      0.00      0.25      0.24      0.25    375.51      1.00

Number of divergences: 0

现在我们可以无偏地估计参数 b_A 和 b_B,同时仍然利用所有观测值。显然,对缺失机制进行建模依赖于需要通过先验证据证实或可能通过敏感性分析来分析的假设。

有关贝叶斯推断中缺失数据的更多阅读材料,请参见