Recalling that Q Learning builds a Q-table that maps state and action pairs to Q-values. However, in a real-world scenario, the number of states could be huge, making it computationally intractable to build a table.
To address this limitation we use a Q-function rather than a Q-table.
Replay buffer
Replay buffer (or experience replay in the original paper) is used in DQN to make the samples i.i.d.
Firstly, it is possible to build a DQN with a single Q Network and no Target Network. In that case, we do two passes through the Q Network, first to output the Predicted Q value, and then to output the Target Q value.
But that could create a potential problem. The Q Network’s weights get updated at each time step, which improves the prediction of the Predicted Q value. However, since the network and its weights are the same, it also changes the direction of our predicted Target Q values. They do not remain steady but can fluctuate after each update. This is like chasing a moving target.
By employing a second network that doesn’t get trained, we ensure that the Target Q values remain stable, at least for a short period. But those Target Q values are also predictions after all and we do want them to improve, so a compromise is made. After a pre-configured number of time-steps, the learned weights from the Q Network are copied over to the Target Network.
This is like EMA?
DQN
classDQN(nn.Module): def__init__(self, num_inputs, num_actions, device='cuda'): super(DQN, self).__init__() self.device = device self.num_inputs = num_inputs self.num_actions = num_actions self.layers = nn.Sequential( nn.Linear(num_inputs, 128), nn.ReLU(), nn.Linear(128, 128), nn.ReLU(), nn.Linear(128, num_actions) ).to(self.device) defforward(self, x): return self.layers(x) defact(self, state, epsilon): if random.random() > epsilon: with torch.no_grad(): # Ensures that no gradients are computed, which saves memory and computations state = torch.FloatTensor(state).unsqueeze(0).to(self.device) # Convert state to tensor and add batch dimension q_value = self.forward(state) # Get Q-values for all actions action = q_value.max(1)[1].item() # Get the action with the maximum Q-value and convert to integer else: action = random.randrange(self.num_actions) return action
# Compute Q-values for current states q_values = model(state)
# Compute Q-values for next states using no gradient computation to speed up and reduce memory usage with torch.no_grad(): next_q_values = model(next_state) next_q_value = next_q_values.max(1)[0] # Get the max Q-value along the action dimension
# Compute the loss between actual Q values and the expected Q values q_value = q_values.gather(1, action.unsqueeze(1)).squeeze(1) loss = (q_value - expected_q_value.detach()).pow(2).mean() # Detach expected_q_value to prevent gradients from flowing