5.2.3 变分推断在参数估计中的应用
如何使用变分推断的方法估计Weibull心理物理曲线的阈值?
Weibull函数为
其中
为该正确率下的阈值,为函数的斜率,为基线概率水平的概率值, 为阈值对应的正确率。
假设我们当前的参数组合 的真值为(0.25,3,0.5,0.82),在实验设置的一致性水平中对应的心理物理曲线如下,

我们根据实验设置与Weibull函数,利用伯努利随机过程生成了5000个模拟数据,即5种一致性水平下,各生成1000个试次数据,每个试次数据代表被试做对(1)和做错(0)。接下来,我们尝试对这批模拟数据进行参数估计。为了简化例子,这里我们只假定阈值 是需要估计的自由参数,其它参数皆为已知的固定参数。
根据二项选择任务常用的对数似然公式,计算出模拟数据在给定参数的Weibull函数中的似然值
eps = np.finfo(float).eps #大于0的极小数,防止计算对数时出现0
def loglikeli(alpha): # 计算模拟数据在给定参数的Weibull函数中的对数似然值
g = 0.82
beta = 3
gamma = 0.5
prob = np.zeros(cohTrial.size)
prob = weibullfun(cohTrial, alpha, beta, gamma, g)
prob = prob*data + (1-prob)*(1-data)
prob = prob * 0.999 + eps #防止后面计算log时出现0值
return np.log(prob).sum()
接着我们从假定 分布为高斯分布,并从中对参数 的值进行采样。接着,分别计算出 ,根据公式9汇总并对所有采样进行平均得到 值。
samplesize = 500 # 从q(x)里面取多少
from scipy.stats import norm,uniform
eps = np.finfo(float).eps
def lossfun(params):
# q分布的参数
u = params[0] # q高斯函数的均值
sigma = params[1] # q高斯函数的std
# 对q分布进行采样
ss = norm.rvs(loc=u, scale=sigma, size=samplesize) # ss是(-inf, inf) # 从q分布中采500个样本
# 计算这些样本的 logq
llq = np.log(norm.pdf(ss, loc=u, scale=sigma)*0.999 + eps) li
# 将采样的范围(-inf, inf)转换为系数的阈限(0,1)内
ss = 1./(1+np.exp(-ss))
llp = [loglikeli(s) for s in ss] # 求这些log(p(data|alpha))
llp = np.array(llp) # 对于所有的ss,计算针对data的likelihood
#计算先验值 log(p(alpha))
llprior = np.log(uniform.pdf(ss,loc=0, scale=1)*0.999+eps)
# 构建优化目标,根据公式9汇总得到L值,并取相反数
dd = llq-llp - llprior
return dd.sum()/samplesize/cohTrial.size # 我们再除以trial数量,以保证返回数量不会太大
接下来优化这个负对数似然函数,求解最小化负对数似然(等价于最大化L值)对应的参数 的分布,输出计算所得参数 的均值并画出分布
from scipy.optimize import minimize
res = minimize(fun=lossfun, x0=(0, 1), method='Powell', bounds=((-5, 0), (eps, 5)))
print('The estimated parameters is', res.x)
# Gaussian distribution的均值转化成(0,1)的阈限为
print('The mean of threshold is', 1./(1+np.exp(-res.x[0])))
The estimated parameters is [-1.0848973 0.02625092]
The mean of threshold is 0.25258036744019224
# 完成之后, 我们画出分布
x = np.arange(-5., 0., 0.01)
xx = 1/(1+np.exp(-x))
plt.plot(xx, norm.pdf(x, loc=res.x[0], scale=res.x[1]), color='b', label='q(x)')
plt.legend()
plt.ylabel('q(x)=p(x|data)')

最后更新于