注意
转到末尾以下载完整的示例代码。
示例:Flax 中的条件变分自编码器
本示例使用 Flax 的神经网络 API 在 MNIST 数据集上训练一个 条件变分自编码器 (CVAE) [1]。实现代码可在此处找到:https://github.com/pyro-ppl/numpyro/tree/master/examples/cvae-flax
该模型是 Pyro 优秀 CVAE 示例的移植,该示例详细描述了模型和数据:https://pyro.org.cn/examples/cvae.html
该模型首先训练一个基线模型,用于从图像的单个象限预测整个 MNIST 图像(即,输入是图像的一个象限,输出是整个图像(而不是其他三个象限))。然后,在第二个模型中,训练 CVAE 的生成/先验/识别网络,同时保持基线模型的参数固定/冻结。我们使用 Optax 的 multi_transform 将不同的梯度变换应用于可训练参数和冻结参数。

参考文献
Kihyuk Sohn, Xinchen Yan, Honglak Lee (2015), “使用深度条件生成模型学习结构化输出表示” (https://papers.nips.cc/paper/5775-learning-structured-output-representation-using-deep-conditional-generative-models)