Deep Q Learning

Deep Q-learning or deep Q-network (DQN) is one of the earliest and most successful algorithms that introduce deep neural networks into RL.

DQN is not Q learning, at least not the specific Q learning algorithm introduced in my post, but it shares the core idea of Q learning.

Sources:

  1. Shiyu Zhao. Chapter 8: Value Function Approximation. Mathematical Foundations of Reinforcement Learning.
  2. DQN 2013 paper
  3. Reinforcement Learning Explained Visually (Part 5): Deep Q Networks, step-by-step by Ketan Doshi
  4. My github repo for DQN

Deep Q learning

Recalling that, we can use following update rule to do Q-learning with function approximation \[ w_{t+1}=w_t+\alpha_t\left[r_{t+1}+\gamma \color{orange}{\max _{a \in \mathcal{A}\left(s_{t+1}\right)} \hat{q}\left(s_{t+1}, a, w_t\right)}-\hat{q}\left(s_t, a_t, w_t\right)\right] \nabla_w \hat{q}\left(s_t, a_t, w_t\right) . \] However, this process is not so easy-to-implement with modern deep learning tools1. In contrast, deep Q-learning aims to minimize the loss function: \[ \color{purple}{J(w)=\mathbb{E}\left[\left(R+\gamma \max _{a \in \mathcal{A}\left(S^{\prime}\right)} \hat{q}\left(S^{\prime}, a, w\right)-\hat{q}(S, A, w)\right)^2\right]}, \] where \(\left(S, A, R, S^{\prime}\right)\) are random variables. - This is actually the Bellman optimality error. That is because \[ q(s, a)=\mathbb{E}\left[R_{t+1}+\gamma \max _{a \in \mathcal{A}\left(S_{t+1}\right)} q\left(S_{t+1}, a\right) \mid S_t=s, A_t=a\right], \quad \forall s, a \]

The value of \(R+\gamma \max _{a \in \mathcal{A}\left(S^{\prime}\right)} \hat{q}\left(S^{\prime}, a, w\right)-\hat{q}(S, A, w)\) should be zero in the expectation sense

Techniques

The target network

The gradient of the loss function is hard to compute since, in the equation, the parameter \(w\) not only appears in \(\hat{q}(S, A, w)\) but also in \[ y \doteq R+\gamma \max _{a \in \mathcal{A}\left(S^{\prime}\right)} \hat{q}\left(S^{\prime}, a, w\right) \] For the sake of simplicity, we assume that \(w\) in \(y\) is fixed (at least for a while) when we calculate the gradient.

NOTE: This assumtion looks not rigorous, in fact, I think it's an engineering choice and I don't have a mathmatical proof of the convergence of DQN under this assumption. There are some intuitive explanations for it, tough.

To do that, we introduce two networks. - One is a main network representing \(\hat{q}(s, a, w)\) - The other is a target network \(\hat{q}\left(s, a, w_T\right)\).

The loss function in this case degenerates to \[ J=\mathbb{E}\left[\left(R+\gamma \max _{a \in \mathcal{A}\left(S^{\prime}\right)} \color{red}{\hat{q}\left(S^{\prime}, a, w_T\right)}-\color{blue}{\hat{q}(S, A, w)}\right)^2\right] \] where \(w_T\) is the target network parameter.

Implementation details: - Let \(w\) and \(w_T\) denote the parameters of the main and target networks, respectively. They are set to be the same initially.

  • In every iteration, we draw a mini-batch of samples \(\left\{\left(s, a, r, s^{\prime}\right)\right\}\) from the replay buffer (will be explained later).

  • The inputs of the networks include state \(s\) and action \(a\). The target output is \[ y_T \doteq r+\gamma \max _{a \in \mathcal{A}\left(s^{\prime}\right)} \color{red}{\hat{q}\left(s^{\prime}, a, w_T\right)} . \] Then, we directly minimize the TD error or called loss function \(\left(y_T-\hat{q}(s, a, w)\right)^2\) over the mini-batch \(\left\{\left(s, a, y_T\right)\right\}\).

  • As in Q learning, the action used in \(\hat{q}(s, a, w)\) is called the current action, and the action used in \(\hat{q}\left(s, a, w_T\right)\) (or \(y_T\)) is called the target action.

Replay buffer

Replay buffer (or experience replay in the original paper) is used in DQN to make the samples i.i.d.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from collections import deque

class ReplayBuffer(object):
def __init__(self, capacity):
self.buffer = deque(maxlen=capacity)

def push(self, state, action, reward, next_state, done):
# print("State shape:", state.shape)
# print("Next state shape:", next_state.shape)

state = np.expand_dims(state, 0)
next_state = np.expand_dims(next_state, 0)

self.buffer.append((state, action, reward, next_state, done))

def sample(self, batch_size):
state, action, reward, next_state, done = zip(*random.sample(self.buffer, batch_size))
return np.concatenate(state), action, reward, np.concatenate(next_state), done

def __len__(self):
return len(self.buffer)

Pesudocode of DQN

Algorithm 8.3

Target network

Firstly, it is possible to build a DQN with a single Q Network and no Target Network. In that case, we do two passes through the Q Network, first to output the Predicted Q value, and then to output the Target Q value.

But that could create a potential problem. The Q Network’s weights get updated at each time step, which improves the prediction of the Predicted Q value. However, since the network and its weights are the same, it also changes the direction of our predicted Target Q values. They do not remain steady but can fluctuate after each update. This is like chasing a moving target.

By employing a second network that doesn’t get trained, we ensure that the Target Q values remain stable, at least for a short period. But those Target Q values are also predictions after all and we do want them to improve, so a compromise is made. After a pre-configured number of time-steps, the learned weights from the Q Network are copied over to the Target Network.

This is like EMA?

DQN

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class DQN(nn.Module):
def __init__(self, num_inputs, num_actions, device='cuda'):
super(DQN, self).__init__()

self.device = device
self.num_inputs = num_inputs
self.num_actions = num_actions
self.layers = nn.Sequential(
nn.Linear(num_inputs, 128),
nn.ReLU(),
nn.Linear(128, 128),
nn.ReLU(),
nn.Linear(128, num_actions)
).to(self.device)

def forward(self, x):
return self.layers(x)

def act(self, state, epsilon):
if random.random() > epsilon:
with torch.no_grad(): # Ensures that no gradients are computed, which saves memory and computations
state = torch.FloatTensor(state).unsqueeze(0).to(self.device) # Convert state to tensor and add batch dimension
q_value = self.forward(state) # Get Q-values for all actions
action = q_value.max(1)[1].item() # Get the action with the maximum Q-value and convert to integer
else:
action = random.randrange(self.num_actions)
return action


Loss

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
def compute_td_loss(batch_size, replay_buffer, model, gamma, optimizer):
state, action, reward, next_state, done = replay_buffer.sample(batch_size)

# Convert numpy arrays to torch tensors
state = torch.FloatTensor(state).to(model.device)
next_state = torch.FloatTensor(next_state).to(model.device)
action = torch.LongTensor(action).to(model.device)
reward = torch.FloatTensor(reward).to(model.device)
done = torch.FloatTensor(done).to(model.device)

# Compute Q-values for current states
q_values = model(state)

# Compute Q-values for next states using no gradient computation to speed up and reduce memory usage
with torch.no_grad():
next_q_values = model(next_state)
next_q_value = next_q_values.max(1)[0] # Get the max Q-value along the action dimension

# Calculate the expected Q-values
expected_q_value = reward + gamma * next_q_value * (1 - done)

# Compute the loss between actual Q values and the expected Q values
q_value = q_values.gather(1, action.unsqueeze(1)).squeeze(1)
loss = (q_value - expected_q_value.detach()).pow(2).mean() # Detach expected_q_value to prevent gradients from flowing

# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()

return loss

Training process

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
num_frames = 10000
batch_size = 32
gamma = 0.99

losses = []
all_rewards = []
episode_reward = 0

state, info = env.reset()
for frame_idx in range(1, num_frames + 1):
epsilon = epsilon_by_frame(frame_idx)
action = model.act(state, epsilon)

# print(f"Select action: {action}. type: {type(action)}")
next_state, reward, terminated, truncated, info = env.step(action)


replay_buffer.push(state, action, reward, next_state, terminated)

state = next_state
episode_reward += reward

if terminated:
state, info = env.reset()
all_rewards.append(episode_reward)
episode_reward = 0

if len(replay_buffer) > batch_size:
loss = compute_td_loss(batch_size, replay_buffer, model, gamma, optimizer)
losses.append(loss.data.item())


if frame_idx % 200 == 0:
plot(frame_idx, all_rewards, losses)

  1. OKay I know this explanation is ridiculous. In my opinion, the true reason of we favoring DQN over Q-learning with function approximation is simply the fact that DQN works better.↩︎