注意
跳转到末尾以下载完整示例代码。
示例:用于“高”数据的MCMC方法
本示例展示了适用于“高”数据的各种MCMC方法的用法
algo=”SA” 使用参考文献[1]中的样本自适应MCMC方法
algo=”HMCECS” 使用参考文献[2]中的能量守恒子采样方法
algo=”FlowHMCECS” 利用归一化流将后验几何形状转化为高斯状。然后使用HMCECS提取后验样本。目前,该方法在这些方法中提供了最佳混合率。
参考文献
样本自适应MCMC,Michael Zhu (2019)
能量守恒子采样哈密尔顿蒙特卡洛,Dang, K. D., Quiroz, M., Kohn, R., Minh-Ngoc, T., & Villani, M. (2019)
使用神经传输中和哈密尔顿蒙特卡洛中的不良几何,Hoffman, M. et al. (2019)
import argparse
import time
import matplotlib.pyplot as plt
from jax import random
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.examples.datasets import COVTYPE, load_dataset
from numpyro.infer import HMC, HMCECS, MCMC, NUTS, SA, SVI, Trace_ELBO, init_to_value
from numpyro.infer.autoguide import AutoBNAFNormal
from numpyro.infer.reparam import NeuTraReparam
def _load_dataset():
_, fetch = load_dataset(COVTYPE, shuffle=False)
features, labels = fetch()
# normalize features and add intercept
features = (features - features.mean(0)) / features.std(0)
features = jnp.hstack([features, jnp.ones((features.shape[0], 1))])
# make binary feature
_, counts = jnp.unique(labels, return_counts=True)
specific_category = jnp.argmax(counts)
labels = labels == specific_category
N, dim = features.shape
print("Data shape:", features.shape)
print(
"Label distribution: {} has label 1, {} has label 0".format(
labels.sum(), N - labels.sum()
)
)
return features, labels
def model(data, labels, subsample_size=None):
dim = data.shape[1]
coefs = numpyro.sample("coefs", dist.Normal(jnp.zeros(dim), jnp.ones(dim)))
with numpyro.plate("N", data.shape[0], subsample_size=subsample_size) as idx:
logits = jnp.dot(data[idx], coefs)
return numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels[idx])
def benchmark_hmc(args, features, labels):
rng_key = random.PRNGKey(1)
start = time.time()
# a MAP estimate at the following source
# https://github.com/google/edward2/blob/master/examples/no_u_turn_sampler/logistic_regression.py#L117
ref_params = {
"coefs": jnp.array(
[
+2.03420663e00,
-3.53567265e-02,
-1.49223924e-01,
-3.07049364e-01,
-1.00028366e-01,
-1.46827862e-01,
-1.64167881e-01,
-4.20344204e-01,
+9.47479829e-02,
-1.12681836e-02,
+2.64442056e-01,
-1.22087866e-01,
-6.00568838e-02,
-3.79419506e-01,
-1.06668741e-01,
-2.97053963e-01,
-2.05253899e-01,
-4.69537191e-02,
-2.78072730e-02,
-1.43250525e-01,
-6.77954629e-02,
-4.34899796e-03,
+5.90927452e-02,
+7.23133609e-02,
+1.38526391e-02,
-1.24497898e-01,
-1.50733739e-02,
-2.68872194e-02,
-1.80925727e-02,
+3.47936489e-02,
+4.03552800e-02,
-9.98773426e-03,
+6.20188080e-02,
+1.15002751e-01,
+1.32145107e-01,
+2.69109547e-01,
+2.45785132e-01,
+1.19035013e-01,
-2.59744357e-02,
+9.94279515e-04,
+3.39266285e-02,
-1.44057125e-02,
-6.95222765e-02,
-7.52013028e-02,
+1.21171586e-01,
+2.29205526e-02,
+1.47308692e-01,
-8.34354162e-02,
-9.34122875e-02,
-2.97472421e-02,
-3.03937674e-01,
-1.70958012e-01,
-1.59496680e-01,
-1.88516974e-01,
-1.20889175e00,
]
)
}
if args.algo == "HMC":
step_size = jnp.sqrt(0.5 / features.shape[0])
trajectory_length = step_size * args.num_steps
kernel = HMC(
model,
step_size=step_size,
trajectory_length=trajectory_length,
adapt_step_size=False,
dense_mass=args.dense_mass,
)
subsample_size = None
elif args.algo == "NUTS":
kernel = NUTS(model, dense_mass=args.dense_mass)
subsample_size = None
elif args.algo == "HMCECS":
subsample_size = 1000
inner_kernel = NUTS(
model,
init_strategy=init_to_value(values=ref_params),
dense_mass=args.dense_mass,
)
# note: if num_blocks=100, we'll update 10 index at each MCMC step
# so it took 50000 MCMC steps to iterative the whole dataset
kernel = HMCECS(
inner_kernel, num_blocks=100, proxy=HMCECS.taylor_proxy(ref_params)
)
elif args.algo == "SA":
# NB: this kernel requires large num_warmup and num_samples
# and running on GPU is much faster than on CPU
kernel = SA(
model, adapt_state_size=1000, init_strategy=init_to_value(values=ref_params)
)
subsample_size = None
elif args.algo == "FlowHMCECS":
subsample_size = 1000
guide = AutoBNAFNormal(model, num_flows=1, hidden_factors=[8])
svi = SVI(model, guide, numpyro.optim.Adam(0.01), Trace_ELBO())
svi_result = svi.run(random.PRNGKey(2), 2000, features, labels)
params, losses = svi_result.params, svi_result.losses
plt.plot(losses)
plt.show()
neutra = NeuTraReparam(guide, params)
neutra_model = neutra.reparam(model)
neutra_ref_params = {"auto_shared_latent": jnp.zeros(55)}
# no need to adapt mass matrix if the flow does a good job
inner_kernel = NUTS(
neutra_model,
init_strategy=init_to_value(values=neutra_ref_params),
adapt_mass_matrix=False,
)
kernel = HMCECS(
inner_kernel, num_blocks=100, proxy=HMCECS.taylor_proxy(neutra_ref_params)
)
else:
raise ValueError("Invalid algorithm, either 'HMC', 'NUTS', or 'HMCECS'.")
mcmc = MCMC(kernel, num_warmup=args.num_warmup, num_samples=args.num_samples)
mcmc.run(rng_key, features, labels, subsample_size, extra_fields=("accept_prob",))
print("Mean accept prob:", jnp.mean(mcmc.get_extra_fields()["accept_prob"]))
mcmc.print_summary(exclude_deterministic=False)
print("\nMCMC elapsed time:", time.time() - start)
def main(args):
features, labels = _load_dataset()
benchmark_hmc(args, features, labels)
if __name__ == "__main__":
assert numpyro.__version__.startswith("0.18.0")
parser = argparse.ArgumentParser(description="parse args")
parser.add_argument(
"-n", "--num-samples", default=1000, type=int, help="number of samples"
)
parser.add_argument(
"--num-warmup", default=1000, type=int, help="number of warmup steps"
)
parser.add_argument(
"--num-steps", default=10, type=int, help='number of steps (for "HMC")'
)
parser.add_argument("--num-chains", nargs="?", default=1, type=int)
parser.add_argument(
"--algo",
default="HMCECS",
type=str,
help='whether to run "HMC", "NUTS", "HMCECS", "SA" or "FlowHMCECS"',
)
parser.add_argument("--dense-mass", action="store_true")
parser.add_argument("--x64", action="store_true")
parser.add_argument("--device", default="cpu", type=str, help='use "cpu" or "gpu".')
args = parser.parse_args()
numpyro.set_platform(args.device)
numpyro.set_host_device_count(args.num_chains)
if args.x64:
numpyro.enable_x64()
main(args)