5.2.2 变分推断实例:高斯分布近似

接下来,我们介绍一个变分推断的编程实例,以高斯分布作为近似分布。假设目标采样分布p(x)为一个混合高斯分布,由于目标采样分布作为后验分布是已知的,我们可以利用 5.2.1 中的公式(6),直接对KL散度进行最小化的优化求解。

def p(x):
return ((0.3 * np.exp(-(x-0.3)**2))+0.7*np.exp(-(x-2.)**2/0.3))/1.2113
p(x)=0.3e(x0.3)2+0.7e(x2)20.31.2113p(x)=\frac{0.3 * e^{-(x-0.3)^2}+0.7*e^{\frac{-(x-2)^2}{0.3}}}{1.2113}

我们尝试采用一个高斯分布q(x)去近似p(x),假设分布的均值为u,标准差为 σ\sigma

q(x)=e(xu)2σ2q(x)=\frac{e^{-(x-u)^2}}{\sigma^2}

对于给定的参数组合(uuσ\sigma),我们可以使用蒙特卡洛方法对KL散度进行计算:我们从 q(xu,σ)q(x|u,\sigma)进行采样,假设抽取NN个符合分布的样本 XX,即可计算每个样本 xix_i在两个分布中概率密度的对数差,对所有样本的概率密度对数差求平均,即为KL散度的离散形式。

DKL=i[log(q(xi))log(p(xi))]ND_{KL}=\frac{\sum_i[{log(q(x_i))-log(p(x_i))}]}{N}
eps = np.finfo(float).eps #一个大于0的极小值ε,保证在进行对数计算时大于0.

def lossfun(params):
    samplesize = 500 #蒙特卡洛法的采样数
    # q-高斯分布的均值与标准差
    u = params[0]
    sigma = params[1]
    # 从q分布中采样
    ss = norm.rvs(loc=u,scale=sigma,size = samplesize)
    # q中所采样本在p分布的概率密度
    pp1 = p(ss) * 0.999 + eps 
    # q中所采样本在q分布的概率密度
    pp2 = norm.pdf(ss,loc=u,scale=sigma) * 0.999 + eps 
    
    dd = np.log(pp2)-np.log(pp1) 
    return dd.sum()/samplesize

接下来便可使用一般参数优化的方式,在参数空间 (u,σ)(u,\sigma)中求解KL散度最小对应的参数。

res = minimize(fun=lossfun,x0=(1,1),method='Powell',bounds=((-3,3),eps,3))
print(res.x)

参数优化的结果为 (u=0.817,σ=1.034)(u=0.817,\sigma=1.034),两个分布的示意图如下。

x = np.arange(-3., 5., 0.01)
plt.plot(x, p(x), color='r', label='p(x)')
plt.plot(x, norm.pdf(x,loc=res.x[0], scale=res.x[1]), color='b', label='q(x)')
plt.legend()
plt.ylabel('pdf(x)')

最后更新于