dann的alpha torch_一图解密AlphaZero(附Pytorch实践)
本來打算自己寫寫的,但是發現了David Foster的神作,看了就懂了。我也就不說啥了。
看不清的話,原圖在后面的連接也可以找到。
沒懂?!!!那我再解釋下。
AlphaGo Zero主要由三個部分組成:自我博弈(self-play),訓練和評估。和AlphaGo 比較,AlphaZero最大的區別在于,并沒有采用專家樣本進行訓練。通過自己和自己玩的方式產生出訓練樣本,通過產生的樣本進行訓練;更新的網絡和更新前的網絡比賽進行評估。
在開始的時候,整個系統開始依照當前最好的網絡參數進行自我博弈,那么假設進行了10000局的比賽,收集自我博弈過程中所得到的數據。這些數據當中包括:每一次的棋局狀態以及在此狀態下各個動作的概率(由蒙特卡羅搜索樹得到);每一局的獲勝得分以及所有棋局結束后的累積得分(勝利的+1分,失敗得-1分,最后各自累加得分),得到的數據全部會被放到一個大小為500000的數據存儲當中;然后隨機的從這個數據當中采樣2048個樣本,1000次迭代更新網絡。更新之后對網絡進行評估:采用當前被更新的網絡和未更新的網絡進行比賽400局,根據比賽的勝率來決定是否要接受當前更新的網絡。如果被更新的網絡獲得了超過55%的勝率,那么接收該被更新的網絡,否則不接受。
那么我們首先來看一下AlphaZero的輸入的棋局狀態到底是什么。如圖所示,是一個大小為19*19*17的數據,表示的是17張大小為19*19(和棋盤的大小相等)的特征圖。其中,8張屬于白子,8張屬于黑子,標記為1的地方表示有子,否則標記為0 。剩下的一張用全1或者是全0表示當前輪到 黑子還是白子了。構成的這個數據表示游戲的狀態輸入到網絡當中進行訓練。
那么我們來看一下,AlphaZero的網絡到底是怎么樣的呢?
這個網絡主要由三個部分組成:由40層殘差網絡構成的特征提取網絡(身體),以及價值網絡以及策略網絡(兩個頭)。該網絡當中價值網絡所輸出的值作為當前的狀態的價值估計; 策略網絡的輸出作為一個狀態到動作的映射概率。而這兩個部分的輸出都被引入到蒙特卡羅搜索樹當中,用來指導最終的下棋決策。那么顯然,價值網絡輸出的是一個1D的標量值,在-1到1之間;策略網絡輸出的是一個19*19*1的特征圖,其中的每一個點表示的是下棋到該位置的概率。那我們來看一下,該網絡是如何指導蒙特卡羅搜索樹的。
如圖所示,在圖中的搜索樹當中,黑色的點表示的是從一個狀態過渡到另一個狀態的動作a;其余的節點表示的是棋局的狀態,也就是之前所說的輸入。從一個非葉子節點的狀態開始,往往存在多種可能的行動,而其中的狀態節點a具有4種屬性,他們決定了到底應該如何選擇。具體來講,其中的N表示的是到目前為止,該動作節點被訪問的次數;P表示網絡預測出來的選擇該節點的概率;W表示下一個狀態的總的價值,而價值網絡輸出的動作的價值會被累及到這個值當中;這個值除以被訪問到的次數就等于平均的價值Q。實際上,還會給Q加上一個U來起到探索更多的動作的效果。我想應該是非常清楚的。那么如何根據構建出來的搜索樹進行下棋的步驟呢?在一定的閾值范圍內(比如說,1000個迭代之前),采用最大化Q函數的方式來選擇動作;那么當大于這個閾值之后采用蒙特卡羅搜索樹的方式(例如PUCT算法,也就是根據概率和被訪問的次數)來選擇執行的動作。
那我們來看一下蒙特卡羅搜索樹在這里面時如何實現的。首先是其中的節點:
class Node:
def __init__(self, parent=None, proba=None, move=None):
self.p = proba
self.n = 0
self.w = 0
self.q = 0
self.children = []
self.parent = parent
self.move = move
其中主要為之前所說的4個屬性以及父子節點的指針。而最后一個move指出了在當前狀態下的合法下棋步驟。在訓練的過程中,這些值都會被更新,那么在更新之后如何通過他們來進行動作的選擇呢?
def select(nodes, c_puct=C_PUCT):
" Optimized version of the selection based of the PUCT formula "
total_count = 0
for i in range(nodes.shape[0]):
total_count += nodes[i][1]
action_scores = np.zeros(nodes.shape[0])
for i in range(nodes.shape[0]):
action_scores[i] = nodes[i][0] + c_puct * nodes[i][2] * \
(np.sqrt(total_count) / (1 + nodes[i][1]))
equals = np.where(action_scores == np.max(action_scores))[0]
if equals.shape[0] > 0:
return np.random.choice(equals)
return equals[0]
這里表示的是對于任何一個節點,從其所有的子節點當中,通過PUCT算法找出最大得分的那個節點。在這個得分action_scores[i]的計算過程中,網絡預測的概率和該節點被訪問的次數都有被考慮。對于被訪問到的非葉子節點繼續進行擴展;而如果是葉子節點則進行最終的評估。至于其中的殘差網絡模塊,價值網絡,策略網絡就不再一一敘述了。詳細參考:https://github.com/dylandjian/superGo?github.com
References:
總結
以上是生活随笔為你收集整理的dann的alpha torch_一图解密AlphaZero(附Pytorch实践)的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: python二维数据读取对齐_从投影的二
- 下一篇: 制作碳排放强度的空间可视化_【科研成果】