论文笔记之:Deep Attention Recurrent Q-Network
Deep Attention Recurrent Q-Network
5vision groups?
?
?摘要:本文將 DQN 引入了 Attention 機(jī)制,使得學(xué)習(xí)更具有方向性和指導(dǎo)性。(前段時(shí)間做一個(gè)工作打算就這么干,誰想到,這么快就被這幾個(gè)孩子給實(shí)現(xiàn)了,自愧不如啊( ⊙ o ⊙ ))
??引言:我們知道 DQN 是將連續(xù) 4幀的視頻信息輸入到 CNN 當(dāng)中,那么,這么做雖然取得了不錯(cuò)的效果,但是,仍然只是能記住這 4 幀的信息,之前的就會(huì)遺忘。所以就有研究者提出了 Deep Recurrent Q-Network (DRQN),一個(gè)結(jié)合 LSTM 和 DQN 的工作:
1. the fully connected layer in the latter is replaced for a LSTM one ,?
2. only the last visual frame at each time step is used as DQN's input.?
作者指出雖然只是使用了一幀的信息,但是 DRQN 仍然抓住了幀間的相關(guān)信息。盡管如此,仍然沒有看到在 Atari game上有系統(tǒng)的提升。
?
另一個(gè)缺點(diǎn)是:長時(shí)間的訓(xùn)練時(shí)間。據(jù)說,在單個(gè) GPU 上訓(xùn)練時(shí)間達(dá)到 12-14天。于是,有人就提出了并行版本的算法來提升訓(xùn)練速度。作者認(rèn)為并行計(jì)算并不是唯一的,最有效的方法來解決這個(gè)問題。
最近 visual attention models 在各個(gè)任務(wù)上都取得了驚人的效果。利用這個(gè)機(jī)制的優(yōu)勢在于:僅僅需要選擇然后注意一個(gè)較小的圖像區(qū)域,可以幫助降低參數(shù)的個(gè)數(shù),從而幫助加速訓(xùn)練和測試。對比 DRQN,本文的 LSTM 機(jī)制存儲(chǔ)的數(shù)據(jù)不僅用于下一個(gè) actions 的選擇,也用于 選擇下一個(gè) Attention 區(qū)域。此外,除了計(jì)算速度上的改進(jìn)之外,Attention-based models 也可以增加 Deep Q-Learning 的可讀性,提供給研究者一個(gè)機(jī)會(huì)去觀察 agent 的集中區(qū)域在哪里以及是什么,(where and what)。
?
?
Deep Attention Recurrent Q-Network:
?
?
?如上圖所示,DARQN 結(jié)構(gòu)主要由 三種類型的網(wǎng)絡(luò)構(gòu)成:convolutional (CNN), attention, and recurrent . 在每一個(gè)時(shí)間步驟 t,CNN 收到當(dāng)前游戲狀態(tài) $s_t$ 的一個(gè)表示,根據(jù)這個(gè)狀態(tài)產(chǎn)生一組 D feature maps,每一個(gè)的維度是 m * m。Attention network 將這些 maps 轉(zhuǎn)換成一組向量 $v_t = \{ v_t^1, ... , v_t^L \}$,L = m*m,然后輸出其線性組合 $z_t$,稱為 a context vector. 這個(gè) recurrent network,在我們這里是 LSTM,將 context vector 作為輸入,以及 之前的 hidden state $h_{t-1}$,memory state $c_{t-1}$,產(chǎn)生 hidden state $h_t$ 用于:
1. a linear layer for evaluating Q-value of each action $a_t$ that the agent can take being in state $s_t$ ;?
2. the attention network for generating a context vector at the next time step t+1.?
?
?
Soft attention?:?
這一小節(jié)提到的 "soft" Attention mechanism 假設(shè) the context vector $z_t$ 可以表示為 所有向量 $v_t^i$ 的加權(quán)和,每一個(gè)對應(yīng)了從圖像不同區(qū)域提取出來的 CNN 特征。權(quán)重 和 這個(gè) vector 的重要程度成正比例,并且是通過 Attention network g 衡量的。g network 包含兩個(gè) fc layer 后面是一個(gè) softmax layer。其輸出可以表示為:
其中,Z是一個(gè)normalizing constant。W 是權(quán)重矩陣,Linear(x) = Ax + b 是一個(gè)放射變換,權(quán)重矩陣是A,偏差是 b。我們一旦定義出了每一個(gè)位置向量的重要性,我們可以計(jì)算出 context vector 為:
另一個(gè)網(wǎng)絡(luò)在第三小節(jié)進(jìn)行詳細(xì)的介紹。整個(gè) DARQN model 是通過最小化序列損失函數(shù)完成訓(xùn)練:
其中,$Y_t$ 是一個(gè)近似的 target value,為了優(yōu)化這個(gè)損失函數(shù),我們利用標(biāo)準(zhǔn)的 Q-learning 更新規(guī)則:
DARQN 中的 functions 都是可微分的,所以每一個(gè)參數(shù)都有梯度,整個(gè)模型可以 end-to-end 的進(jìn)行訓(xùn)練。本文的算法也借鑒了 target network 和 experience replay 的技術(shù)。
?
?
Hard Attention:
此處的 hard attention mechanism 采樣的時(shí)候要求僅僅從圖像中采樣一個(gè)圖像 patch。
假設(shè) $s_t$ 從環(huán)境中采樣的時(shí)候,受到了 attention policy 的影響,attention network g 的softmax layer 給出了帶參數(shù)的類別分布(categorical distribution)。然后,在策略梯度方法,策略參數(shù)的更新可以表示為:
其中 $R_t$ 是將來的折扣的損失。為了估計(jì)這個(gè)值,另一個(gè)網(wǎng)絡(luò) $G_t = Linear(h_t)$ 才引入進(jìn)來。這個(gè)網(wǎng)絡(luò)通過朝向 期望值 $Y_t$ 進(jìn)行網(wǎng)絡(luò)訓(xùn)練。Attention network 參數(shù)最終的更新采用如下的方式進(jìn)行:
?其中 $G_t - Y_t$ 是advantage function estimation。
作者提供了源代碼:https://github.com/5vision/DARQN ?
實(shí)驗(yàn)部分:
?
?
?
?
?
總結(jié): ?
?
?
?
?
?
?
?
?
總結(jié)
以上是生活随笔為你收集整理的论文笔记之:Deep Attention Recurrent Q-Network的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: mdp框架_强化学习:MDP(Marko
- 下一篇: arm rtx教程_ARM RTX操作系