深度强化学习实战:Tensorflow实现DDPG - PaperWeekly 第48期
作者丨李國(guó)豪
學(xué)校丨中國(guó)科學(xué)院大學(xué)&上海科技大學(xué)
研究方向丨無(wú)人駕駛,強(qiáng)化學(xué)習(xí)
指導(dǎo)老師丨林寶軍教授
1. 前言
本文主要講解 DeepMind 發(fā)布在 ICLR 2016 的文章 Continuous control with deep reinforcement learning,時(shí)間稍微有點(diǎn)久遠(yuǎn),但因?yàn)樗惴ń?jīng)典,還是值得去實(shí)現(xiàn)。
2. 環(huán)境
這次實(shí)驗(yàn)環(huán)境是 Openai Gym 的 Pendulum-v0,state 是 3 維連續(xù)的表示桿的位置方向信息,action 是 1 維的連續(xù)動(dòng)作,大小是 -2.0 到 2.0,表示對(duì)桿施加的力和方向。目標(biāo)是讓桿保持直立,所以 reward 在桿保持直立不動(dòng)的時(shí)候最大。筆者所用的環(huán)境為:?
Tensorflow (1.2.1)?
gym (0.9.2)?
請(qǐng)先安裝 Tensorflow 和 gym,Tensorflow 和 gym 的安裝就不贅述了,下面是網(wǎng)絡(luò)收斂后的結(jié)果。
class="video_iframe" data-vidtype="2" allowfullscreen="" frameborder="0" data-ratio="1" data-w="272" data-src="http://v.qq.com/iframe/player.html?vid=y1325jlix3j&width=650&height=487.5&auto=0" style="display: block; width: 650px !important; height: 487.5px !important;" width="650" height="487.5" data-vh="487.5" data-vw="650" src="http://v.qq.com/iframe/player.html?vid=y1325jlix3j&width=650&height=487.5&auto=0"/>
3. 代碼詳解
先貼一張 DeepMind 文章中的偽代碼,分析一下實(shí)現(xiàn)它,我們需要實(shí)現(xiàn)哪些東西:
4. 網(wǎng)絡(luò)結(jié)構(gòu)(model)
首先,我們需要實(shí)現(xiàn)一個(gè) critic network 和一個(gè) actor network,然后再實(shí)現(xiàn)一個(gè) target critic network 和 target actor network,并且對(duì)應(yīng)初始化為相同的 weights。下面來(lái)看看這部分代碼怎么實(shí)現(xiàn):
critic network & target critic network
上面是 critic network 的實(shí)現(xiàn),critic network是一個(gè)用神經(jīng)網(wǎng)絡(luò)去近似的一個(gè)函數(shù),輸入是 s-state,a-action,輸出是 Q 函數(shù),網(wǎng)絡(luò)參數(shù)是,在這里我的實(shí)現(xiàn)和原文類(lèi)似,state 經(jīng)過(guò)一個(gè)全連接層得到隱藏層特征 h1,action 經(jīng)過(guò)另外一個(gè)全連接層得到隱藏層特征 h2,然后特征串聯(lián)在一起得到 h_concat,之后 h_concat 再經(jīng)過(guò)一層全連接層得到 h3,最后 h3 經(jīng)過(guò)一個(gè)沒(méi)有激活函數(shù)的全連接層得到 q_output。這就簡(jiǎn)單得實(shí)現(xiàn)了一個(gè) critic network。
上面是target critic network的實(shí)現(xiàn),target critic network網(wǎng)絡(luò)結(jié)構(gòu)和 critic network 一樣,也參數(shù)初始化為一樣的權(quán)重,思路是先把 critic network 的權(quán)重取出來(lái)初始化,再調(diào)用一遍 self.__create_critic_network() 創(chuàng)建 target network,最后把 critic network 初始化的權(quán)重賦值給 target critic network。?
這樣我們就得到了 critic network 和 critic target network。?
actor network & actor target network?
actor network和 actor target network的實(shí)現(xiàn)與 critic 幾乎一樣,區(qū)別在于網(wǎng)絡(luò)結(jié)構(gòu)和激活函數(shù)。
這里用了 3 層全連接層,最后激活函數(shù)是 tanh,把輸出限定在 -1 到 1 之間。這樣大體的網(wǎng)絡(luò)結(jié)構(gòu)就實(shí)現(xiàn)完了。
5. Replay Buffer & Random Process(Mechanism)
接下來(lái),偽代碼提到 replay buffer 和 random process,這部分代碼比較簡(jiǎn)單也很短,主要參考了 openai 的 rllab 的實(shí)現(xiàn),大家可以直接看看源碼。
6.?網(wǎng)絡(luò)更新和損失函數(shù)(Model)
用梯度下降更新網(wǎng)絡(luò),先需要定義我們的 loss 函數(shù)。?
critic nework 更新
這里 critic 只是很簡(jiǎn)單的是一個(gè) L2 loss。不過(guò)由于 transition 是 s, a, r, s'。要得到 y 需要一步處理,下面是預(yù)處理 transition 的代碼。
訓(xùn)練模型是,從 Replay buffer 里取出一個(gè) mini-batch,在經(jīng)過(guò)預(yù)處理就可以更新我們的網(wǎng)絡(luò)了,是不是很簡(jiǎn)單。y 經(jīng)過(guò)下面這行代碼處理得到。
actor nework更新
actor network 的更新也很簡(jiǎn)單,我們需要求的梯度如上圖,首先我們需要critic network對(duì)動(dòng)作 a 的導(dǎo)數(shù),其中 a 是由 actor network 根據(jù)狀態(tài) s 估計(jì)出來(lái)的。代碼如下:
先根據(jù) actor network 估計(jì)出 action,再用 critic network 的輸出 q 對(duì)估計(jì)出來(lái)的 action 求導(dǎo)。?
然后我們把得到的這部分梯度,和 actor network 的輸出對(duì) actor network 的權(quán)重求導(dǎo)的梯度,相乘就能得到最后的梯度,代碼如下:
也就是說(shuō)我們需要求的 policy gradient 主要由下面這一行代碼求得,由于我們需要梯度下降去更新網(wǎng)絡(luò),所以需要加個(gè)負(fù)號(hào):
之后就是更新我們的 target network,target network 采用 soft update 的方式去穩(wěn)定網(wǎng)絡(luò)的變化,算法如下:
就這樣我們的整體網(wǎng)絡(luò)更新需要的東西都實(shí)現(xiàn)了,下面是整體網(wǎng)絡(luò)更新的代碼:
總體的細(xì)節(jié)都介紹完了,希望大家有所收獲。另外,完整代碼已放出,大家可以點(diǎn)擊“閱讀原文”訪問(wèn)我的 Github。
關(guān)于PaperWeekly
PaperWeekly 是一個(gè)推薦、解讀、討論、報(bào)道人工智能前沿論文成果的學(xué)術(shù)平臺(tái)。如果你研究或從事 AI 領(lǐng)域,歡迎在公眾號(hào)后臺(tái)點(diǎn)擊「交流群」,小助手將把你帶入 PaperWeekly 的交流群里。
▽ 點(diǎn)擊 |?閱讀原文?| 查看完整代碼
總結(jié)
以上是生活随笔為你收集整理的深度强化学习实战:Tensorflow实现DDPG - PaperWeekly 第48期的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: 来自闪闪宝石的光芒 - “宝石迷阵” x
- 下一篇: 评测任务实战:中文文本分类技术实践与分享