5.2.2 变分推断实例:高斯分布近似
接下来,我们介绍一个变分推断的编程实例,以高斯分布作为近似分布。假设目标采样分布p(x)为一个混合高斯分布,由于目标采样分布作为后验分布是已知的,我们可以利用 5.2.1 中的公式(6),直接对KL散度进行最小化的优化求解。
def p(x):
return ((0.3 * np.exp(-(x-0.3)**2))+0.7*np.exp(-(x-2.)**2/0.3))/1.2113
我们尝试采用一个高斯分布q(x)去近似p(x),假设分布的均值为u,标准差为
对于给定的参数组合(和 ),我们可以使用蒙特卡洛方法对KL散度进行计算:我们从 进行采样,假设抽取个符合分布的样本 ,即可计算每个样本 在两个分布中概率密度的对数差,对所有样本的概率密度对数差求平均,即为KL散度的离散形式。
eps = np.finfo(float).eps #一个大于0的极小值ε,保证在进行对数计算时大于0.
def lossfun(params):
samplesize = 500 #蒙特卡洛法的采样数
# q-高斯分布的均值与标准差
u = params[0]
sigma = params[1]
# 从q分布中采样
ss = norm.rvs(loc=u,scale=sigma,size = samplesize)
# q中所采样本在p分布的概率密度
pp1 = p(ss) * 0.999 + eps
# q中所采样本在q分布的概率密度
pp2 = norm.pdf(ss,loc=u,scale=sigma) * 0.999 + eps
dd = np.log(pp2)-np.log(pp1)
return dd.sum()/samplesize
接下来便可使用一般参数优化的方式,在参数空间 中求解KL散度最小对应的参数。
res = minimize(fun=lossfun,x0=(1,1),method='Powell',bounds=((-3,3),eps,3))
print(res.x)
参数优化的结果为 ,两个分布的示意图如下。
x = np.arange(-3., 5., 0.01)
plt.plot(x, p(x), color='r', label='p(x)')
plt.plot(x, norm.pdf(x,loc=res.x[0], scale=res.x[1]), color='b', label='q(x)')
plt.legend()
plt.ylabel('pdf(x)')

最后更新于