背景

Q-learning 算法主要是维护一个以状态空间和动作空间为行和列的表格,其中的 QQ 值表示当前状态下采取动作能带来的价值。

DQN可以视为 Q-learning 的进阶版,是针对巨大的状态和动作空间、Q表格的维护和查找不现实所引入神经网络的方法。但是引入非线性函数、使用神经网络近似Q表,训练结果可能不收敛。

DQN有两版,2013版《Playing Atari with Deep Reinforcement Learning》伪代码:(LaTex格式的伪代码显示不了,所以就用图片代替了)

2015版《Human-level control through deep reinforcement learning》伪代码:

Q-learning算法

Q-learning 是一种 model-free 的强化学习算法,用于学习特定状态下动作的价值。该算法的核心为使用贝尔曼方程进行简单的值迭代更新,对当前值和新信息进行加权平均。

model-free 算法只用学习策略即可。model-based 算法让智能体学习环境的模型,在训练之后,智能体必须通过预测下一个状态和 reward 来采取行动。

Bellman方程:
状态s的最优值 V(s)V^*(s)、Q-state的最优值 Q(s,a)Q^*(s,a)
V(s)=maxasT(s,a,s)[R(s,a,s)+rV(s)]=maxaQ(s,a)V^*(s)=\max\limits_{a}\sum\limits_{s'}T(s,a,s')[R(s,a,s')+rV^*(s')]=\max\limits_{a}Q^*(s,a)
Q(s,a)=sT(s,a,s)[R(s,a,s)+rV(s)]Q^*(s,a)=\sum\limits_{s'}T(s,a,s')[R(s,a,s')+rV^*(s')]

值迭代更新:

Qnew(st,at)Q(st,at)+α[rt+1+γmaxaQ(st+1,a)Q(st,at)]Q^{new}(s_t,a_t)\leftarrow Q(s_t,a_t)+\alpha[r_{t+1}+\gamma \, \underset{a}{\mathrm{max}}\, Q(s_{t+1},a)-Q(s_t,a_t)]

其中:

  • Q(st,at)Q(s_t,a_t) 为当前值
  • α\alpha 为学习率
  • rt+1r_{t+1} 为状态从 sts_tst+1s_{t+1} 获得的奖励
  • γ\gamma 为折扣系数(接近于0时,智能体更在意短期回报;接近于1时,更在意长期回报)
  • maxaQ(st+1,a)\max\limits_{a}Q(s_{t+1},a) 表示对未来最优值的估计

DQN 的两大改进

DQN 的算法框架如图所示。对于训练更新规则,将使用一个事实——Every Q function for some policy obeys the Bellman方程:

Qπ(s,a)=r+γQπ(s,π(s))Q^π(s, a) = r + γQ^π(s',π(s'))

时间差分误差(Temporal difference error)为:δ=Q(s,a)Q(s,a)δ=Q(s,a)-Q^*(s,a)。损失函数一般采用均方误差损失。

Replay Memory经验池

DQN 引入了经验池,利用 Q-learning 是 off-policy 的特性,使用经验回放记忆训练 DQN。存储智能体观察到的转换,允许稍后重用这些数据,重复利用过去经验。通过随机抽样,构建批次的转换是 decorrelated。已经表明这极大的稳定和改进了 DQN 训练过程(DQN tutorial那里提到的)。代码实现中需要两个类:

  • Transition - 一个命名元组,代表环境中的单个transition。本质上将(state, action)对映射到它们的(next_state, reward)结果。
  • ReplayMemory - 一个有界大小的循环缓冲区,用于保存最近观察到的transitions。它还实现了一个.sample()的方法,用于随机选择批次的transition进行训练。

固定 Q-target

DQN 中会有两个结构完全相同但是参数却不同的网络,一个(policy_net)用于预测Q估计、一个(target_net)用于预测Q现实。Q(s,a)Q(s,a) 在用表格表示与用网络表示有很大的不同:修改 Q(s,a)Q(s,a) 要让它接近r+γmaxaQ(st+1,a)r+\gamma\cdot\max\limits_{a'}Q(s_{t+1},a'),对于表格表示,修改不会影响别的位置的值;而对于网络表示,修改网络参数会造成全局的变化,会导致 maxaQ(st+1,a)\max\limits_{a'}Q(s_{t+1},a') 的值也改变,即预测值在接近目标值的过程中,目标值也在移动,因此训练比较困难。

为此研究人员提出了固定Q目标网络的方法,一段时间内将目标网络固定住保持稳定,定期拷贝同样的参数过去。就像射箭射兔子,每射一箭兔子会乱动一下,可以把兔子固定住变成靶子。

过程描述:初始化 policy_net 和 target_net,根据损失函数更新 policy_net 的参数,而 target_net 的参数固定不变。在经过多次迭代后,将 policy_net 的参数全部复制给 target_net,并一直如此迭代。这样一段时间内的target_net是固定不变的,从而使得算法更新更加稳定

代码

详细代码略,这里推荐pytorch官方DQN代码,比较具体。

参考资料

Q-learning

DQN简介

REINFORCEMENT LEARNING (DQN) TUTORIAL

Pytorch实现简单DQN