注意
转到末尾 下载完整示例代码。
示例:枚举隐马尔可夫模型
本示例移植自 [1],展示了如何在 Pyro 中边际化离散模型变量。
这结合了 MCMC 和变量消除算法,其中我们使用枚举来精确边际化联合密度中的某些变量。
要边际化离散变量 x
验证您的模型中的变量依赖结构允许可处理的推断,即枚举变量之间的依赖图应具有窄的树宽。
确保您的模型可以处理这些变量的样本值的广播。
请注意,与使用 Python 循环的 [1] 不同,此处我们使用 scan()
来减少模型的编译时间(只需编译一步)。在底层,scan 将所有先验的参数和值堆叠到额外的时间维度中。这使得我们可以并行计算联合密度。此外,堆叠的形式允许我们使用 [2] 中的并行扫描算法,该算法将并行复杂度从 O(长度) 降低到 O(log(长度))。
数据取自 [3]。然而,数据的原始来源似乎是卡尔斯鲁厄大学的 Institut fuer Algorithmen und Kognitive Systeme。
参考文献
Pyro 的隐马尔可夫模型示例, (https://pyro.org.cn/examples/hmm.html)
Temporal Parallelization of Bayesian Smoothers, Simo Sarkka, Angel F. Garcia-Fernandez (https://arxiv.org/abs/1905.13002)
Modeling Temporal Dependencies in High-Dimensional Sequences: Application to Polyphonic Music Generation and Transcription, Boulanger-Lewandowski, N., Bengio, Y. and Vincent, P.
Tensor Variable Elimination for Plated Factor Graphs, Fritz Obermeyer, Eli Bingham, Martin Jankowiak, Justin Chiu, Neeraj Pradhan, Alexander Rush, Noah Goodman (https://arxiv.org/abs/1902.03210)
import argparse
import logging
import os
import time
from jax import random
import jax.numpy as jnp
import numpyro
from numpyro.contrib.control_flow import scan
import numpyro.distributions as dist
from numpyro.examples.datasets import JSB_CHORALES, load_dataset
from numpyro.handlers import mask
from numpyro.infer import HMC, MCMC, NUTS
from numpyro.ops.indexing import Vindex
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
让我们从一个简单的隐马尔可夫模型开始。
# x[t-1] --> x[t] --> x[t+1]
# | | |
# V V V
# y[t-1] y[t] y[t+1]
#
# This model includes a plate for the data_dim = 44 keys on the piano. This
# model has two "style" parameters probs_x and probs_y that we'll draw from a
# prior. The latent state is x, and the observed state is y.
def model_1(sequences, lengths, args, include_prior=True):
num_sequences, max_length, data_dim = sequences.shape
with mask(mask=include_prior):
probs_x = numpyro.sample(
"probs_x", dist.Dirichlet(0.9 * jnp.eye(args.hidden_dim) + 0.1).to_event(1)
)
probs_y = numpyro.sample(
"probs_y",
dist.Beta(0.1, 0.9).expand([args.hidden_dim, data_dim]).to_event(2),
)
def transition_fn(carry, y):
x_prev, t = carry
with numpyro.plate("sequences", num_sequences, dim=-2):
with mask(mask=(t < lengths)[..., None]):
x = numpyro.sample(
"x",
dist.Categorical(probs_x[x_prev]),
infer={"enumerate": "parallel"},
)
with numpyro.plate("tones", data_dim, dim=-1):
numpyro.sample("y", dist.Bernoulli(probs_y[x.squeeze(-1)]), obs=y)
return (x, t + 1), None
x_init = jnp.zeros((num_sequences, 1), dtype=jnp.int32)
# NB swapaxes: we move time dimension of `sequences` to the front to scan over it
scan(transition_fn, (x_init, 0), jnp.swapaxes(sequences, 0, 1))
接下来让我们添加 y[t] 对 y[t-1] 的依赖。
# x[t-1] --> x[t] --> x[t+1]
# | | |
# V V V
# y[t-1] --> y[t] --> y[t+1]
def model_2(sequences, lengths, args, include_prior=True):
num_sequences, max_length, data_dim = sequences.shape
with mask(mask=include_prior):
probs_x = numpyro.sample(
"probs_x", dist.Dirichlet(0.9 * jnp.eye(args.hidden_dim) + 0.1).to_event(1)
)
probs_y = numpyro.sample(
"probs_y",
dist.Beta(0.1, 0.9).expand([args.hidden_dim, 2, data_dim]).to_event(3),
)
def transition_fn(carry, y):
x_prev, y_prev, t = carry
with numpyro.plate("sequences", num_sequences, dim=-2):
with mask(mask=(t < lengths)[..., None]):
x = numpyro.sample(
"x",
dist.Categorical(probs_x[x_prev]),
infer={"enumerate": "parallel"},
)
# Note the broadcasting tricks here: to index probs_y on tensors x and y,
# we also need a final tensor for the tones dimension. This is conveniently
# provided by the plate associated with that dimension.
with numpyro.plate("tones", data_dim, dim=-1) as tones:
y = numpyro.sample(
"y", dist.Bernoulli(probs_y[x, y_prev, tones]), obs=y
)
return (x, y, t + 1), None
x_init = jnp.zeros((num_sequences, 1), dtype=jnp.int32)
y_init = jnp.zeros((num_sequences, data_dim), dtype=jnp.int32)
scan(transition_fn, (x_init, y_init, 0), jnp.swapaxes(sequences, 0, 1))
接下来考虑一个具有两个隐藏状态的分解式 HMM。
# w[t-1] ----> w[t] ---> w[t+1]
# \ x[t-1] --\-> x[t] --\-> x[t+1]
# \ / \ / \ /
# \/ \/ \/
# y[t-1] y[t] y[t+1]
#
# Note that since the joint distribution of each y[t] depends on two variables,
# those two variables become dependent. Therefore during enumeration, the
# entire joint space of these variables w[t],x[t] needs to be enumerated.
# For that reason, we set the dimension of each to the square root of the
# target hidden dimension.
def model_3(sequences, lengths, args, include_prior=True):
num_sequences, max_length, data_dim = sequences.shape
hidden_dim = int(args.hidden_dim**0.5) # split between w and x
with mask(mask=include_prior):
probs_w = numpyro.sample(
"probs_w", dist.Dirichlet(0.9 * jnp.eye(hidden_dim) + 0.1).to_event(1)
)
probs_x = numpyro.sample(
"probs_x", dist.Dirichlet(0.9 * jnp.eye(hidden_dim) + 0.1).to_event(1)
)
probs_y = numpyro.sample(
"probs_y",
dist.Beta(0.1, 0.9).expand([args.hidden_dim, 2, data_dim]).to_event(3),
)
def transition_fn(carry, y):
w_prev, x_prev, t = carry
with numpyro.plate("sequences", num_sequences, dim=-2):
with mask(mask=(t < lengths)[..., None]):
w = numpyro.sample(
"w",
dist.Categorical(probs_w[w_prev]),
infer={"enumerate": "parallel"},
)
x = numpyro.sample(
"x",
dist.Categorical(probs_x[x_prev]),
infer={"enumerate": "parallel"},
)
# Note the broadcasting tricks here: to index probs_y on tensors x and y,
# we also need a final tensor for the tones dimension. This is conveniently
# provided by the plate associated with that dimension.
with numpyro.plate("tones", data_dim, dim=-1) as tones:
numpyro.sample("y", dist.Bernoulli(probs_y[w, x, tones]), obs=y)
return (w, x, t + 1), None
w_init = jnp.zeros((num_sequences, 1), dtype=jnp.int32)
x_init = jnp.zeros((num_sequences, 1), dtype=jnp.int32)
scan(transition_fn, (w_init, x_init, 0), jnp.swapaxes(sequences, 0, 1))
通过添加 x 对 w 的依赖,我们将其泛化为动态贝叶斯网络。
# w[t-1] ----> w[t] ---> w[t+1]
# | \ | \ | \
# | x[t-1] ----> x[t] ----> x[t+1]
# | / | / | /
# V / V / V /
# y[t-1] y[t] y[t+1]
#
# Note that message passing here has roughly the same cost as with the
# Factorial HMM, but this model has more parameters.
def model_4(sequences, lengths, args, include_prior=True):
num_sequences, max_length, data_dim = sequences.shape
hidden_dim = int(args.hidden_dim**0.5) # split between w and x
with mask(mask=include_prior):
probs_w = numpyro.sample(
"probs_w", dist.Dirichlet(0.9 * jnp.eye(hidden_dim) + 0.1).to_event(1)
)
probs_x = numpyro.sample(
"probs_x",
dist.Dirichlet(0.9 * jnp.eye(hidden_dim) + 0.1)
.expand_by([hidden_dim])
.to_event(2),
)
probs_y = numpyro.sample(
"probs_y",
dist.Beta(0.1, 0.9).expand([hidden_dim, hidden_dim, data_dim]).to_event(3),
)
def transition_fn(carry, y):
w_prev, x_prev, t = carry
with numpyro.plate("sequences", num_sequences, dim=-2):
with mask(mask=(t < lengths)[..., None]):
w = numpyro.sample(
"w",
dist.Categorical(probs_w[w_prev]),
infer={"enumerate": "parallel"},
)
x = numpyro.sample(
"x",
dist.Categorical(Vindex(probs_x)[w, x_prev]),
infer={"enumerate": "parallel"},
)
with numpyro.plate("tones", data_dim, dim=-1) as tones:
numpyro.sample("y", dist.Bernoulli(probs_y[w, x, tones]), obs=y)
return (w, x, t + 1), None
w_init = jnp.zeros((num_sequences, 1), dtype=jnp.int32)
x_init = jnp.zeros((num_sequences, 1), dtype=jnp.int32)
scan(transition_fn, (w_init, x_init, 0), jnp.swapaxes(sequences, 0, 1))
接下来考虑一个二阶 HMM 模型,其中 x[t+1] 依赖于 x[t] 和 x[t-1]。
# _______>______
# _____>_____/______ \
# / / \ \
# x[t-1] --> x[t] --> x[t+1] --> x[t+2]
# | | | |
# V V V V
# y[t-1] y[t] y[t+1] y[t+2]
#
# Note that in this model (in contrast to the previous model) we treat
# the transition and emission probabilities as parameters (so they have no prior).
#
# Note that this is the "2HMM" model in reference [4].
def model_6(sequences, lengths, args, include_prior=False):
num_sequences, max_length, data_dim = sequences.shape
with mask(mask=include_prior):
# Explicitly parameterize the full tensor of transition probabilities, which
# has hidden_dim cubed entries.
probs_x = numpyro.sample(
"probs_x",
dist.Dirichlet(0.9 * jnp.eye(args.hidden_dim) + 0.1)
.expand([args.hidden_dim, args.hidden_dim])
.to_event(2),
)
probs_y = numpyro.sample(
"probs_y",
dist.Beta(0.1, 0.9).expand([args.hidden_dim, data_dim]).to_event(2),
)
def transition_fn(carry, y):
x_prev, x_curr, t = carry
with numpyro.plate("sequences", num_sequences, dim=-2):
with mask(mask=(t < lengths)[..., None]):
probs_x_t = Vindex(probs_x)[x_prev, x_curr]
x_prev, x_curr = (
x_curr,
numpyro.sample(
"x",
dist.Categorical(probs_x_t),
infer={"enumerate": "parallel"},
),
)
with numpyro.plate("tones", data_dim, dim=-1):
probs_y_t = probs_y[x_curr.squeeze(-1)]
numpyro.sample("y", dist.Bernoulli(probs_y_t), obs=y)
return (x_prev, x_curr, t + 1), None
x_prev = jnp.zeros((num_sequences, 1), dtype=jnp.int32)
x_curr = jnp.zeros((num_sequences, 1), dtype=jnp.int32)
scan(transition_fn, (x_prev, x_curr, 0), jnp.swapaxes(sequences, 0, 1), history=2)
进行推断
models = {
name[len("model_") :]: model
for name, model in globals().items()
if name.startswith("model_")
}
def main(args):
model = models[args.model]
_, fetch = load_dataset(JSB_CHORALES, split="train", shuffle=False)
lengths, sequences = fetch()
if args.num_sequences:
sequences = sequences[0 : args.num_sequences]
lengths = lengths[0 : args.num_sequences]
logger.info("-" * 40)
logger.info("Training {} on {} sequences".format(model.__name__, len(sequences)))
# find all the notes that are present at least once in the training set
present_notes = (sequences == 1).sum(0).sum(0) > 0
# remove notes that are never played (we remove 37/88 notes with default args)
sequences = sequences[:, :, present_notes]
if args.truncate:
lengths = lengths.clip(0, args.truncate)
sequences = sequences[:, : args.truncate]
logger.info("Each sequence has shape {}".format(sequences[0].shape))
logger.info("Starting inference...")
rng_key = random.PRNGKey(2)
start = time.time()
kernel = {"nuts": NUTS, "hmc": HMC}[args.kernel](model)
mcmc = MCMC(
kernel,
num_warmup=args.num_warmup,
num_samples=args.num_samples,
num_chains=args.num_chains,
progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
)
mcmc.run(rng_key, sequences, lengths, args=args)
mcmc.print_summary()
logger.info("\nMCMC elapsed time: {}".format(time.time() - start))
if __name__ == "__main__":
assert numpyro.__version__.startswith("0.18.0")
parser = argparse.ArgumentParser(description="HMC for HMMs")
parser.add_argument(
"-m",
"--model",
default="1",
type=str,
help="one of: {}".format(", ".join(sorted(models.keys()))),
)
parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
parser.add_argument("-d", "--hidden-dim", default=16, type=int)
parser.add_argument("-t", "--truncate", type=int)
parser.add_argument("--num-sequences", type=int)
parser.add_argument("--kernel", default="nuts", type=str)
parser.add_argument("--num-warmup", nargs="?", default=500, type=int)
parser.add_argument("--num-chains", nargs="?", default=1, type=int)
parser.add_argument("--device", default="cpu", type=str, help='use "cpu" or "gpu".')
args = parser.parse_args()
numpyro.set_platform(args.device)
numpyro.set_host_device_count(args.num_chains)
main(args)