5.1.2 采样
什么是采样?
采样(sampling)指的是从一个概率分布中抽取样本的过程。
采出来的样本对应的是概率分布的x轴
采出来很多样本之后,画histogram的图,应该近似概率分布
采样方法
一般来说,我们在python或matlab等语言中可以很轻松的借助内置函数来从一些常见分布中采样:
from scipy.stats import norm,gamma
a = norm.rvs() # 从高斯分布中采样
print(a)
b = gamma.rvs(1) # 从伽马分布中采样
print(b)
但关键问题是:假如不允许使用这些函数,我们又该如何采样呢?
接下来我们以高斯采样的例子介绍几种常见的采样方法。
问题:只给定均匀分布的样本,如何利用该均匀分布从一个高斯分布中采样?
方法1:逆变化采样(Inverse Transform Sampling)
几乎所有的已知分布都可以用这种方法采样(matlab和python内部就是这么做的)。
但是,如果一个复杂分布,其累积分布函数的反函数未知,如何采样?
方法2:接受-拒绝采样(Accept-Reject Sampling)
以下是代码示例:
首先,设置目标采样分布的概率密度函数:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm,uniform
# 目标采样分布的概率密度函数(已知但是很复杂)
def p(x):
return ((0.3 * np.exp(-(x-0.3)**2))+0.7*np.exp(-(x-2.)**2/0.3))/1.2113
选定值和提议分布函数,并画出分布。
# 设定C值
C = 2.5
# g(x)为均值=1.4, 标准差=1.2的高斯分布
# 画出两个分布
x = np.arange(-4., 6., 0.01)
plt.plot(x, p(x), color='r', label='p(x)')
plt.plot(x, C*norm.pdf(x, loc=1.4, scale=1.2), color='b', label='C*g(x)')
plt.legend()
plt.xlabel('x')
plt.ylabel('pdf')

现在我们来采样:
sample = []
nSample = 10000
for i in range(nSample):
x = np.random.normal(loc=1.4, scale=1.2)
u = np.random.rand()
if u <= p(x)/(C*norm.pdf(x, loc=1.4, scale=1.2)):
sample.append(x)
完成之后画出分布:
x = np.arange(-3., 5., 0.01)
plt.plot(x, p(x), color='r', label='p(x)')
plt.plot(x, C*norm.pdf(x, loc=1.4, scale=1.2), color='b', label='C*g(x)')
plt.hist(sample, color='k', bins=150, density=True)
plt.legend()
plt.ylabel('p(x)')

接受拒绝采样的直观理解:
在每一个的点上, 都有,即
从图上看,就是蓝线和红线的比值。
但是,接受拒绝采样需要找到一个准确的常数,如果这个常数无法直接得到怎么办?
方法3:Metropolis-Hasting采样
在MCMC的证明中会详细介绍该算法并解释它为什么成立,在这里,我们先来看看代码如何实现。
首先,设置目标分布的概率密度函数:
from scipy.stats import norm
import numpy as np
# 目标采样分布的概率密度函数
def p(x):
return ((0.3 * np.exp(-(x-0.3)**2))+0.7*np.exp(-(x-2.)**2/0.3))/1.2113
nSample = 100000
sample2 = []
y = 1 # 设置初始值
#进行MH采样
for i in range(nSample):
y2 = norm(loc=y).rvs() # 从提议分布采样一个
a = np.min([1, p(y2)/p(y)]) # 因为提议分布高斯是对称的, 所以可以上下约掉
u = np.random.rand() # 从[0,1]均匀分布获取一个值
if u <= a:
sample2.append(y2) # 接受并跳转到新的sample
y = y2
else:
sample2.append(y) # 保留原来那个sample不变
# 绘制采样结果
x = np.arange(-3., 5., 0.01)
plt.plot(x, p(x), color='r', label='p(x)')
a,_,_=plt.hist(sample2, color='k', bins=150, density=True)
plt.legend()
plt.xlabel('X')
plt.ylabel('p(x)')

最后更新于