Value Function Approximation

In the previous post, we introduced TD learning algorithms. At that time, all state/action values were represented by tables. This is inefficient for handling large state or action spaces.

In this post, we will use the function approximation method for TD learning. It is also where artificial neural networks are incorporated into reinforcement learning as function approximators.

Sources:

  1. Shiyu Zhao. Chapter 8: Value Function Approximation. Mathematical Foundations of Reinforcement Learning.

TD learning with function approximation

With function approximation methods, the goal of TD learning is equivalent to is finding the best w that can minimize the loss function J(w) (1)J(w)=EsS[(vπ(S)v^(S,w))2] # Stationary distribution

The expectation in (1) is with respect to the random variable SS. There are several ways to define the probability distribution of S.

In reinforcement learning (RL), we commonly use the stationary distribution of S as the distribution of S.

The stationary distribution of S under policy π can bedenoted by {dπ(s)}sS. By definition, dπ(s)0 and sSdπ(s)=1.

Let nπ(s) denote the number of times that s has been visited in a very ong episode generated by π. Then, dπ(s) can be approximated by dπ(s)nπ(s)sSnπ(s) Meanwhile, the converged values dπ(s) can be computed directly by solving equation: dπT=dπTPπ, i.e., dπ is the left eigenvector of Pπ associated with the eigenvalue 1. The proof is here.

Optimization algorithms

The loss function (1) can be rewritten as J(w)=E[(vπ(S)v^(S,w))2]=sSdπ(s)(vπ(s)v^(s,w))2

While we have the objective function, the next step is to optimize it.

To minimize J(w), we can use the gradient-descent algorithm: wk+1=wkαkwJ(wk)

The true gradient is wJ(w)=wE[(vπ(S)v^(S,w))2]=E[w(vπ(S)v^(S,w))2]=2E[(vπ(S)v^(S,w))(wv^(S,w))]=2E[(vπ(S)v^(S,w))wv^(S,w)]

The true gradient above involves the calculation of an expectation. We can use the stochastic gradient to replace the true gradient: wt+1=wt+αt(vπ(st)v^(st,wt))wv^(st,wt) where st is a sample of S. Here, 2αk is merged to αk. - This algorithm is not implementable because it requires the true state value vπ, which is the unknown to be estimated. - We can replace vπ(st) with an approximation so that the algorithm is implementable.

In particular, - First, Monte Carlo learning with function approximation: Let gt be the discounted return starting from st in the episode. Then, gt can be used to approximate vπ(st). The algorithm becomes wt+1=wt+αt(gtv^(st,wt))wv^(st,wt) - Second, TD learning with function approximation: By the spirit of TD learning, we can replace vπ(st) with rt+1+γv^(st+1,wt). Please remember this substitution is not rigorous. The reason is omitted for simplicity.

After all, the algorithm becomes wt+1=wt+αt[rt+1+γv^(st+1,wt)v^(st,wt)]wv^(st,wt)

Therefore, we can solve TD learning problems[^1] with the function approximation method.

TD learning of action values based on function approximation

Sarsa with function approximation

Algorithm 8.2

Suppose that qπ(s,a) is approximated by q^(s,a,w). Replacing v^(s,w) in wt+1=wt+αt[rt+1+γv^(st+1,wt)v^(st,wt)]wv^(st,wt) by q^(s,a,w) gives wt+1=wt+αt[rt+1+γq^(st+1,at+1,wt)q^(st,at,wt)]wq^(st,at,wt).

Q-learning with function approximation

Algorithm 8.3

Tabular Q-learning can also be extended to the case of function approximation. The update rule is wt+1=wt+αt[rt+1+γmaxaA(st+1)q^(st+1,a,wt)q^(st,at,wt)]wq^(st,at,wt).

Here, the function approximator is rt+1+γmaxaA(st+1)q^(st+1,a,wt). It's similar to Sarsa with function approximation except that q^(st+1,at+1,wt) is replaced with maxaA(st+1)q^(st+1,a,wt).

NOTE: altough Q-learning with function approximation can be implemented by neural networks. In practice, we choose to use, instead of this method, Deep Q-learning or deep Q-network (DQN). DQN is not Q learning, at least not the specific Q learning algorithm introduced in my post, but it shares the core idea of Q learning.

1 "TD learning problems" refer to the problems aimed to solve by TD learning algorithms, i.e., the state/action value function estimation given a dataset generated by a given policy π.