Rescorlar-wagner模型
Rescor
1. 动物中的强化学习
动物心理学家发现,当多个刺激同时出现时,强化学习中存在三种不同的效应:遮蔽效应(Overshadowing)、阻断效应(Blocking)和抑制效应(Inhibition)。
遮蔽效应是指当在训练阶段两个刺激(A和B)同时出现并与一个结果相关联时,相对于单独对A或B进行条件反射的情况,对A或B的学习效果会因为与另一个刺激的复合条件反射而减弱。例如,如果光(A)和声音(B)同时作为条件刺激(CS)与食物(无条件刺激,UCS)配对,动物对声音(B)的条件反应(CR)会比单独呈现声音时弱。
阻断效应是指当一个条件刺激(CS1)已经与无条件刺激(UCS)建立了强烈的关联后,再引入一个新的条件刺激(CS2)时,CS2很难再与UCS建立关联。也就是说,CS1“阻断”了CS2的学习。例如,如果声音(CS1)已经与食物(UCS)配对并引发了唾液分泌(CR),然后再引入光(CS2)与声音一起出现并继续与食物配对,动物只会对声音反应而不会对光产生条件反应。
抑制效应是指一个刺激通过与另一个刺激的关联,抑制了对某个结果的预期。例如,如果光(CS1)与食物(UCS)配对引发唾液分泌(CR),然后再引入声音(CS2)与光一起出现,但这次没有食物出现,动物在测试阶段只会对CS1产生条件反应,而对CS2无反应,甚至会产生相反的条件反应(见表1)。
效应
阶段 1
阶段 2
学习测试
遮蔽效应
[A + B] →食物(UCS)
[A]中等; [B]中等
阻断效应
[A] → UCS
[A + B] → UCS
[A]强; [B]无
抑制效应
[A] → UCS
[A + B] → 0
[A]强[B]无/ [A]强[B]负
A:光; B: 声音; UCS: 无条件刺激; [A] 或者 [B] 代表光或者声音单独出现;[A + B]代表光和声音一起出现。
2. Rescorlar-wagner模型
由之前的章节得到心理学中强化学习的最简递推公式:,被称为简单强化学习模型(navie reinforcement learning model)。但是简单的强化学习(PEi = ri − ;i = Vi -1+ )无法解释以上三种条件反射现象。
以下通过Python代码模拟基于简单强化学习模型的遮蔽效应的形成过程。
# Naïve model
from scipy.stats import bernoulli
import numpy as np
import matplotlib.pyplot as plt
def overshadow(T):
'''
T: number of trials
'''
r = np.ones(T) # (T,)维全是1的数组
A = np.ones(T) # (T,)维全是1的数组
B = np.ones(T) # (T,)维全是1的数组
return r, A, B
a = 0.05 # learning rate
T = 100 # number of trials
r, A, B = overshadow(T)
Va = np.empty(T)
Vb = np.empty(T)
for i in range(T): # loop trials
if i == 0:
Va[0] = 0 + a*(r[0]-0)
Vb[0] = 0 + a*(r[0]-0)
else:
Va[i] = Va[i-1] + a*(r[i]-Va[i-1])
Vb[i] = Vb[i-1] + a*(r[i]-Vb[i-1])
plt.plot(r, '-k', label='reward')
plt.plot(Va, linewidth=10, label='Va')
plt.plot(Vb, linewidth=2, label='Vb')
plt.xlabel('Trials')
plt.ylabel('Value')
plt.legend()
```

基于此,1972年Robert A. Rescorla和Allan R. Wagner提出了瑞思考勒-瓦格纳(Rescorlar-Wagner, RW)模型。根据RW模型,当多个刺激存在的时候,预测误差(Prediction error)应该为奖赏减去所有出现刺激的价值( ,其中第i个trial,一种出现N个刺激)。在更新刺激的价值时,只更新出现了刺激的价值。因此对于同时出现两种刺激的刺激价值的更新过程为:
PE[i] = r[i]-(Va[i-1]A[i]+Vb[i-1]B[i])
Va[i] = Va[i-1] + aPE[i] * A[i]
Vb[i] = Vb[i-1] + aPE[i] * B[i]
RW模型能成功解释上述三种现象。以下通过Python代码模拟基于RW模型的遮蔽效应,阻断效应和抑制效应的形成过程。
# RW model-Overshadowing
a = 0.05 # learning rate
T = 100 # how many trials
r, A, B = overshadow(T)
Va = np.empty(T)
Vb = np.empty(T)
for i in range(T): # loop trials
if i == 0:
Va[0] = 0 + a*(r[0] - 0 * A[i] - 0 * B[i])
Vb[0] = 0 + a*(r[0] - 0 * A[i] - 0 * B[i])
else:
PE = r[i]-(Va[i-1]*A[i]+Vb[i-1]*B[i])
Va[i] = Va[i-1] + a*PE * A[i]
Vb[i] = Vb[i-1] + a*PE * B[i]
plt.plot(r, '-k', label='reward')
plt.plot(Va, linewidth=10, label='Va')
plt.plot(Vb, linewidth=2, label='Vb')
plt.xlabel('Trials')
plt.ylabel('Value')
plt.legend()

def block(T):
T1 = round(T/2) # round是取整数的意思
# phase 1: A -> R
r1 = np.ones(T1)
A1 = np.ones(T1)
B1 = np.zeros(T1)
# phase 2: A, B -> R
r2 = np.ones(T-T1)
A2 = np.ones(T-T1)
B2 = np.ones(T-T1)
# combine two phases
r = np.hstack((r1, r2))
A = np.hstack((A1, A2))
B = np.hstack((B1, B2))
return r, A, B
a = 0.05 # learning rate
T = 1000 # how many trials
r, A, B = block(T)
Va = np.empty(T)
Vb = np.empty(T)
for i in range(T): # loop trials
if i == 0:
Va[0] = 0 + a*(r[0] - 0 * A[i] - 0 * B[i])
Vb[0] = 0 + a*(r[0] - 0 * A[i] - 0 * B[i])
else:
PE = r[i]-Va[i-1]*A[i]-Vb[i-1]*B[i]
Va[i] = Va[i-1] + a*PE * A[i]
Vb[i] = Vb[i-1] + a*PE * B[i]
plt.plot(r, '-k', label='reward')
plt.plot(Va, linewidth=10, label='Va')
plt.plot(Vb, linewidth=2, label='Vb')
plt.xlabel('Trials')
plt.ylabel('Value')
plt.legend()

def inhibition(T):
r = np.empty(T)
A = np.empty(T)
B = np.empty(T)
for i in range(T):
if bernoulli.rvs(0.5, 1)==1:
# A -> R trial
r[i] = 1
A[i] = 1
B[i] = 0
else:
# A,B -> 0 trial
r[i] = 0
A[i] = 1
B[i] = 1
return r, A, B
a = 0.05 # learning rate
T = 1000 # how many trials
r, A, B = inhibition(T)
Va = np.empty(T)
Vb = np.empty(T)
for i in range(T): # loop trials
if i == 0:
Va[0] = 0 + a*(r[0] - 0 * A[i] - 0 * B[i])
Vb[0] = 0 + a*(r[0] - 0 * A[i] - 0 * B[i])
else:
PE = r[i]-Va[i-1]*A[i]-Vb[i-1]*B[i]
Va[i] = Va[i-1] + a*PE * A[i]
Vb[i] = Vb[i-1] + a*PE * B[i]
#plt.plot(r, '-k', label='reward')
plt.plot(Va, linewidth=10, label='Va')
plt.plot(Vb, linewidth=2, label='Vb')
plt.xlabel('Trials')
plt.ylabel('Value')
plt.legend()

最后更新于