交互式在线版本: Open In Colab

贝叶斯填充

现实世界的数据集通常包含许多缺失值。在这种情况下,我们必须删除这些缺失数据(也称为“完整案例”)或用一些值替换它们。虽然使用完整案例相当简单,但它仅适用于缺失条目数量很少,以至于丢弃这些条目不会对我们正在对数据进行的分析能力产生太大影响的情况。第二种策略,也称为填充,更适用,也是本教程的重点。

执行填充最流行的方法可能是用相应特征的均值、中位数或众数来填充缺失值。在这种情况下,我们隐式地假设包含缺失值的特征与数据集中的其余特征没有关联。这是一个相当强的假设,通常可能不成立。此外,它没有编码我们可能对这些值施加的任何不确定性。下面,我们将构建一个贝叶斯设置来解决这些问题。特别是,给定数据集上的一个模型,我们将

  • 为具有缺失值的特征创建生成模型

  • 并将缺失值视为未观测的潜变量。

[ ]:
!pip install -q numpyro@git+https://github.com/pyro-ppl/numpyro
[1]:
# first, we need some imports
import os

from IPython.display import set_matplotlib_formats
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd

from jax import numpy as jnp, random

import numpyro
from numpyro import distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive

plt.style.use("seaborn")
if "NUMPYRO_SPHINXBUILD" in os.environ:
    set_matplotlib_formats("svg")

assert numpyro.__version__.startswith("0.18.0")

数据集

数据取自 Titanic: Machine Learning from Disaster 竞赛,该竞赛托管在 kaggle 上。它包含 泰坦尼克号事故 中乘客的信息,例如姓名、年龄、性别等。我们的目标是预测一个人是否更有可能幸存。

[2]:
train_df = pd.read_csv(
    "https://raw.githubusercontent.com/agconti/kaggle-titanic/master/data/train.csv"
)
train_df.info()
train_df.head()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 891 entries, 0 to 890
Data columns (total 12 columns):
 #   Column       Non-Null Count  Dtype
---  ------       --------------  -----
 0   PassengerId  891 non-null    int64
 1   Survived     891 non-null    int64
 2   Pclass       891 non-null    int64
 3   Name         891 non-null    object
 4   Sex          891 non-null    object
 5   Age          714 non-null    float64
 6   SibSp        891 non-null    int64
 7   Parch        891 non-null    int64
 8   Ticket       891 non-null    object
 9   Fare         891 non-null    float64
 10  Cabin        204 non-null    object
 11  Embarked     889 non-null    object
dtypes: float64(2), int64(5), object(5)
memory usage: 83.7+ KB
[2]:
乘客ID 是否幸存 船舱等级 姓名 性别 年龄 同乘兄弟姐妹/配偶数 同乘父母/子女数 票号 票价 船舱号 登船港口
0 1 0 3 Braund, Mr. Owen Harris 男性 22.0 1 0 A/5 21171 7.2500 NaN S
1 2 1 1 Cumings, Mrs. John Bradley (Florence Briggs Th...) 女性 38.0 1 0 PC 17599 71.2833 C85 C
2 3 1 3 Heikkinen, Miss. Laina 女性 26.0 0 0 STON/O2. 3101282 7.9250 NaN S
3 4 1 1 Futrelle, Mrs. Jacques Heath (Lily May Peel) 女性 35.0 1 0 113803 53.1000 C123 S
4 5 0 3 Allen, Mr. William Henry 男性 35.0 0 0 373450 8.0500 NaN S

查看数据信息,我们知道 AgeCabinEmbarked 列有缺失数据。虽然 Cabin 是一个重要的特征(因为船舱的位置会影响船舱内人员的生存机会),但为了简单起见,本教程中我们将跳过它。数据集中有许多分类列和两个数值列 AgeFare。让我们先看看这些分类列的分布

[3]:
for col in ["Survived", "Pclass", "Sex", "SibSp", "Parch", "Embarked"]:
    print(train_df[col].value_counts(), end="\n\n")
0    549
1    342
Name: Survived, dtype: int64

3    491
1    216
2    184
Name: Pclass, dtype: int64

male      577
female    314
Name: Sex, dtype: int64

0    608
1    209
2     28
4     18
3     16
8      7
5      5
Name: SibSp, dtype: int64

0    678
1    118
2     80
3      5
5      5
4      4
6      1
Name: Parch, dtype: int64

S    644
C    168
Q     77
Name: Embarked, dtype: int64

准备数据

首先,我们将 SibSpParch 列中的稀有组合并。此外,我们将用众数 S 填充 Embarked 中的 2 个缺失条目。注意,我们可以为 Embarked 中的这些缺失条目建立一个生成模型,但为了简单起见,我们跳过这一步。

[4]:
train_df.SibSp.clip(0, 1, inplace=True)
train_df.Parch.clip(0, 2, inplace=True)
train_df.Embarked.fillna("S", inplace=True)

仔细观察数据,我们可以发现每个姓名都包含一个头衔。我们知道年龄与姓名的头衔相关:例如,带有 Mrs. 的人平均会比带有 Miss. 的人年龄大,因此创建这个特征可能会很有用。头衔的分布是

[5]:
train_df.Name.str.split(", ").str.get(1).str.split(" ").str.get(0).value_counts()
[5]:
Mr.          517
Miss.        182
Mrs.         125
Master.       40
Dr.            7
Rev.           6
Mlle.          2
Col.           2
Major.         2
Lady.          1
Sir.           1
the            1
Ms.            1
Capt.          1
Mme.           1
Jonkheer.      1
Don.           1
Name: Name, dtype: int64

我们将创建一个新列 Title,其中稀有头衔合并到一个组 Misc. 中。

[6]:
train_df["Title"] = (
    train_df.Name.str.split(", ")
    .str.get(1)
    .str.split(" ")
    .str.get(0)
    .apply(lambda x: x if x in ["Mr.", "Miss.", "Mrs.", "Master."] else "Misc.")
)

现在,可以将包含分类值的 dataframe 转换为 numpy 数组了。我们还对 Age 列进行了标准化处理(这是回归模型的好习惯)。

[7]:
title_cat = pd.CategoricalDtype(
    categories=["Mr.", "Miss.", "Mrs.", "Master.", "Misc."], ordered=True
)
embarked_cat = pd.CategoricalDtype(categories=["S", "C", "Q"], ordered=True)
age_mean, age_std = train_df.Age.mean(), train_df.Age.std()
data = dict(
    age=train_df.Age.pipe(lambda x: (x - age_mean) / age_std).values,
    pclass=train_df.Pclass.values - 1,
    title=train_df.Title.astype(title_cat).cat.codes.values,
    sex=(train_df.Sex == "male").astype(int).values,
    sibsp=train_df.SibSp.values,
    parch=train_df.Parch.values,
    embarked=train_df.Embarked.astype(embarked_cat).cat.codes.values,
)
survived = train_df.Survived.values
# compute the age mean for each title
age_notnan = data["age"][jnp.isfinite(data["age"])]
title_notnan = data["title"][jnp.isfinite(data["age"])]
age_mean_by_title = jnp.stack([age_notnan[title_notnan == i].mean() for i in range(5)])

建模

首先,我们要注意在 NumPyro 中,以下模型

def model1a():
    x = numpyro.sample("x", dist.Normal(0, 1).expand([10]))

def model1b():
    x = numpyro.sample("x", dist.Normal(0, 1).expand([10]).mask(False))
    numpyro.sample("x_obs", dist.Normal(0, 1).expand([10]), obs=x)

是等价的,因为它们都具有

  • dist.Normal(0, 1) 先验中抽取的相同潜变量点 x

  • 以及相同的对数密度 dist.Normal(0, 1).log_prob(x)

现在,假设我们观测到 x 的最后 6 个值(未观测的条目取值为 NaN),典型的模型将是

def model2a(x):
    x_impute = numpyro.sample("x_impute", dist.Normal(0, 1).expand([4]))
    x_obs = numpyro.sample("x_obs", dist.Normal(0, 1).expand([6]), obs=x[4:])
    x_imputed = jnp.concatenate([x_impute, x_obs])

或者使用 mask

def model2b(x):
    x_impute = numpyro.sample("x_impute", dist.Normal(0, 1).expand([4]).mask(False))
    x_imputed = jnp.concatenate([x_impute, x[4:]])
    numpyro.sample("x", dist.Normal(0, 1).expand([10]), obs=x_imputed)

这两种对部分观测数据 x 进行建模的方法是等价的。对于下面的模型,我们将使用后一种方法。

[8]:
def model(
    age, pclass, title, sex, sibsp, parch, embarked, survived=None, bayesian_impute=True
):
    b_pclass = numpyro.sample("b_Pclass", dist.Normal(0, 1).expand([3]))
    b_title = numpyro.sample("b_Title", dist.Normal(0, 1).expand([5]))
    b_sex = numpyro.sample("b_Sex", dist.Normal(0, 1).expand([2]))
    b_sibsp = numpyro.sample("b_SibSp", dist.Normal(0, 1).expand([2]))
    b_parch = numpyro.sample("b_Parch", dist.Normal(0, 1).expand([3]))
    b_embarked = numpyro.sample("b_Embarked", dist.Normal(0, 1).expand([3]))

    # impute age by Title
    isnan = np.isnan(age)
    age_nanidx = np.nonzero(isnan)[0]
    if bayesian_impute:
        age_mu = numpyro.sample("age_mu", dist.Normal(0, 1).expand([5]))
        age_mu = age_mu[title]
        age_sigma = numpyro.sample("age_sigma", dist.Normal(0, 1).expand([5]))
        age_sigma = age_sigma[title]
        age_impute = numpyro.sample(
            "age_impute",
            dist.Normal(age_mu[age_nanidx], age_sigma[age_nanidx]).mask(False),
        )
        age = jnp.asarray(age).at[age_nanidx].set(age_impute)
        numpyro.sample("age", dist.Normal(age_mu, age_sigma), obs=age)
    else:
        # fill missing data by the mean of ages for each title
        age_impute = age_mean_by_title[title][age_nanidx]
        age = jnp.asarray(age).at[age_nanidx].set(age_impute)

    a = numpyro.sample("a", dist.Normal(0, 1))
    b_age = numpyro.sample("b_Age", dist.Normal(0, 1))
    logits = a + b_age * age
    logits = logits + b_title[title] + b_pclass[pclass] + b_sex[sex]
    logits = logits + b_sibsp[sibsp] + b_parch[parch] + b_embarked[embarked]
    numpyro.sample("survived", dist.Bernoulli(logits=logits), obs=survived)

注意,在模型中,age 的先验是 dist.Normal(age_mu, age_sigma),其中 age_muage_sigma 的值取决于 title。由于 age 中有缺失值,我们将把这些缺失值编码在潜参数 age_impute 中。然后,我们可以用向量 age_impute 替换 age 中的 NaN 条目。

抽样

我们将使用带有 NUTS 核的 MCMC 来对回归系数和填充值进行抽样。

[9]:
mcmc = MCMC(NUTS(model), num_warmup=1000, num_samples=1000)
mcmc.run(random.PRNGKey(0), **data, survived=survived)
mcmc.print_summary()
sample: 100%|██████████| 2000/2000 [00:15<00:00, 132.15it/s, 63 steps of size 5.68e-02. acc. prob=0.95]
                     mean       std    median      5.0%     95.0%     n_eff     r_hat
              a      0.12      0.82      0.11     -1.21      1.49    887.50      1.00
  age_impute[0]      0.20      0.84      0.18     -1.22      1.53   1346.09      1.00
  age_impute[1]     -0.06      0.86     -0.08     -1.41      1.26   1057.70      1.00
  age_impute[2]      0.38      0.73      0.39     -0.80      1.58   1570.36      1.00
  age_impute[3]      0.25      0.84      0.23     -0.99      1.86   1027.43      1.00
  age_impute[4]     -0.63      0.91     -0.59     -1.99      0.87   1183.66      1.00
  age_impute[5]      0.21      0.89      0.19     -1.02      1.97   1456.79      1.00
  age_impute[6]      0.45      0.82      0.46     -0.90      1.73   1239.22      1.00
  age_impute[7]     -0.62      0.86     -0.62     -2.13      0.72   1406.09      1.00
  age_impute[8]     -0.13      0.90     -0.14     -1.64      1.38   1905.07      1.00
  age_impute[9]      0.24      0.84      0.26     -1.06      1.77   1471.12      1.00
 age_impute[10]      0.20      0.89      0.21     -1.26      1.65   1588.79      1.00
 age_impute[11]      0.17      0.91      0.19     -1.59      1.48   1446.52      1.00
 age_impute[12]     -0.65      0.89     -0.68     -2.12      0.77   1457.47      1.00
 age_impute[13]      0.21      0.85      0.18     -1.24      1.53   1057.77      1.00
 age_impute[14]      0.05      0.92      0.05     -1.40      1.65   1207.08      1.00
 age_impute[15]      0.37      0.94      0.37     -1.02      1.98   1326.55      1.00
 age_impute[16]     -1.74      0.26     -1.74     -2.13     -1.32   1320.08      1.00
 age_impute[17]      0.21      0.89      0.22     -1.30      1.60   1545.73      1.00
 age_impute[18]      0.18      0.90      0.18     -1.26      1.58   2013.12      1.00
 age_impute[19]     -0.67      0.86     -0.66     -1.97      0.85   1499.50      1.00
 age_impute[20]      0.23      0.89      0.27     -1.19      1.71   1712.24      1.00
 age_impute[21]      0.21      0.87      0.20     -1.11      1.68   1400.55      1.00
 age_impute[22]      0.19      0.90      0.18     -1.26      1.63   1400.37      1.00
 age_impute[23]     -0.15      0.85     -0.15     -1.57      1.24   1205.10      1.00
 age_impute[24]     -0.71      0.89     -0.73     -2.05      0.82   1085.52      1.00
 age_impute[25]      0.20      0.85      0.19     -1.20      1.62   1708.01      1.00
 age_impute[26]      0.21      0.88      0.21     -1.20      1.68   1363.75      1.00
 age_impute[27]     -0.69      0.91     -0.73     -2.20      0.77   1224.06      1.00
 age_impute[28]      0.60      0.77      0.60     -0.61      1.95   1312.44      1.00
 age_impute[29]      0.20      0.89      0.17     -1.23      1.71    938.19      1.00
 age_impute[30]      0.24      0.87      0.23     -1.14      1.60   1324.50      1.00
 age_impute[31]     -1.72      0.26     -1.72     -2.11     -1.28   1425.46      1.00
 age_impute[32]      0.44      0.77      0.43     -0.83      1.58   1587.41      1.00
 age_impute[33]      0.34      0.89      0.32     -1.14      1.73   1375.14      1.00
 age_impute[34]     -1.72      0.26     -1.71     -2.11     -1.26   1007.71      1.00
 age_impute[35]     -0.45      0.90     -0.47     -2.06      0.92   1329.44      1.00
 age_impute[36]      0.30      0.84      0.30     -1.03      1.73   1080.80      1.00
 age_impute[37]      0.33      0.88      0.32     -1.10      1.81   1033.30      1.00
 age_impute[38]      0.33      0.76      0.35     -0.94      1.56   1550.68      1.00
 age_impute[39]      0.19      0.93      0.21     -1.32      1.82   1203.79      1.00
 age_impute[40]     -0.67      0.88     -0.69     -1.94      0.88   1382.98      1.00
 age_impute[41]      0.17      0.89      0.14     -1.30      1.43   1438.18      1.00
 age_impute[42]      0.23      0.82      0.25     -1.12      1.48   1499.59      1.00
 age_impute[43]      0.22      0.82      0.21     -1.19      1.45   1236.67      1.00
 age_impute[44]     -0.41      0.85     -0.42     -1.96      0.78    812.53      1.00
 age_impute[45]     -0.36      0.89     -0.35     -2.01      0.94   1488.83      1.00
 age_impute[46]     -0.33      0.91     -0.32     -1.76      1.27   1628.61      1.00
 age_impute[47]     -0.71      0.85     -0.69     -2.12      0.64   1363.89      1.00
 age_impute[48]      0.21      0.85      0.24     -1.21      1.64   1552.65      1.00
 age_impute[49]      0.42      0.82      0.41     -0.83      1.77    754.08      1.00
 age_impute[50]      0.26      0.86      0.24     -1.18      1.63   1155.49      1.00
 age_impute[51]     -0.29      0.91     -0.30     -1.83      1.15   1212.08      1.00
 age_impute[52]      0.36      0.85      0.34     -1.12      1.68   1190.99      1.00
 age_impute[53]     -0.68      0.89     -0.65     -2.09      0.75   1104.75      1.00
 age_impute[54]      0.27      0.90      0.25     -1.24      1.68   1331.19      1.00
 age_impute[55]      0.36      0.89      0.36     -0.96      1.86   1917.52      1.00
 age_impute[56]      0.38      0.86      0.40     -1.00      1.75   1862.00      1.00
 age_impute[57]      0.01      0.91      0.03     -1.33      1.56   1285.43      1.00
 age_impute[58]     -0.69      0.91     -0.66     -2.13      0.78   1438.41      1.00
 age_impute[59]     -0.14      0.85     -0.16     -1.44      1.37   1135.79      1.00
 age_impute[60]     -0.59      0.94     -0.61     -2.19      0.93   1222.88      1.00
 age_impute[61]      0.24      0.92      0.25     -1.35      1.65   1341.95      1.00
 age_impute[62]     -0.55      0.91     -0.57     -2.01      0.96    753.85      1.00
 age_impute[63]      0.21      0.90      0.19     -1.42      1.60   1238.50      1.00
 age_impute[64]     -0.66      0.88     -0.68     -2.04      0.73   1214.85      1.00
 age_impute[65]      0.44      0.78      0.48     -0.93      1.57   1174.41      1.00
 age_impute[66]      0.22      0.94      0.20     -1.35      1.69   1910.00      1.00
 age_impute[67]      0.33      0.76      0.34     -0.85      1.63   1210.24      1.00
 age_impute[68]      0.31      0.84      0.33     -1.08      1.60   1756.60      1.00
 age_impute[69]      0.26      0.91      0.25     -1.29      1.75   1155.87      1.00
 age_impute[70]     -0.67      0.86     -0.70     -2.02      0.70   1186.22      1.00
 age_impute[71]     -0.70      0.90     -0.69     -2.21      0.75   1469.35      1.00
 age_impute[72]      0.24      0.86      0.24     -1.07      1.66   1604.16      1.00
 age_impute[73]      0.34      0.72      0.35     -0.77      1.55   1144.55      1.00
 age_impute[74]     -0.64      0.85     -0.64     -2.10      0.77   1513.79      1.00
 age_impute[75]      0.41      0.78      0.42     -0.96      1.60    796.47      1.00
 age_impute[76]      0.18      0.89      0.21     -1.19      1.74    755.44      1.00
 age_impute[77]      0.21      0.84      0.22     -1.22      1.63   1371.73      1.00
 age_impute[78]     -0.36      0.87     -0.33     -1.81      1.01   1017.23      1.00
 age_impute[79]      0.20      0.84      0.19     -1.35      1.37   1677.57      1.00
 age_impute[80]      0.23      0.84      0.24     -1.09      1.61   1545.61      1.00
 age_impute[81]      0.28      0.90      0.32     -1.08      1.83   1735.91      1.00
 age_impute[82]      0.61      0.80      0.60     -0.61      2.03   1353.67      1.00
 age_impute[83]      0.24      0.89      0.26     -1.22      1.66   1165.03      1.00
 age_impute[84]      0.21      0.91      0.21     -1.35      1.65   1584.00      1.00
 age_impute[85]      0.24      0.92      0.21     -1.33      1.63   1271.37      1.00
 age_impute[86]      0.31      0.81      0.30     -0.86      1.76   1198.70      1.00
 age_impute[87]     -0.11      0.84     -0.10     -1.42      1.23   1248.38      1.00
 age_impute[88]      0.21      0.94      0.22     -1.31      1.77   1082.82      1.00
 age_impute[89]      0.24      0.86      0.23     -1.08      1.67   2141.98      1.00
 age_impute[90]      0.41      0.84      0.45     -0.88      1.90   1518.73      1.00
 age_impute[91]      0.21      0.86      0.20     -1.21      1.58   1723.50      1.00
 age_impute[92]      0.21      0.84      0.20     -1.21      1.57   1742.44      1.00
 age_impute[93]      0.22      0.87      0.23     -1.29      1.50   1359.74      1.00
 age_impute[94]      0.22      0.87      0.18     -1.09      1.70    906.55      1.00
 age_impute[95]      0.22      0.87      0.23     -1.16      1.65   1112.58      1.00
 age_impute[96]      0.30      0.84      0.26     -1.18      1.57   1680.70      1.00
 age_impute[97]      0.23      0.87      0.25     -1.22      1.63   1408.40      1.00
 age_impute[98]     -0.36      0.91     -0.37     -1.96      1.03   1083.67      1.00
 age_impute[99]      0.15      0.87      0.14     -1.22      1.61   1644.46      1.00
age_impute[100]      0.27      0.85      0.30     -1.27      1.45   1266.96      1.00
age_impute[101]      0.25      0.87      0.25     -1.19      1.57   1220.96      1.00
age_impute[102]     -0.29      0.85     -0.28     -1.70      1.10   1392.91      1.00
age_impute[103]      0.01      0.89      0.01     -1.46      1.39   1137.34      1.00
age_impute[104]      0.21      0.86      0.24     -1.16      1.64   1018.70      1.00
age_impute[105]      0.24      0.93      0.21     -1.14      1.90   1479.67      1.00
age_impute[106]      0.21      0.83      0.21     -1.09      1.55   1471.11      1.00
age_impute[107]      0.22      0.85      0.22     -1.09      1.64   1941.83      1.00
age_impute[108]      0.31      0.88      0.30     -1.10      1.76   1342.10      1.00
age_impute[109]      0.22      0.86      0.23     -1.25      1.56   1198.01      1.00
age_impute[110]      0.33      0.78      0.35     -0.95      1.62   1267.01      1.00
age_impute[111]      0.22      0.88      0.21     -1.11      1.71   1404.51      1.00
age_impute[112]     -0.03      0.90     -0.02     -1.38      1.55   1625.35      1.00
age_impute[113]      0.24      0.85      0.23     -1.17      1.62   1361.84      1.00
age_impute[114]      0.36      0.86      0.37     -0.99      1.76   1155.67      1.00
age_impute[115]      0.26      0.96      0.28     -1.37      1.81   1245.97      1.00
age_impute[116]      0.21      0.86      0.24     -1.18      1.69   1565.59      1.00
age_impute[117]     -0.31      0.94     -0.33     -1.91      1.19   1593.65      1.00
age_impute[118]      0.21      0.87      0.22     -1.20      1.64   1315.42      1.00
age_impute[119]     -0.69      0.88     -0.74     -2.00      0.90   1536.44      1.00
age_impute[120]      0.63      0.81      0.66     -0.65      1.89    899.61      1.00
age_impute[121]      0.27      0.90      0.26     -1.16      1.74   1744.32      1.00
age_impute[122]      0.18      0.87      0.18     -1.23      1.60   1625.58      1.00
age_impute[123]     -0.39      0.88     -0.38     -1.71      1.12   1266.58      1.00
age_impute[124]     -0.62      0.95     -0.63     -2.03      1.01   1600.28      1.00
age_impute[125]      0.23      0.88      0.23     -1.15      1.71   1604.27      1.00
age_impute[126]      0.18      0.91      0.18     -1.24      1.63   1527.38      1.00
age_impute[127]      0.32      0.85      0.36     -1.08      1.73   1074.98      1.00
age_impute[128]      0.25      0.88      0.25     -1.10      1.69   1486.79      1.00
age_impute[129]     -0.70      0.87     -0.68     -2.20      0.56   1506.55      1.00
age_impute[130]      0.21      0.88      0.20     -1.16      1.68   1451.63      1.00
age_impute[131]      0.22      0.87      0.23     -1.22      1.61    905.86      1.00
age_impute[132]      0.33      0.83      0.33     -1.01      1.66   1517.67      1.00
age_impute[133]      0.18      0.86      0.18     -1.19      1.59   1050.00      1.00
age_impute[134]     -0.14      0.92     -0.15     -1.77      1.24   1386.20      1.00
age_impute[135]      0.19      0.85      0.18     -1.22      1.53   1290.94      1.00
age_impute[136]      0.16      0.92      0.16     -1.35      1.74   1767.36      1.00
age_impute[137]     -0.71      0.90     -0.68     -2.24      0.82   1154.14      1.00
age_impute[138]      0.18      0.91      0.16     -1.30      1.67   1160.90      1.00
age_impute[139]      0.24      0.90      0.24     -1.15      1.76   1289.37      1.00
age_impute[140]      0.41      0.80      0.39     -1.05      1.53   1532.92      1.00
age_impute[141]      0.27      0.83      0.29     -1.04      1.60   1310.29      1.00
age_impute[142]     -0.28      0.89     -0.29     -1.68      1.22   1088.65      1.00
age_impute[143]     -0.12      0.91     -0.11     -1.56      1.40   1324.74      1.00
age_impute[144]     -0.65      0.87     -0.63     -1.91      0.93   1672.31      1.00
age_impute[145]     -1.73      0.26     -1.74     -2.11     -1.26   1502.96      1.00
age_impute[146]      0.40      0.85      0.40     -0.85      1.84   1443.81      1.00
age_impute[147]      0.23      0.87      0.20     -1.37      1.49   1220.62      1.00
age_impute[148]     -0.70      0.88     -0.70     -2.08      0.87   1846.67      1.00
age_impute[149]      0.27      0.87      0.29     -1.11      1.76   1451.79      1.00
age_impute[150]      0.21      0.90      0.20     -1.10      1.78   1409.94      1.00
age_impute[151]      0.25      0.87      0.26     -1.21      1.63   1224.08      1.00
age_impute[152]      0.05      0.85      0.05     -1.42      1.39   1164.23      1.00
age_impute[153]      0.18      0.90      0.15     -1.19      1.72   1697.92      1.00
age_impute[154]      1.05      0.93      1.04     -0.24      2.84   1212.82      1.00
age_impute[155]      0.20      0.84      0.18     -1.18      1.54   1398.45      1.00
age_impute[156]      0.23      0.95      0.19     -1.19      1.87   1773.79      1.00
age_impute[157]      0.19      0.85      0.22     -1.13      1.64   1123.21      1.00
age_impute[158]      0.22      0.86      0.22     -1.18      1.60   1307.64      1.00
age_impute[159]      0.18      0.84      0.18     -1.09      1.59   1499.97      1.00
age_impute[160]      0.24      0.89      0.28     -1.23      1.65   1100.08      1.00
age_impute[161]     -0.45      0.88     -0.45     -1.86      1.05   1414.97      1.00
age_impute[162]      0.39      0.89      0.40     -1.00      1.87   1525.80      1.00
age_impute[163]      0.34      0.89      0.35     -1.14      1.75   1600.03      1.00
age_impute[164]      0.21      0.94      0.19     -1.13      1.91   1090.05      1.00
age_impute[165]      0.22      0.85      0.20     -1.11      1.60   1330.87      1.00
age_impute[166]     -0.13      0.91     -0.15     -1.69      1.28   1284.90      1.00
age_impute[167]      0.22      0.89      0.24     -1.15      1.76   1261.93      1.00
age_impute[168]      0.20      0.90      0.18     -1.18      1.83   1217.16      1.00
age_impute[169]      0.07      0.89      0.05     -1.29      1.60   2007.16      1.00
age_impute[170]      0.23      0.90      0.24     -1.25      1.67    937.57      1.00
age_impute[171]      0.41      0.80      0.42     -0.82      1.82   1404.02      1.00
age_impute[172]      0.23      0.87      0.20     -1.33      1.51   2032.72      1.00
age_impute[173]     -0.44      0.88     -0.44     -1.81      1.08   1006.62      1.00
age_impute[174]      0.19      0.84      0.19     -1.11      1.63   1495.21      1.00
age_impute[175]      0.20      0.85      0.20     -1.17      1.63   1551.22      1.00
age_impute[176]     -0.43      0.92     -0.44     -1.83      1.21   1477.58      1.00
      age_mu[0]      0.19      0.04      0.19      0.12      0.26    749.16      1.00
      age_mu[1]     -0.54      0.07     -0.54     -0.66     -0.42    786.30      1.00
      age_mu[2]      0.43      0.08      0.42      0.31      0.55   1134.72      1.00
      age_mu[3]     -1.73      0.04     -1.73     -1.79     -1.65   1194.53      1.00
      age_mu[4]      0.85      0.17      0.85      0.58      1.13   1111.96      1.00
   age_sigma[0]      0.88      0.03      0.88      0.82      0.93    766.67      1.00
   age_sigma[1]      0.90      0.06      0.90      0.81      0.99    992.72      1.00
   age_sigma[2]      0.79      0.05      0.78      0.71      0.87    708.34      1.00
   age_sigma[3]      0.26      0.03      0.25      0.20      0.31    959.62      1.00
   age_sigma[4]      0.93      0.13      0.93      0.74      1.15   1092.88      1.00
          b_Age     -0.45      0.14     -0.44     -0.66     -0.22    744.95      1.00
  b_Embarked[0]     -0.28      0.58     -0.30     -1.28      0.64    496.51      1.00
  b_Embarked[1]      0.30      0.60      0.29     -0.74      1.20    495.25      1.00
  b_Embarked[2]      0.04      0.61      0.03     -0.93      1.02    482.67      1.00
     b_Parch[0]      0.45      0.57      0.47     -0.45      1.42    336.02      1.02
     b_Parch[1]      0.12      0.58      0.14     -0.91      1.00    377.61      1.02
     b_Parch[2]     -0.49      0.58     -0.45     -1.48      0.41    358.61      1.01
    b_Pclass[0]      1.22      0.57      1.24      0.33      2.17    371.15      1.00
    b_Pclass[1]      0.06      0.57      0.07     -0.84      1.03    369.58      1.00
    b_Pclass[2]     -1.18      0.57     -1.16     -2.18     -0.31    373.55      1.00
       b_Sex[0]      1.15      0.74      1.18     -0.03      2.31    568.65      1.00
       b_Sex[1]     -1.05      0.74     -1.02     -2.18      0.21    709.29      1.00
     b_SibSp[0]      0.28      0.66      0.26     -0.86      1.25    585.03      1.00
     b_SibSp[1]     -0.17      0.67     -0.18     -1.28      0.87    596.44      1.00
     b_Title[0]     -0.94      0.54     -0.96     -1.86     -0.11    437.32      1.00
     b_Title[1]     -0.33      0.61     -0.33     -1.32      0.60    570.32      1.00
     b_Title[2]      0.53      0.62      0.53     -0.52      1.46    452.87      1.00
     b_Title[3]      1.48      0.59      1.48      0.60      2.48    562.71      1.00
     b_Title[4]     -0.68      0.58     -0.66     -1.71      0.15    472.57      1.00

Number of divergences: 0

为了再次确认“年龄与头衔相关”的假设是合理的,让我们看看按头衔推断出的年龄。回想一下,我们对 age 进行了标准化处理,因此这里需要将其缩放回原始范围。

[10]:
age_by_title = age_mean + age_std * mcmc.get_samples()["age_mu"].mean(axis=0)
dict(zip(title_cat.categories, age_by_title))
[10]:
{'Mr.': 32.434227,
 'Miss.': 21.763992,
 'Mrs.': 35.852997,
 'Master.': 4.6297398,
 'Misc.': 42.081936}

推断结果证实了我们的假设,即 AgeTitle 相关

  • 带有 Master. 头衔的人年龄非常小(换句话说,他们在船上是孩子),与其他组相比,

  • 带有 Mrs. 头衔的人比带有 Miss. 头衔的人年龄大(平均而言)。

我们还可以看到,结果与我们训练数据集中给定 TitleAge 的实际统计均值相似

[11]:
train_df.groupby("Title")["Age"].mean()
[11]:
Title
Master.     4.574167
Misc.      42.384615
Miss.      21.773973
Mr.        32.368090
Mrs.       35.898148
Name: Age, dtype: float64

到目前为止一切顺利,我们获得了关于回归系数、填充值及其不确定性的许多信息。让我们稍微检查一下这些结果

  • b_Age 的均值 -0.44 意味着年龄较小的人有更好的生存机会。

  • b_Sex 的均值 (1.11, -1.07) 意味着女性乘客比男性乘客有更高的生存机会。

预测

在 NumPyro 中,我们可以使用 Predictive 工具从后验样本中进行预测。让我们检查一下模型在训练数据集上的表现如何。为简单起见,我们将对每个后验样本获取一个 survived 预测,并对预测执行多数规则。

[12]:
posterior = mcmc.get_samples()
survived_pred = Predictive(model, posterior)(random.PRNGKey(1), **data)["survived"]
survived_pred = (survived_pred.mean(axis=0) >= 0.5).astype(jnp.uint8)
print("Accuracy:", (survived_pred == survived).sum() / survived.shape[0])
confusion_matrix = pd.crosstab(
    pd.Series(survived, name="actual"), pd.Series(survived_pred, name="predict")
)
confusion_matrix / confusion_matrix.sum(axis=1)
Accuracy: 0.8271605
[12]:
预测 0 1
实际
0 0.876138 0.198830
1 0.156648 0.748538

这是一个使用简单逻辑回归模型获得的相当好的结果。让我们看看如果在这里不使用贝叶斯填充,模型的表现如何。

[13]:
mcmc.run(random.PRNGKey(2), **data, survived=survived, bayesian_impute=False)
posterior_1 = mcmc.get_samples()
survived_pred_1 = Predictive(model, posterior_1)(random.PRNGKey(2), **data)["survived"]
survived_pred_1 = (survived_pred_1.mean(axis=0) >= 0.5).astype(jnp.uint8)
print("Accuracy:", (survived_pred_1 == survived).sum() / survived.shape[0])
confusion_matrix = pd.crosstab(
    pd.Series(survived, name="actual"), pd.Series(survived_pred_1, name="predict")
)
confusion_matrix / confusion_matrix.sum(axis=1)
confusion_matrix = pd.crosstab(
    pd.Series(survived, name="actual"), pd.Series(survived_pred_1, name="predict")
)
confusion_matrix / confusion_matrix.sum(axis=1)
sample: 100%|██████████| 2000/2000 [00:11<00:00, 166.79it/s, 63 steps of size 7.18e-02. acc. prob=0.93]
Accuracy: 0.82042646
[13]:
预测 0 1
实际
0 0.872495 0.204678
1 0.163934 0.736842

我们可以看到贝叶斯填充在这里表现稍好一些。

注意。 当使用 posterior 样本对新数据进行预测时,我们需要对 age_impute 进行边际化,因为这些填充值是特定于训练数据的

posterior.pop("age_impute")
survived_pred = Predictive(model, posterior)(random.PRNGKey(3), **new_data)

参考文献

  1. McElreath, R. (2016). Statistical Rethinking: A Bayesian Course with Examples in R and Stan.

  2. Kaggle 竞赛:Titanic: Machine Learning from Disaster