深度强化学习之:PPO训练红白机1942
本篇是深度強化學習動手系列文章,自MyEncyclopedia公眾號文章深度強化學習之:DQN訓練超級瑪麗闖關發布后收到不少關注和反饋,這一期,讓我們實現目前主流深度強化學習算法PPO來打另一個紅白機經典游戲1942。
相關文章鏈接如下:
強化學習開源環境集
視頻論文解讀:PPO算法
視頻論文解讀:組合優化的強化學習方法
解讀TRPO論文,深度強化學習結合傳統優化方法
解讀深度強化學習基石論文:函數近似的策略梯度方法
NES 1942 環境安裝
紅白機游戲環境可以由OpenAI Retro來模擬,OpenAI Retro還在 Gym 集成了其他的經典游戲環境,包括Atari 2600,GBA,SNES等。
不過,受到版權原因,除了一些基本的rom,大部分游戲需要自行獲取rom。
環境準備部分相關代碼如下
pip?install?gym-retro python?-m?retro.import?/path/to/your/ROMs/directory/OpenAI Gym 輸入動作類型
在創建 retro 環境時,可以在retro.make中通過參數use_restricted_actions指定 action space,即按鍵的配置。
env?=?retro.make(game='1942-Nes',?use_restricted_actions=retro.Actions.FILTERED)可選參數如下,FILTERED,DISCRETE和MULTI_DISCRETE 都可以指定過濾的動作,過濾動作需要通過配置文件加載。
class?Actions(Enum):"""Different?settings?for?the?action?space?of?the?environment"""ALL?=?0??#:?MultiBinary?action?space?with?no?filtered?actionsFILTERED?=?1??#:?MultiBinary?action?space?with?invalid?or?not?allowed?actions?filtered?outDISCRETE?=?2??#:?Discrete?action?space?for?filtered?actionsMULTI_DISCRETE?=?3??#:?MultiDiscete?action?space?for?filtered?actionsDISCRETE和MULTI_DISCRETE 是 Gym 里的 Action概念,它們的基類都是gym.spaces.Space,可以通過 sample()方法采樣,下面具體一一介紹。
Discrete:對應一維離散空間,例如,Discrete(n=4) 表示 [0, 3] 范圍的整數。
輸出是
3Box:對應多維連續空間,每一維的范圍可以用 [low,high] 指定。舉例,Box(low=-1.0, high=2, shape=(3, 4,), dtype=np.float32) 表示 shape 是 [3, 4],每個范圍在 [-1, 2] 的float32型 tensor。
輸出是
[[-0.7538084???0.96901214??0.38641307?-0.05045208][-0.85486996??1.3516271???0.3222616???1.2540635?][-0.29908678?-0.8970335???1.4869047???0.7007356?]]MultiBinary: 0或1的多維離散空間。例如,MultiBinary([3,2]) 表示 shape 是3x2的0或1的tensor。
輸出是
[[1?0][1?1][0?0]]MultiDiscrete:多維整型離散空間。例如,MultiDiscrete([5,2,2]) 表示三維Discrete空間,第一維范圍在 [0-4],第二,三維范圍在[0-1]。
輸出是
[2?1?0]Tuple:組合成 tuple 復合空間。舉例來說,可以將 Box,Discrete,Discrete組成tuple 空間:Tuple(spaces=(Box(low=-1.0, high=1.0, shape=(3,), dtype=np.float32), Discrete(n=3), Discrete(n=2)))
輸出是
(array([?0.22640526,??0.75286865,?-0.6309239?],?dtype=float32),?0,?1)Dict:組合成有名字的復合空間。例如,Dict({'position':Discrete(2), 'velocity':Discrete(3)})
輸出是
OrderedDict([('position',?1),?('velocity',?1)])NES 1942 動作空間配置
了解了 gym/retro 的動作空間,我們來看看1942的默認動作空間
env?=?retro.make(game='1942-Nes') print("The?size?of?action?is:?",?env.action_space.shape) The?size?of?action?is:??(9,)表示有9個 Discrete 動作,包括 start, select這些控制鍵。
從訓練1942角度來說,我們希望指定最少的有效動作取得最好的成績。根據經驗,我們知道這個游戲最重要的鍵是4個方向加上 fire 鍵。限定游戲動作空間,官方的做法是在創建游戲環境時,指定預先生成的動作輸入配置文件。但是這個方式相對麻煩,我們采用了直接指定按鍵的二進制表示來達到同樣的目的,此時,需要設置 use_restricted_actions=retro.Actions.FILTERED。
下面的代碼限制了6種按鍵,并隨機play。
action_list?=?[#?No?Operation[0,?0,?0,?0,?0,?0,?0,?0,?0,?0,?0,?0],#?Left[0,?0,?0,?0,?0,?0,?1,?0,?0,?0,?0,?0],#?Right[0,?0,?0,?0,?0,?0,?0,?1,?0,?0,?0,?0],#?Down[0,?0,?0,?0,?0,?1,?0,?0,?0,?0,?0,?0],#?Up[0,?0,?0,?0,?1,?0,?0,?0,?0,?0,?0,?0],#?B[1,?0,?0,?0,?0,?0,?0,?0,?0,?0,?0,?0], ]def?random_play(env,?action_list,?sleep_seconds=0.01):env.viewer?=?Nonestate?=?env.reset()score?=?0for?j?in?range(10000):env.render()time.sleep(sleep_seconds)action?=?np.random.randint(len(action_list))next_state,?reward,?done,?_?=?env.step(action_list[action])state?=?next_statescore?+=?rewardif?done:print("Episode?Score:?",?score)env.reset()breakenv?=?retro.make(game='1942-Nes',?use_restricted_actions=retro.Actions.FILTERED) random_play(env,?action_list)來看看其游戲效果,全隨機死的還是比較快。
?圖像輸入處理
一般對于通過屏幕像素作為輸入的RL end-to-end訓練來說,對圖像做預處理很關鍵。因為原始圖像較大,一方面我們希望能盡量壓縮圖像到比較小的tensor,另一方面又要保證關鍵信息不丟失,比如子彈的圖像不能因為圖片縮小而消失。另外的一個通用技巧是將多個連續的frame合并起來組成立體的frame,這樣可以有效表示連貫動作。
下面的代碼通過 pipeline 將游戲每幀原始圖像從shape (224, 240, 3) 轉換成 (4, 84, 84),也就是原始的 width=224,height=240,rgb=3轉換成 width=84,height=240,stack_size=4的黑白圖像。具體 pipeline為
MaxAndSkipEnv:每兩幀過濾一幀圖像,減少數據量。
FrameDownSample:down sample 圖像到指定小分辨率 84x84,并從彩色降到黑白。
FrameBuffer:合并連續的4幀,形成 (4, 84, 84) 的圖像輸入
觀察圖像維度變換
env?=?retro.make(game='1942-Nes',?use_restricted_actions=retro.Actions.FILTERED) print("Initial?shape:?",?env.observation_space.shape)env?=?build_env(env) print("Processed?shape:?",?env.observation_space.shape)確保shape 從 (224, 240, 3) 轉換成 (4, 84, 84)
Initial?shape:??(224,?240,?3) Processed?shape:??(4,?84,?84)FrameDownSample實現如下,我們使用了 cv2 類庫來完成黑白化和圖像縮放
class?FrameDownSample(ObservationWrapper):def?__init__(self,?env,?exclude,?width=84,?height=84):super(FrameDownSample,?self).__init__(env)self.exclude?=?excludeself.observation_space?=?Box(low=0,high=255,shape=(width,?height,?1),dtype=np.uint8)self._width?=?widthself._height?=?heightdef?observation(self,?observation):#?convert?image?to?gray?scalescreen?=?cv2.cvtColor(observation,?cv2.COLOR_RGB2GRAY)#?crop?screen?[up:?down,?left:?right]screen?=?screen[self.exclude[0]:self.exclude[2],?self.exclude[3]:self.exclude[1]]#?to?float,?and?normalizedscreen?=?np.ascontiguousarray(screen,?dtype=np.float32)?/?255#?resize?imagescreen?=?cv2.resize(screen,?(self._width,?self._height),?interpolation=cv2.INTER_AREA)return?screenMaxAndSkipEnv,每兩幀過濾一幀
class?MaxAndSkipEnv(Wrapper):def?__init__(self,?env=None,?skip=4):super(MaxAndSkipEnv,?self).__init__(env)self._obs_buffer?=?deque(maxlen=2)self._skip?=?skipdef?step(self,?action):total_reward?=?0.0done?=?Nonefor?_?in?range(self._skip):obs,?reward,?done,?info?=?self.env.step(action)self._obs_buffer.append(obs)total_reward?+=?rewardif?done:breakmax_frame?=?np.max(np.stack(self._obs_buffer),?axis=0)return?max_frame,?total_reward,?done,?infodef?reset(self):self._obs_buffer.clear()obs?=?self.env.reset()self._obs_buffer.append(obs)return?obsFrameBuffer,將最近的4幀合并起來
class?FrameBuffer(ObservationWrapper):def?__init__(self,?env,?num_steps,?dtype=np.float32):super(FrameBuffer,?self).__init__(env)obs_space?=?env.observation_spaceself._dtype?=?dtypeself.observation_space?=?Box(low=0,?high=255,?shape=(num_steps,?obs_space.shape[0],?obs_space.shape[1]),?dtype=self._dtype)def?reset(self):frame?=?self.env.reset()self.buffer?=?np.stack(arrays=[frame,?frame,?frame,?frame])return?self.bufferdef?observation(self,?observation):self.buffer[:-1]?=?self.buffer[1:]self.buffer[-1]?=?observationreturn?self.buffer最后,visualize 處理后的圖像,同樣還是在隨機play中,確保關鍵信息不丟失
def?random_play_preprocessed(env,?action_list,?sleep_seconds=0.01):import?matplotlib.pyplot?as?pltenv.viewer?=?Nonestate?=?env.reset()score?=?0for?j?in?range(10000):time.sleep(sleep_seconds)action?=?np.random.randint(len(action_list))plt.imshow(state[-1],?cmap="gray")plt.title('Pre?Processed?image')plt.pause(sleep_seconds)next_state,?reward,?done,?_?=?env.step(action_list[action])state?=?next_statescore?+=?rewardif?done:print("Episode?Score:?",?score)env.reset()breakmatplotlib 動畫輸出
?CNN Actor & Critic
Actor 和 Critic 模型相同,輸入是 (4, 84, 84) 的圖像,輸出是 [0, 5] 的action index。
class?Actor(nn.Module):def?__init__(self,?input_shape,?num_actions):super(Actor,?self).__init__()self.input_shape?=?input_shapeself.num_actions?=?num_actionsself.features?=?nn.Sequential(nn.Conv2d(input_shape[0],?32,?kernel_size=8,?stride=4),nn.ReLU(),nn.Conv2d(32,?64,?kernel_size=4,?stride=2),nn.ReLU(),nn.Conv2d(64,?64,?kernel_size=3,?stride=1),nn.ReLU())self.fc?=?nn.Sequential(nn.Linear(self.feature_size(),?512),nn.ReLU(),nn.Linear(512,?self.num_actions),nn.Softmax(dim=1))def?forward(self,?x):x?=?self.features(x)x?=?x.view(x.size(0),?-1)x?=?self.fc(x)dist?=?Categorical(x)return?distPPO核心代碼
先計算 ,這里采用了一個技巧,對 取 log,相減再取 exp,這樣可以增強數值穩定性。
dist?=?self.actor_net(state) new_log_probs?=?dist.log_prob(action) ratio?=?(new_log_probs?-?old_log_probs).exp() surr1?=?ratio?*?advantagesurr1 對應PPO論文中的
?然后計算 surr2,對應 中的 clip 部分,clip可以由 torch.clamp 函數實現。 則對應 actor_loss。
surr2?=?torch.clamp(ratio,?1.0?-?self.clip_param,?1.0?+?self.clip_param)?*?advantage actor_loss?=?-?torch.min(surr1,?surr2).mean()?
最后,計算總的 loss ,包括 actor_loss,critic_loss 和 policy的 entropy。
entropy?=?dist.entropy().mean()critic_loss?=?(return_?-?value).pow(2).mean() loss?=?actor_loss?+?0.5?*?critic_loss?-?0.001?*?entropy?上述完整代碼如下
補充一下 GAE 的計算,advantage 根據公式
可以轉換成如下代碼
def?compute_gae(self,?next_value):gae?=?0returns?=?[]values?=?self.values?+?[next_value]for?step?in?reversed(range(len(self.rewards))):delta?=?self.rewards[step]?+?self.gamma?*?values[step?+?1]?*?self.masks[step]?-?values[step]gae?=?delta?+?self.gamma?*?self.tau?*?self.masks[step]?*?gaereturns.insert(0,?gae?+?values[step])return?returns外層 Training 代碼
外層調用代碼基于隨機 play 的邏輯,agent.act()封裝了采樣和 forward prop,agent.step() 則封裝了 backprop 和參數學習迭代的邏輯。
for?i_episode?in?range(start_epoch?+?1,?n_episodes?+?1):state?=?env.reset()score?=?0timestamp?=?0while?timestamp?<?10000:action,?log_prob,?value?=?agent.act(state)next_state,?reward,?done,?info?=?env.step(action_list[action])score?+=?rewardtimestamp?+=?1agent.step(state,?action,?value,?log_prob,?reward,?done,?next_state)if?done:breakelse:state?=?next_state訓練結果
讓我們來看看學習的效果吧,注意我們的飛機學到了一些關鍵的技巧,躲避子彈;飛到角落盡快擊斃敵機;一定程度預測敵機出現的位置并預先走到位置。
往期精彩回顧適合初學者入門人工智能的路線及資料下載機器學習及深度學習筆記等資料打印機器學習在線手冊深度學習筆記專輯《統計學習方法》的代碼復現專輯 AI基礎下載機器學習的數學基礎專輯 本站qq群851320808,加入微信群請掃碼:
總結
以上是生活随笔為你收集整理的深度强化学习之:PPO训练红白机1942的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 火狐浏览器如何设置启动页面
- 下一篇: Win11如何跳过开机更新 Win11跳