Recalling that, we can use following update rule to do Q-learning with function approximation However, this process is not so easy-to-implement with modern deep learning tools1. In contrast, deep Q-learning aims to minimize the loss function: where are random variables. - This is actually the Bellman optimality error. That is because
The value of should be zero in the expectation sense
Techniques
The target network
The gradient of the loss function is hard to compute since, in the equation, the parameter not only appears in but also in For the sake of simplicity, we assume that in is fixed (at least for a while) when we calculate the gradient.
NOTE: This assumtion looks not rigorous, in fact, I think it's an engineering choice and I don't have a mathmatical proof of the convergence of DQN under this assumption. There are some intuitive explanations for it, tough.
To do that, we introduce two networks. - One is a main network representing - The other is a target network .
The loss function in this case degenerates to where is the target network parameter.
Implementation details: - Let and denote the parameters of the main and target networks, respectively. They are set to be the same initially.
In every iteration, we draw a mini-batch of samples from the replay buffer (will be explained later).
The inputs of the networks include state and action . The target output is Then, we directly minimize the TD error or called loss function over the mini-batch .
As in Q learning, the action used in is called the current action, and the action used in (or ) is called the target action.
Replay buffer
Replay buffer (or experience replay in the original paper) is used in DQN to make the samples i.i.d.
Firstly, it is possible to build a DQN with a single Q Network and no Target Network. In that case, we do two passes through the Q Network, first to output the Predicted Q value, and then to output the Target Q value.
But that could create a potential problem. The Q Network’s weights get updated at each time step, which improves the prediction of the Predicted Q value. However, since the network and its weights are the same, it also changes the direction of our predicted Target Q values. They do not remain steady but can fluctuate after each update. This is like chasing a moving target.
By employing a second network that doesn’t get trained, we ensure that the Target Q values remain stable, at least for a short period. But those Target Q values are also predictions after all and we do want them to improve, so a compromise is made. After a pre-configured number of time-steps, the learned weights from the Q Network are copied over to the Target Network.
classDQN(nn.Module): def__init__(self, num_inputs, num_actions, device='cuda'): super(DQN, self).__init__() self.device = device self.num_inputs = num_inputs self.num_actions = num_actions self.layers = nn.Sequential( nn.Linear(num_inputs, 128), nn.ReLU(), nn.Linear(128, 128), nn.ReLU(), nn.Linear(128, num_actions) ).to(self.device) defforward(self, x): return self.layers(x) defact(self, state, epsilon): if random.random() > epsilon: with torch.no_grad(): # Ensures that no gradients are computed, which saves memory and computations state = torch.FloatTensor(state).unsqueeze(0).to(self.device) # Convert state to tensor and add batch dimension q_value = self.forward(state) # Get Q-values for all actions action = q_value.max(1)[1].item() # Get the action with the maximum Q-value and convert to integer else: action = random.randrange(self.num_actions) return action
# Compute Q-values for current states q_values = model(state)
# Compute Q-values for next states using no gradient computation to speed up and reduce memory usage with torch.no_grad(): next_q_values = model(next_state) next_q_value = next_q_values.max(1)[0] # Get the max Q-value along the action dimension
# Compute the loss between actual Q values and the expected Q values q_value = q_values.gather(1, action.unsqueeze(1)).squeeze(1) loss = (q_value - expected_q_value.detach()).pow(2).mean() # Detach expected_q_value to prevent gradients from flowing
state, info = env.reset() for frame_idx inrange(1, num_frames + 1): epsilon = epsilon_by_frame(frame_idx) action = model.act(state, epsilon) # print(f"Select action: {action}. type: {type(action)}") next_state, reward, terminated, truncated, info = env.step(action) replay_buffer.push(state, action, reward, next_state, terminated) state = next_state episode_reward += reward if terminated: state, info = env.reset() all_rewards.append(episode_reward) episode_reward = 0 iflen(replay_buffer) > batch_size: loss = compute_td_loss(batch_size, replay_buffer, model, gamma, optimizer) losses.append(loss.data.item()) if frame_idx % 200 == 0: plot(frame_idx, all_rewards, losses)
OKay I know this explanation is ridiculous. In my opinion, the true reason of we favoring DQN over Q-learning with function approximation is simply the fact that DQN works better.↩︎