为什么需要变分推断?
在贝叶斯推断中,高维且复杂的后验分布往往难以进行解析计算。
前一节介绍了马尔可夫链蒙特卡洛(MCMC)算法,它通过构造马尔可夫链并进行采样,来避免繁琐的解析计算,从而获得后验分布的样本。
然而,当参数空间维度较高时,MCMC 仍面临收敛速度慢、计算成本高以及收敛判断困难等问题。
在机器学习研究和实际应用中,变分推断(Variational Inference, VI) 因此成为一种更高效的替代方案。
从理论上讲,我们往往无法确切知道真实的生成模型,因此使用一个近似分布来代替真实后验是一种合理且可行的做法。
变分推断是什么?
变分推断是一种近似推断技术,通过求解特定的优化问题寻找一个近似的后验分布。
不同于解析求解与MCMC对后验分布 p(θ∣d) 进行准确估计,变分推断的核心思想是引入一个简单的参数分布 q(θ),来近似复杂的真实后验分布。这种方法将求解计算后验分布问题转化成了优化问题。
现在,让我们具体介绍变分优化的概念是如何应用于推断问题中的:
假设我们有一个贝叶斯模型,每个参数都有给定的先验分布。这个模型具有一些潜变量及参数,我们将模型的潜变量与参数的集合记为 θ, 将所有观测变量记为d,概率模型指定了联合分布 p(θ,d)。
我们做贝叶斯推断的目标是求解后验分布 p(θ∣d),由于后验分布有时难以计算,我们就会使用变分推断,目标是找到后验分布的近似分布。
p(θ∣d)=p(d)p(θ,d)(1) 该公式也能写成另一个形式
p(d)=p(θ∣d)p(θ,d)(2) 此时在公式右边分子分母同时引入一个分布 q(θ),无论 q(θ)是什么,只要不为0,等式依然成立
p(d)=p(θ∣d)/q(θ)p(θ,d)/q(θ))(3) 两边同时求对数,
log(p(d))=log(q(θ)p(θ,d))−log(q(θ)p(θ∣d))(4) 在公式两边,我们对 q(θ)求期望
∫log(p(d))q(θ)dθ=∫log(q(θ)p(θ,d))q(θ)dθ−∫log(q(θ)p(θ∣d))q(θ)dθ log(p(d))=∫log(q(θ)p(θ,d))q(θ)dθ+∫log(p(θ∣d)q(θ))q(θ)dθ(5) 其中,公式(5)右边第二项为 q(θ)与 p(θ∣d)两个分布之间的KL散度
DKL(q(θ)∣∣p(θ∣d))=∫log(p(θ∣d)q(θ))q(θ)dθ(6) 同时,定义公式右边第一项为
L=∫log(q(θ)p(θ,d))q(θ)dθ(7) 由此可得,
log(p(d))=L+DKL(q(θ)∣∣p(θ∣d))(8) 我们的目标是求解后验概率分布 p(θ∣d),由于这个分布很难计算,我们引入了另一个关于 θ 的分布 q(θ),那么我们只要让 q(θ)尽可能地去近似 p(θ∣d),就可以使用 q(θ)代替 p(θ∣d)作为我们要求的解。
让 q(θ)尽可能近似 p(θ∣d),就是要最小化KL散度DKL(q(θ)∣∣p(θ∣d))。不过,直接从公式(6)求解最小值是不可行的,其中p(θ∣d)正是需要求解的部分。
考虑到 log(p(d))是一个定值,根据公式(8)最小化KL散度的求解可以等价转化为最大化 L的优化问题,这里的 L就叫做证据下界 (Evidence Lower Bound, ELBO)。
q(θ)^=q(θ)argmaxL=q(θ)argmax∫log(q(θ)p(θ,d))q(θ)dθ=q(θ)argmax∫log(q(θ)p(θ)∗p(d∣θ))q(θ)dθ=q(θ)argmax∫(logp(θ)+logp(d∣θ)−logq(θ))q(θ)dθ(9) 当ELBO取得最大值时,相应的 q(θ)便是最接近 p(θ∣d)的分布。
如果我们允许 q(θ)的任意取值(不加任何约束),那么证据下界的最大值将会发生在 q(θ)与后验分布 p(θ∣d)相等的时候,也就是KL散度为0时。
然而,我们在实际运算中面对难以解析的后验分布,如果不对 q(θ)的分布进行一定约束,将难以用参数化的方法对 q(θ)进行优化。因此,我们需要考虑一类有限的分布族 q(θ),接着从中寻找使KL散度最小化的分布。
我们添加约束的目标是充分限制分布族,使它们仅包含可解析的分布,同时要保证分布族足够丰富与灵活,以此很好地近似于真实后验分布。必须要强调的是,施加约束纯粹是为了保证分布达到可解析的程度,在这要求之下我们应该使用尽可能丰富的近似分布族。尤其是,使用高度灵活的分布并不存在过拟合的现象,使用更多的灵活分布近似只会让我们更接近真实的后验分布。
其中一种限制近似分布族的方法是,采用一种参数化分布 q(θ∣ω),由一系列参数 ω控制。证据下界L将变成 ω的函数 L(ω),我们便可采用标准的非线性优化技术来求解这些参数的最优取值。这个方法的一个例子是,采用高斯分布作为近似,我们可以对高斯分布的均值和方差这两个参数进行优化。下一小节中,我们会展开这个例子,详细介绍如何实现用变分推断近似高斯分布。