5.2.1 变分推断介绍

为什么需要变分推断?

变分推断是什么?

变分推断是一种近似推断技术,通过求解特定的优化问题寻找一个近似的后验分布。

不同于解析求解与MCMC对后验分布 p(θd)p(\theta|d) 进行准确估计,变分推断的核心思想是引入一个简单的参数分布 q(θ)q(\theta),来近似复杂的真实后验分布。这种方法将求解计算后验分布问题转化成了优化问题。

现在,让我们具体介绍变分优化的概念是如何应用于推断问题中的:

假设我们有一个贝叶斯模型,每个参数都有给定的先验分布。这个模型具有一些潜变量及参数,我们将模型的潜变量与参数的集合记为 θθ, 将所有观测变量记为dd,概率模型指定了联合分布 p(θ,d)p(\theta,d)

我们做贝叶斯推断的目标是求解后验分布 p(θd)p(\theta|d),由于后验分布有时难以计算,我们就会使用变分推断,目标是找到后验分布的近似分布。

p(θd)=p(θ,d)p(d)(1)p(θ|d) = \frac{p(θ,d)}{p(d)} \qquad \qquad \tag{1}

该公式也能写成另一个形式

p(d)=p(θ,d)p(θd)(2)p(d)= \frac{p(θ,d)}{p(θ|d) } \qquad \qquad \tag{2}

此时在公式右边分子分母同时引入一个分布 q(θ)q(θ),无论 q(θ)q(θ)是什么,只要不为0,等式依然成立

p(d)=p(θ,d)/q(θ))p(θd)/q(θ)(3)p(d)= \frac{p(θ,d)/q(θ))}{p(θ|d) /q(θ)} \qquad \qquad \tag{3}

两边同时求对数,

log(p(d))=log(p(θ,d)q(θ))log(p(θd)q(θ))(4)log(p(d))= log(\frac{p(θ,d)}{q(θ)})-log(\frac{p(θ|d)} {q(θ)}) \qquad \qquad \tag{4}

在公式两边,我们对 q(θ)q(θ)求期望

log(p(d))q(θ)dθ=log(p(θ,d)q(θ))q(θ)dθlog(p(θd)q(θ))q(θ)dθ\int log(p(d)) q(θ)dθ =\int log(\frac{p(θ,d)}{q(θ)}) q(θ)dθ-\int log(\frac{p(θ|d)}{q(θ)})q(θ)dθ
log(p(d))=log(p(θ,d)q(θ))q(θ)dθ+log(q(θ)p(θd))q(θ)dθ(5)log(p(d)) =\int log(\frac{p(θ,d)}{q(θ)}) q(θ)dθ+\int log(\frac{q(θ)}{p(θ|d)})q(θ)dθ \qquad \qquad \tag{5}

其中,公式(5)右边第二项为 q(θ)q(θ)p(θd)p(θ|d)两个分布之间的KL散度

DKL(q(θ)p(θd))=log(q(θ)p(θd))q(θ)dθ(6)D_{KL}(q(θ)||p(θ|d))=\int log(\frac{q(θ)}{p(θ|d)})q(θ)dθ \qquad \qquad \tag{6}

同时,定义公式右边第一项为

L=log(p(θ,d)q(θ))q(θ)dθ(7)L=\int log(\frac{p(θ,d)}{q(θ)}) q(θ)dθ \qquad \qquad \tag{7}

由此可得,

log(p(d))=L+DKL(q(θ)p(θd))(8)log(p(d)) = L + D_{KL}(q(θ)||p(θ|d)) \qquad \qquad \tag{8}

我们的目标是求解后验概率分布 p(θd)p(θ|d),由于这个分布很难计算,我们引入了另一个关于 θθ 的分布 q(θ)q(θ),那么我们只要让 q(θ)q(θ)尽可能地去近似 p(θd)p(θ|d),就可以使用 q(θ)q(θ)代替 p(θd)p(θ|d)作为我们要求的解。

q(θ)q(θ)尽可能近似 p(θd)p(θ|d),就是要最小化KL散度DKL(q(θ)p(θd))D_{KL}(q(θ)||p(θ|d))。不过,直接从公式(6)求解最小值是不可行的,其中p(θd)p(θ|d)正是需要求解的部分。

考虑到 log(p(d))log(p(d)) 是一个定值,根据公式(8)最小化KL散度的求解可以等价转化为最大化 LL的优化问题,这里的 LL就叫做证据下界 (Evidence Lower Bound, ELBO)

q(θ)^=arg maxq(θ)L=arg maxq(θ)log(p(θ,d)q(θ))q(θ)dθ=arg maxq(θ)log(p(θ)p(dθ)q(θ))q(θ)dθ=arg maxq(θ)(logp(θ)+logp(dθ)logq(θ))q(θ)dθ(9)\hat{q(θ)} =\argmax_{q(θ)}L \\ = \argmax_{q(θ)} \int log(\frac{p(θ,d)}{q(θ)}) q(θ)dθ \\=\argmax_{q(θ)}\int log(\frac{p(θ)*p(d|θ)}{q(θ)}) q(θ)dθ \\=\argmax_{q(θ)} \int (logp(θ)+logp(d|θ)-logq(θ))q(θ)dθ \qquad \qquad \tag{9}

当ELBO取得最大值时,相应的 q(θ)q(θ)便是最接近 p(θd)p(θ|d)的分布。

如果我们允许 q(θ)q(\theta)的任意取值(不加任何约束),那么证据下界的最大值将会发生在 q(θ)q(θ)与后验分布 p(θd)p(θ|d)相等的时候,也就是KL散度为0时。

然而,我们在实际运算中面对难以解析的后验分布,如果不对 q(θ)q(\theta)的分布进行一定约束,将难以用参数化的方法对 q(θ)q(\theta)进行优化。因此,我们需要考虑一类有限的分布族 q(θ)q(\theta),接着从中寻找使KL散度最小化的分布。

我们添加约束的目标是充分限制分布族,使它们仅包含可解析的分布,同时要保证分布族足够丰富与灵活,以此很好地近似于真实后验分布。必须要强调的是,施加约束纯粹是为了保证分布达到可解析的程度,在这要求之下我们应该使用尽可能丰富的近似分布族。尤其是,使用高度灵活的分布并不存在过拟合的现象,使用更多的灵活分布近似只会让我们更接近真实的后验分布。

其中一种限制近似分布族的方法是,采用一种参数化分布 q(θω)q(\theta|\omega),由一系列参数 ω\omega控制。证据下界LL将变成 ω\omega的函数 L(ω)L(\omega),我们便可采用标准的非线性优化技术来求解这些参数的最优取值。这个方法的一个例子是,采用高斯分布作为近似,我们可以对高斯分布的均值和方差这两个参数进行优化。下一小节中,我们会展开这个例子,详细介绍如何实现用变分推断近似高斯分布。

最后更新于