1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
| import torch import torch.nn as nn import numpy as np from 神经网络 import Net from 模块导入和超参数 import MEMORY_CAPACITY,N_STATES,LR,EPSILON,N_ACTIONS,TARGET_REPLACE_ITER,BATCH_SIZE,GAMMA class DQN(object): def __init__(self): self.eval_net, self.target_net = Net(), Net()
self.learn_step_counter = 0 # 用于 target 更新计时 self.memory_counter = 0 # 记忆库记数 self.memory = np.zeros((MEMORY_CAPACITY, N_STATES * 2 + 2)) # 初始化记忆库 self.optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=LR) # torch 的优化器 self.loss_func = nn.MSELoss() # 误差公式
def choose_action(self, x): x = torch.unsqueeze(torch.FloatTensor(x), 0) # 这里只输入一个 sample if np.random.uniform() < EPSILON: # 选最优动作 actions_value = self.eval_net.forward(x) action = torch.max(actions_value, 1)[1].data.numpy()[0, 0] # return the argmax else: # 选随机动作 action = np.random.randint(0, N_ACTIONS) return action
def store_transition(self, s, a, r, s_): transition = np.hstack((s, [a, r], s_)) # 如果记忆库满了, 就覆盖老数据 index = self.memory_counter % MEMORY_CAPACITY self.memory[index, :] = transition self.memory_counter += 1
def learn(self): # target net 参数更新 if self.learn_step_counter % TARGET_REPLACE_ITER == 0: self.target_net.load_state_dict(self.eval_net.state_dict()) self.learn_step_counter += 1
# 抽取记忆库中的批数据 sample_index = np.random.choice(MEMORY_CAPACITY, BATCH_SIZE) b_memory = self.memory[sample_index, :] b_s = torch.FloatTensor(b_memory[:, :N_STATES]) b_a = torch.LongTensor(b_memory[:, N_STATES:N_STATES+1].astype(int)) b_r = torch.FloatTensor(b_memory[:, N_STATES+1:N_STATES+2]) b_s_ = torch.FloatTensor(b_memory[:, -N_STATES:])
# 针对做过的动作b_a, 来选 q_eval 的值, (q_eval 原本有所有动作的值) q_eval = self.eval_net(b_s).gather(1, b_a) # shape (batch, 1) q_next = self.target_net(b_s_).detach() # q_next 不进行反向传递误差, 所以 detach q_target = b_r + GAMMA * q_next.max(1)[0] # shape (batch, 1) loss = self.loss_func(q_eval, q_target)
# 计算, 更新 eval net self.optimizer.zero_grad() loss.backward() self.optimizer.step()
|