序数回归
有些数据是离散的,但本质上是有序的,这被称为序数数据。一个例子是问卷中的李克特量表(“这是一个有用的教程”:1. 非常不同意 / 2. 不同意 / 3. 不确定 / 4. 同意 / 5. 非常同意)。序数数据在医学领域也无处不在(例如用于测量神经功能障碍的格拉斯哥昏迷量表)。
这对统计建模提出了挑战,因为这些数据不符合最常用的建模方法(例如线性回归)。将数据建模为分类数据是一种可能性,但这忽略了数据固有的排序,并且在统计上效率可能较低。有多种方法可以对有序数据进行建模。在这里,我们将展示如何使用 OrderedLogistic 分布,其切点是从非正常先验、正态分布以及通过 Dirichlet 分布的类别概率推导出来的。关于序数数据的贝叶斯建模的更深入讨论,请参阅例如Michael Betancourt 的序数回归案例研究
参考文献
Betancourt, M. (2019), “序数回归”, (https://betanalpha.github.io/assets/case_studies/ordinal_regression.html)
[1]:
# !pip install -q numpyro@git+https://github.com/pyro-ppl/numpyro
[2]:
import pandas as pd
import seaborn as sns
from jax import numpy as np, random
import numpyro
from numpyro import handlers, sample
from numpyro.distributions import (
Categorical,
Dirichlet,
ImproperUniform,
Normal,
OrderedLogistic,
TransformedDistribution,
constraints,
transforms,
)
from numpyro.infer import MCMC, NUTS
from numpyro.infer.reparam import TransformReparam
assert numpyro.__version__.startswith("0.18.0")
数据生成
首先,生成一些具有序数结构的数据
[3]:
simkeys = random.split(random.PRNGKey(1), 2)
nsim = 50
nclasses = 3
Y = Categorical(logits=np.zeros(nclasses)).sample(simkeys[0], sample_shape=(nsim,))
X = Normal().sample(simkeys[1], sample_shape=(nsim,))
X += Y
print("value counts of Y:")
df = pd.DataFrame({"X": X, "Y": Y})
print(df.Y.value_counts())
for i in range(nclasses):
print(f"mean(X) for Y == {i}: {X[np.where(Y == i)].mean():.3f}")
value counts of Y:
1 19
2 16
0 15
Name: Y, dtype: int64
mean(X) for Y == 0: 0.042
mean(X) for Y == 1: 0.832
mean(X) for Y == 2: 1.448
[4]:
sns.violinplot(x="Y", y="X", data=df);

非正常先验
我们将结果 Y 建模为在 X 条件下来自 OrderedLogistic 分布。numpyro 中的 OrderedLogistic
分布需要有序的切点。我们可以使用 ImproperUniform
分布引入一个具有任意支持但完全无信息的参数,然后添加 ordered_vector
约束。
[5]:
def model1(X, Y, nclasses=3):
b_X_eta = sample("b_X_eta", Normal(0, 5))
c_y = sample(
"c_y",
ImproperUniform(
support=constraints.ordered_vector,
batch_shape=(),
event_shape=(nclasses - 1,),
),
)
with numpyro.plate("obs", X.shape[0]):
eta = X * b_X_eta
sample("Y", OrderedLogistic(eta, c_y), obs=Y)
mcmc_key = random.PRNGKey(1234)
kernel = NUTS(model1)
mcmc = MCMC(kernel, num_warmup=250, num_samples=750)
mcmc.run(mcmc_key, X, Y, nclasses)
mcmc.print_summary()
sample: 100%|██████████| 1000/1000 [00:03<00:00, 258.56it/s, 7 steps of size 5.02e-01. acc. prob=0.94]
mean std median 5.0% 95.0% n_eff r_hat
b_X_eta 1.44 0.37 1.42 0.83 2.05 349.38 1.01
c_y[0] -0.10 0.38 -0.10 -0.71 0.51 365.63 1.00
c_y[1] 2.15 0.49 2.13 1.38 2.99 376.45 1.01
Number of divergences: 0
ImproperUniform
分布使我们能够使用定义域受限的参数,而无需添加任何额外信息,例如关于该参数的先验分布的位置或尺度信息。
如果我们想纳入这些信息,例如切点的值不应离零太远,我们可以添加一个额外的 sample
语句,该语句使用另一个先验,并结合一个 obs
参数。在下面的示例中,我们首先像之前一样从带有 constraints.ordered_vector
的 ImproperUniform
分布中抽取切点 c_y
,然后从 Normal
分布中 sample
一个虚拟参数,同时使用 obs=c_y
对 c_y
进行条件限制。实际上,我们创建了一个非正常/未归一化的先验,该先验是通过将 Normal
分布的支持限制到有序域而产生的。
[6]:
def model2(X, Y, nclasses=3):
b_X_eta = sample("b_X_eta", Normal(0, 5))
c_y = sample(
"c_y",
ImproperUniform(
support=constraints.ordered_vector,
batch_shape=(),
event_shape=(nclasses - 1,),
),
)
sample("c_y_smp", Normal(0, 1), obs=c_y)
with numpyro.plate("obs", X.shape[0]):
eta = X * b_X_eta
sample("Y", OrderedLogistic(eta, c_y), obs=Y)
kernel = NUTS(model2)
mcmc = MCMC(kernel, num_warmup=250, num_samples=750)
mcmc.run(mcmc_key, X, Y, nclasses)
mcmc.print_summary()
sample: 100%|██████████| 1000/1000 [00:03<00:00, 256.41it/s, 7 steps of size 5.31e-01. acc. prob=0.92]
mean std median 5.0% 95.0% n_eff r_hat
b_X_eta 1.23 0.31 1.23 0.64 1.68 501.31 1.01
c_y[0] -0.24 0.34 -0.23 -0.76 0.38 492.91 1.00
c_y[1] 1.77 0.40 1.76 1.11 2.42 628.46 1.00
Number of divergences: 0
正常先验
如果对这些切点 c_y
设置一个正常先验是可取的(例如从该先验中采样并获得先验预测),我们可以如下使用带有OrderedTransform变换的TransformedDistribution。
[7]:
def model3(X, Y, nclasses=3):
b_X_eta = sample("b_X_eta", Normal(0, 5))
c_y = sample(
"c_y",
TransformedDistribution(
Normal(0, 1).expand([nclasses - 1]), transforms.OrderedTransform()
),
)
with numpyro.plate("obs", X.shape[0]):
eta = X * b_X_eta
sample("Y", OrderedLogistic(eta, c_y), obs=Y)
kernel = NUTS(model3)
mcmc = MCMC(kernel, num_warmup=250, num_samples=750)
mcmc.run(mcmc_key, X, Y, nclasses)
mcmc.print_summary()
sample: 100%|██████████| 1000/1000 [00:04<00:00, 244.78it/s, 7 steps of size 5.54e-01. acc. prob=0.93]
mean std median 5.0% 95.0% n_eff r_hat
b_X_eta 1.40 0.34 1.41 0.86 1.98 300.30 1.03
c_y[0] -0.03 0.35 -0.03 -0.57 0.54 395.98 1.00
c_y[1] 2.06 0.47 2.04 1.26 2.83 475.16 1.01
Number of divergences: 0
使用 Dirichlet 分布的原则性先验
将我们的专业知识应用于潜在空间中的切点并非易事(尤其是在应用 OrderedTransform 之前必须提供先验的情况下)。
自然的倾向是对序数概率应用 Dirichlet 先验模型。我们将遵循 M.Betancourt 的建议([1],第 2.2 节),并使用Dirichlet先验模型通过SimplexToOrderedTransform间接引入切点。当需要向我们的序数模型添加强大的先验知识时,这种方法应该是有利的,例如,当我们的数据集中缺少某个类别时,或者当某些类别严重分离(导致切点不可识别)时。此外,这种参数化允许我们对模型进行采样并进行先验预测检查(与使用ImproperUniform
的model1
不同)。
我们可以直接从 TransformedDistribution(Dirichlet(concentration),transforms.SimplexToOrderedTransform(anchor_point))
中抽样切点。但是,如果我们在 reparam handler
上下文中使用变换,我们不仅可以捕获导出的切点,还可以捕获由 concentration
参数隐含的抽样的序数概率。anchor_point
是一个干扰参数,用于提高我们变换的可识别性(详情请参阅 [1],第 2.2 节)。
请注意,我们不能单独比较各种模型中的潜在切点或 b_X_eta,因为它们是内在关联的。
[8]:
# We will apply a nudge towards equal probability for each category
# (corresponds to equal logits of the true data generating process)
concentration = np.ones((nclasses,)) * 10.0
[9]:
def model4(X, Y, nclasses, concentration, anchor_point=0.0):
b_X_eta = sample("b_X_eta", Normal(0, 5))
with handlers.reparam(config={"c_y": TransformReparam()}):
c_y = sample(
"c_y",
TransformedDistribution(
Dirichlet(concentration),
transforms.SimplexToOrderedTransform(anchor_point),
),
)
with numpyro.plate("obs", X.shape[0]):
eta = X * b_X_eta
sample("Y", OrderedLogistic(eta, c_y), obs=Y)
kernel = NUTS(model4)
mcmc = MCMC(kernel, num_warmup=250, num_samples=750)
mcmc.run(mcmc_key, X, Y, nclasses, concentration)
# with exclude_deterministic=False, we will also show the ordinal probabilities sampled from Dirichlet (vis. `c_y_base`)
mcmc.print_summary(exclude_deterministic=False)
sample: 100%|██████████| 1000/1000 [00:05<00:00, 193.88it/s, 7 steps of size 7.00e-01. acc. prob=0.93]
mean std median 5.0% 95.0% n_eff r_hat
b_X_eta 1.01 0.26 1.01 0.59 1.42 388.46 1.00
c_y[0] -0.42 0.26 -0.42 -0.88 -0.05 491.73 1.00
c_y[1] 1.34 0.29 1.34 0.86 1.80 617.53 1.00
c_y_base[0] 0.40 0.06 0.40 0.29 0.49 488.71 1.00
c_y_base[1] 0.39 0.06 0.39 0.29 0.48 523.65 1.00
c_y_base[2] 0.21 0.05 0.21 0.13 0.29 610.33 1.00
Number of divergences: 0