dqn在训练过程中loss越来越大_[动手学强化学习] 2.DQN解决CartPole-v0问题
生活随笔
收集整理的這篇文章主要介紹了
dqn在训练过程中loss越来越大_[动手学强化学习] 2.DQN解决CartPole-v0问题
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
強化學習如何入門:
強化學習怎么入門好??www.zhihu.com最近在整理之前寫的強化學習代碼,發現pytorch的代碼還是老版本的。
而pytorch今年更新了一個大版本,更到0.4了,很多老代碼都不兼容了,于是基于最新版重寫了一下 CartPole-v0這個環境的DQN代碼。
- 對代碼進行了簡化,網上其他很多代碼不是太老就是太亂;
- 增加了一個動態繪圖函數;
- 這次改動可以很快就達到200步,不過后期不穩定,還需要詳細調整下 探索-利用困境。
CartPole-v0環境:
Gym: A toolkit for developing and comparing reinforcement learning algorithms?gym.openai.comDQN CartPole-v0源碼,歡迎fork和star:
https://github.com/hangsz/reinforcement_learning?github.com需要安裝gym庫和pytorch
gym安裝方式:pip install gym
pytorch(選擇適合自己的版本): https://pytorch.org/get-started/locally/
動畫:
https://www.zhihu.com/video/1193285883359604736# coding: utf-8__author__ = 'zhenhang.sun@gmail.com' __version__ = '1.0.0'import gym import math import randomimport torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optimclass Net(nn.Module):def __init__(self, input_size, hidden_size, output_size):super().__init__()self.linear1 = nn.Linear(input_size, hidden_size)self.linear2 = nn.Linear(hidden_size, output_size)def forward(self, x):x = F.relu(self.linear1(x))x = self.linear2(x)return xclass Agent(object):def __init__(self, **kwargs):for key, value in kwargs.items():setattr(self, key, value)self.eval_net = Net(self.state_space_dim, 256, self.action_space_dim)self.optimizer = optim.Adam(self.eval_net.parameters(), lr=self.lr)self.buffer = []self.steps = 0def act(self, s0):self.steps += 1epsi = self.epsi_low + (self.epsi_high-self.epsi_low) * (math.exp(-1.0 * self.steps/self.decay))if random.random() < epsi:a0 = random.randrange(self.action_space_dim)else:s0 = torch.tensor(s0, dtype=torch.float).view(1,-1)a0 = torch.argmax(self.eval_net(s0)).item()return a0def put(self, *transition):if len( self.buffer)==self.capacity:self.buffer.pop(0)self.buffer.append(transition)def learn(self):if (len(self.buffer)) < self.batch_size:returnsamples = random.sample( self.buffer, self.batch_size)s0, a0, r1, s1 = zip(*samples)s0 = torch.tensor( s0, dtype=torch.float)a0 = torch.tensor( a0, dtype=torch.long).view(self.batch_size, -1)r1 = torch.tensor( r1, dtype=torch.float).view(self.batch_size, -1)s1 = torch.tensor( s1, dtype=torch.float)y_true = r1 + self.gamma * torch.max( self.eval_net(s1).detach(), dim=1)[0].view(self.batch_size, -1)y_pred = self.eval_net(s0).gather(1, a0)loss_fn = nn.MSELoss()loss = loss_fn(y_pred, y_true)self.optimizer.zero_grad()loss.backward()self.optimizer.step()# coding: utf-8__author__ = 'zhenhang.sun@gmail.com' __version__ = '1.0.0'import gym from IPython import display import matplotlib.pyplot as pltfrom dqn import Agentdef plot(score, mean):display.clear_output(wait=True)display.display(plt.gcf())plt.figure(figsize=(20,10))plt.clf()plt.title('Training...')plt.xlabel('Episode')plt.ylabel('Duration')plt.plot(score)plt.plot(mean)plt.text(len(score)-1, score[-1], str(score[-1]))plt.text(len(mean)-1, mean[-1], str(mean[-1]))if __name__ == '__main__':env = gym.make('CartPole-v0')params = {'gamma': 0.8,'epsi_high': 0.9,'epsi_low': 0.05,'decay': 200, 'lr': 0.001,'capacity': 10000,'batch_size': 64,'state_space_dim': env.observation_space.shape[0],'action_space_dim': env.action_space.n }agent = Agent(**params)score = []mean = []for episode in range(1000):s0 = env.reset()total_reward = 1while True:env.render()a0 = agent.act(s0)s1, r1, done, _ = env.step(a0)if done:r1 = -1agent.put(s0, a0, r1, s1)if done:breaktotal_reward += r1s0 = s1agent.learn()score.append(total_reward)mean.append( sum(score[-100:])/100)plot(score, mean)總結
以上是生活随笔為你收集整理的dqn在训练过程中loss越来越大_[动手学强化学习] 2.DQN解决CartPole-v0问题的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 请问坤耐隔音毡环保吗?
- 下一篇: 宝付支付代扣哪个公司贷款