注意
转到末尾以下载完整的示例代码。
示例:高斯壳的嵌套采样
本示例演示了 contrib 类 NestedSampler 的用法,它是 jaxns 库 ([1]) 的一个包装器,用于 NumPyro 模型。
在这里,我们将复现 [2] 中的高斯壳演示,并与 NUTS 采样器进行比较。
参考文献
import argparse
import matplotlib.pyplot as plt
from jax import random
import jax.numpy as jnp
import numpyro
from numpyro.contrib.nested_sampling import NestedSampler
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, DiscreteHMCGibbs
class GaussianShell(dist.Distribution):
support = dist.constraints.real_vector
def __init__(self, loc, radius, width):
self.loc, self.radius, self.width = loc, radius, width
super().__init__(batch_shape=loc.shape[:-1], event_shape=loc.shape[-1:])
def sample(self, key, sample_shape=()):
return jnp.zeros(
sample_shape + self.shape()
) # a dummy sample to initialize the samplers
def log_prob(self, value):
normalizer = (-0.5) * (jnp.log(2.0 * jnp.pi) + 2.0 * jnp.log(self.width))
d = jnp.linalg.norm(value - self.loc, axis=-1)
return normalizer - 0.5 * ((d - self.radius) / self.width) ** 2
def model(center1, center2, radius, width, enum=False):
z = numpyro.sample(
"z", dist.Bernoulli(0.5), infer={"enumerate": "parallel"} if enum else {}
)
x = numpyro.sample("x", dist.Uniform(-6.0, 6.0).expand([2]).to_event(1))
center = jnp.stack([center1, center2])[z]
numpyro.sample("shell", GaussianShell(center, radius, width), obs=x)
def run_inference(args, data):
print("=== Performing Nested Sampling ===")
ns = NestedSampler(model)
ns.run(random.PRNGKey(0), **data, enum=args.enum)
ns.print_summary()
# samples obtained from nested sampler are weighted, so
# we need to provide random key to resample from those weighted samples
ns_samples = ns.get_samples(random.PRNGKey(1), num_samples=args.num_samples)
print("\n=== Performing MCMC Sampling ===")
if args.enum:
mcmc = MCMC(
NUTS(model), num_warmup=args.num_warmup, num_samples=args.num_samples
)
else:
mcmc = MCMC(
DiscreteHMCGibbs(NUTS(model)),
num_warmup=args.num_warmup,
num_samples=args.num_samples,
)
mcmc.run(random.PRNGKey(2), **data, enum=args.enum)
mcmc.print_summary()
mcmc_samples = mcmc.get_samples()
return ns_samples["x"], mcmc_samples["x"]
def main(args):
data = dict(
radius=2.0,
width=0.1,
center1=jnp.array([-3.5, 0.0]),
center2=jnp.array([3.5, 0.0]),
)
ns_samples, mcmc_samples = run_inference(args, data)
# plotting
fig, (ax1, ax2) = plt.subplots(
2, 1, sharex=True, figsize=(8, 8), constrained_layout=True
)
ax1.plot(mcmc_samples[:, 0], mcmc_samples[:, 1], "ro", alpha=0.2)
ax1.set(
xlim=(-6, 6),
ylim=(-2.5, 2.5),
ylabel="x[1]",
title="Gaussian-shell samples using NUTS",
)
ax2.plot(ns_samples[:, 0], ns_samples[:, 1], "ro", alpha=0.2)
ax2.set(
xlim=(-6, 6),
ylim=(-2.5, 2.5),
xlabel="x[0]",
ylabel="x[1]",
title="Gaussian-shell samples using Nested Sampler",
)
plt.savefig("gaussian_shells_plot.pdf")
if __name__ == "__main__":
assert numpyro.__version__.startswith("0.18.0")
parser = argparse.ArgumentParser(description="Nested sampler for Gaussian shells")
parser.add_argument("-n", "--num-samples", nargs="?", default=10000, type=int)
parser.add_argument("--num-warmup", nargs="?", default=1000, type=int)
parser.add_argument(
"--enum",
action="store_true",
default=False,
help="whether to enumerate over the discrete latent variable",
)
parser.add_argument("--device", default="cpu", type=str, help='use "cpu" or "gpu".')
args = parser.parse_args()
numpyro.enable_x64()
numpyro.set_platform(args.device)
main(args)