深度学习总结:DQN原理,算法及pytorch方式实现
生活随笔
收集整理的這篇文章主要介紹了
深度学习总结:DQN原理,算法及pytorch方式实现
小編覺(jué)得挺不錯(cuò)的,現(xiàn)在分享給大家,幫大家做個(gè)參考.
文章目錄
- Q-learning原理圖
- Q-learning算法描述:
- pytorch實(shí)現(xiàn):
- Q-network實(shí)現(xiàn):
- DQN實(shí)現(xiàn):
- 2個(gè)Q-network,其中一個(gè)為target Q-network;
- take action獲取下一步的動(dòng)作,這個(gè)部分就是和環(huán)境互動(dòng)的部分,選取動(dòng)作是基于e-greedy來(lái)的;
- store transmitions就是保存數(shù)據(jù),用于experience replay;
- 最重要的是學(xué)習(xí)過(guò)程:就是算法描述的核心部分, 需要針對(duì)minibatach的處理,需要做regression更新Q-network,還需要定期更新target Q-network。
- 訓(xùn)練實(shí)現(xiàn):優(yōu)化游戲環(huán)境的reward, 實(shí)現(xiàn)算法描述的for each episode(通過(guò)for range控制) for each time step(通過(guò)游戲返回的done終止)
Q-learning原理圖
Q-learning算法描述:
pytorch實(shí)現(xiàn):
Q-network實(shí)現(xiàn):
輸入s,輸出是Q(s,a_i)即所有action在s下對(duì)應(yīng)的Q值。
class Net(nn.Module):def __init__(self, ):super(Net, self).__init__()self.fc1 = nn.Linear(N_STATES, 50)self.fc1.weight.data.normal_(0, 0.1) # initializationself.out = nn.Linear(50, N_ACTIONS)self.out.weight.data.normal_(0, 0.1) # initializationdef forward(self, x):x = self.fc1(x)x = F.relu(x)actions_value = self.out(x)return actions_valueDQN實(shí)現(xiàn):
DQN包含:
2個(gè)Q-network,其中一個(gè)為target Q-network;
class DQN(object):def __init__(self):self.eval_net, self.target_net = Net(), Net()self.learn_step_counter = 0 # for target updatingself.memory_counter = 0 # for storing memoryself.memory = np.zeros((MEMORY_CAPACITY, N_STATES * 2 + 2)) # initialize memoryself.optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=LR)self.loss_func = nn.MSELoss()take action獲取下一步的動(dòng)作,這個(gè)部分就是和環(huán)境互動(dòng)的部分,選取動(dòng)作是基于e-greedy來(lái)的;
def choose_action(self, x):x = torch.unsqueeze(torch.FloatTensor(x), 0)# input only one sampleif np.random.uniform() < EPSILON: # greedyactions_value = self.eval_net.forward(x)action = torch.max(actions_value, 1)[1].data.numpy()action = action[0] if ENV_A_SHAPE == 0 else action.reshape(ENV_A_SHAPE) # return the argmax indexelse: # randomaction = np.random.randint(0, N_ACTIONS)action = action if ENV_A_SHAPE == 0 else action.reshape(ENV_A_SHAPE)return actionstore transmitions就是保存數(shù)據(jù),用于experience replay;
def store_transition(self, s, a, r, s_):transition = np.hstack((s, [a, r], s_))# replace the old memory with new memoryindex = self.memory_counter % MEMORY_CAPACITYself.memory[index, :] = transitionself.memory_counter += 1最重要的是學(xué)習(xí)過(guò)程:就是算法描述的核心部分, 需要針對(duì)minibatach的處理,需要做regression更新Q-network,還需要定期更新target Q-network。
def learn(self):# target parameter updateif self.learn_step_counter % TARGET_REPLACE_ITER == 0:self.target_net.load_state_dict(self.eval_net.state_dict())self.learn_step_counter += 1# sample batch transitionssample_index = np.random.choice(MEMORY_CAPACITY, BATCH_SIZE)b_memory = self.memory[sample_index, :]b_s = torch.FloatTensor(b_memory[:, :N_STATES])b_a = torch.LongTensor(b_memory[:, N_STATES:N_STATES+1].astype(int))b_r = torch.FloatTensor(b_memory[:, N_STATES+1:N_STATES+2])b_s_ = torch.FloatTensor(b_memory[:, -N_STATES:])# q_eval w.r.t the action in experienceq_eval = self.eval_net(b_s).gather(1, b_a) # shape (batch, 1)q_next = self.target_net(b_s_).detach() # detach from graph, don't backpropagateq_target = b_r + GAMMA * q_next.max(1)[0].view(BATCH_SIZE, 1) # shape (batch, 1)loss = self.loss_func(q_eval, q_target)self.optimizer.zero_grad()loss.backward()self.optimizer.step()訓(xùn)練實(shí)現(xiàn):優(yōu)化游戲環(huán)境的reward, 實(shí)現(xiàn)算法描述的for each episode(通過(guò)for range控制) for each time step(通過(guò)游戲返回的done終止)
dqn = DQN()print('\nCollecting experience...') for i_episode in range(400):s = env.reset()ep_r = 0while True:env.render()a = dqn.choose_action(s)# take actions_, r, done, info = env.step(a)# modify the rewardx, x_dot, theta, theta_dot = s_r1 = (env.x_threshold - abs(x)) / env.x_threshold - 0.8r2 = (env.theta_threshold_radians - abs(theta)) / env.theta_threshold_radians - 0.5r = r1 + r2dqn.store_transition(s, a, r, s_)ep_r += rif dqn.memory_counter > MEMORY_CAPACITY:dqn.learn()if done:print('Ep: ', i_episode,'| Ep_r: ', round(ep_r, 2))if done:breaks = s_總結(jié)
以上是生活随笔為你收集整理的深度学习总结:DQN原理,算法及pytorch方式实现的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: 深度学习总结:pytorch构建RNN和
- 下一篇: 深度学习总结:GAN,原理,算法描述,p