DQN 强化学习
是什么
強化學(xué)習(xí)(Reinforcement Learning, RL),又稱再勵學(xué)習(xí)、評價學(xué)習(xí)或增強學(xué)習(xí),是機器學(xué)習(xí)的范式和方法論之一,用于描述和解決智能體(agent)在與環(huán)境的交互過程中通過學(xué)習(xí)策略以達成回報最大化或?qū)崿F(xiàn)特定目標的問題。
模塊導(dǎo)入
import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import gym# 超參數(shù) BATCH_SIZE = 32 LR = 0.01 # learning rate EPSILON = 0.9 # 最優(yōu)選擇動作百分比 GAMMA = 0.9 # 獎勵遞減參數(shù) TARGET_REPLACE_ITER = 100 # Q 現(xiàn)實網(wǎng)絡(luò)的更新頻率 MEMORY_CAPACITY = 2000 # 記憶庫大小 env = gym.make('CartPole-v0') # 立桿子游戲 env = env.unwrapped N_ACTIONS = env.action_space.n # 桿子能做的動作 N_STATES = env.observation_space.shape[0] # 桿子能獲取的環(huán)境信息數(shù)神經(jīng)網(wǎng)絡(luò)
from torch import nn from 模塊導(dǎo)入和超參數(shù) import N_STATES,N_ACTIONS import torch.nn.functional as F class Net(nn.Module):def __init__(self, ):super(Net, self).__init__()self.fc1 = nn.Linear(N_STATES, 10)self.fc1.weight.data.normal_(0, 0.1) # initializationself.out = nn.Linear(10, N_ACTIONS)self.out.weight.data.normal_(0, 0.1) # initializationdef forward(self, x):x = self.fc1(x)x = F.relu(x)actions_value = self.out(x)return actions_valueDQN 模型
import torch import torch.nn as nn import numpy as np from 神經(jīng)網(wǎng)絡(luò) import Net from 模塊導(dǎo)入和超參數(shù) import MEMORY_CAPACITY,N_STATES,LR,EPSILON,N_ACTIONS,TARGET_REPLACE_ITER,BATCH_SIZE,GAMMA class DQN(object):def __init__(self):self.eval_net, self.target_net = Net(), Net()self.learn_step_counter = 0 # 用于 target 更新計時self.memory_counter = 0 # 記憶庫記數(shù)self.memory = np.zeros((MEMORY_CAPACITY, N_STATES * 2 + 2)) # 初始化記憶庫self.optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=LR) # torch 的優(yōu)化器self.loss_func = nn.MSELoss() # 誤差公式def choose_action(self, x):x = torch.unsqueeze(torch.FloatTensor(x), 0)# 這里只輸入一個 sampleif np.random.uniform() < EPSILON: # 選最優(yōu)動作actions_value = self.eval_net.forward(x)action = torch.max(actions_value, 1)[1].data.numpy()[0, 0] # return the argmaxelse: # 選隨機動作action = np.random.randint(0, N_ACTIONS)return actiondef store_transition(self, s, a, r, s_):transition = np.hstack((s, [a, r], s_))# 如果記憶庫滿了, 就覆蓋老數(shù)據(jù)index = self.memory_counter % MEMORY_CAPACITYself.memory[index, :] = transitionself.memory_counter += 1def learn(self):# target net 參數(shù)更新if self.learn_step_counter % TARGET_REPLACE_ITER == 0:self.target_net.load_state_dict(self.eval_net.state_dict())self.learn_step_counter += 1# 抽取記憶庫中的批數(shù)據(jù)sample_index = np.random.choice(MEMORY_CAPACITY, BATCH_SIZE)b_memory = self.memory[sample_index, :]b_s = torch.FloatTensor(b_memory[:, :N_STATES])b_a = torch.LongTensor(b_memory[:, N_STATES:N_STATES+1].astype(int))b_r = torch.FloatTensor(b_memory[:, N_STATES+1:N_STATES+2])b_s_ = torch.FloatTensor(b_memory[:, -N_STATES:])# 針對做過的動作b_a, 來選 q_eval 的值, (q_eval 原本有所有動作的值)q_eval = self.eval_net(b_s).gather(1, b_a) # shape (batch, 1)q_next = self.target_net(b_s_).detach() # q_next 不進行反向傳遞誤差, 所以 detachq_target = b_r + GAMMA * q_next.max(1)[0] # shape (batch, 1)loss = self.loss_func(q_eval, q_target)# 計算, 更新 eval netself.optimizer.zero_grad()loss.backward()self.optimizer.step()訓(xùn)練
from DQN模型 import DQN from 模塊導(dǎo)入和超參數(shù) import env, MEMORY_CAPACITYdqn = DQN() # 定義 DQN 系統(tǒng)for i_episode in range(400):s = env.reset()while True:env.render() # 顯示實驗動畫a = dqn.choose_action(s)# 選動作, 得到環(huán)境反饋s_, r, done, info = env.step(a)# 修改 reward, 使 DQN 快速學(xué)習(xí)x, x_dot, theta, theta_dot = s_r1 = (env.x_threshold - abs(x)) / env.x_threshold - 0.8r2 = (env.theta_threshold_radians - abs(theta)) / env.theta_threshold_radians - 0.5r = r1 + r2# 存記憶dqn.store_transition(s, a, r, s_)if dqn.memory_counter > MEMORY_CAPACITY:dqn.learn() # 記憶庫滿了就進行學(xué)習(xí)if done: # 如果回合結(jié)束, 進入下回合breaks = s_作者聲明
如有問題,歡迎指正!總結(jié)
- 上一篇: python数据字典ppt_[关系型数据
- 下一篇: 利用iconfont.css生成html