Proof of the Policy Gradient Theorem
Here we prove the Policy gradient theorem, i.e., the gradient of an objective function \(J(\theta)\) is \[ \color{orange}{\nabla_\theta J(\theta)=\sum_{s \in \mathcal{S}} \eta(s) \sum_{a \in \mathcal{A}} \nabla_\theta \pi(a \mid s, \theta) q_\pi(s, a)} \] where \(\eta\) is a state distribution and \(\nabla_\theta \pi\) is the gradient of \(\pi\) with respect to \(\theta\).
Moreover, this equation has a compact form expressed in terms of expectation: \[ \color{green}{\nabla_\theta J(\theta)=\mathbb{E}_{S \sim \eta, A \sim \pi(S, \theta)}\left[\nabla_\theta \ln \pi(A \mid S, \theta) q_\pi(S, A)\right]}, \] where \(\ln\) is the natural logarithm.
We prove this theorem in the discounted case and undiscounted cases separately. In each case, we prove it for 3 different metrics \(\bar{v}_\pi, \bar{r}_\pi, \bar{v}_\pi^0\).
For simplicity, I only list the proof in the discounted case in the appendix. See the book for proof of the undiscounted case.
Sources:
Derivation of the gradients in the discounted case
We next derive the gradients of the metrics in the discounted case where \(\gamma \in(0,1)\). The state value and action value in the discounted case are defined as \[ \begin{aligned} v_\pi(s) & =\mathbb{E}\left[R_{t+1}+\gamma R_{t+2}+\gamma^2 R_{t+3}+\ldots \mid S_t=s\right], \\ q_\pi(s, a) & =\mathbb{E}\left[R_{t+1}+\gamma R_{t+2}+\gamma^2 R_{t+3}+\ldots \mid S_t=s, A_t=a\right] . \end{aligned} \]
It holds that \(v_\pi(s)=\sum_{a \in \mathcal{A}} \pi(a \mid s, \theta) q_\pi(s, a)\) and the state value satisfies the Bellman equation.
First, we show that \(\bar{v}_\pi(\theta)\) and \(\bar{r}_\pi(\theta)\) are equivalent metrics.
Lemma 9.1
Lemma 9.1 (Equivalence between \(\bar{v}_\pi(\theta)\) and \(\bar{r}_\pi(\theta)\) ). In the discounted case where \(\gamma \in\) \((0,1)\), it holds that \[ \begin{equation} \label{eq9_13} \bar{r}_\pi=(1-\gamma) \bar{v}_\pi \end{equation} \]
Proof: Note that \(\bar{v}_\pi(\theta)=d_\pi^T v_\pi\) and \(\bar{r}_\pi(\theta)=d_\pi^T r_\pi\), where \(v_\pi\) and \(r_\pi\) satisfy the Bellman equation \(v_\pi=r_\pi+\gamma P_\pi v_\pi\). Multiplying \(d_\pi^T\) on both sides of the Bellman equation yields \[ \bar{v}_\pi=\bar{r}_\pi+\gamma d_\pi^T P_\pi v_\pi=\bar{r}_\pi+\gamma d_\pi^T v_\pi=\bar{r}_\pi+\gamma \bar{v}_\pi \] which implies \(\eqref{eq9_13}\). Second, the following lemma gives the gradient of \(v_\pi(s)\) for any \(s\).
Lemma 9.2
Lemma 9.2 (Gradient of \(v_\pi(s)\) ). In the discounted case, it holds for any \(s \in \mathcal{S}\) that \[ \nabla_\theta v_\pi(s)=\sum_{s^{\prime} \in \mathcal{S}} \operatorname{Pr}_\pi\left(s^{\prime} \mid s\right) \sum_{a \in \mathcal{A}} \nabla_\theta \pi\left(a \mid s^{\prime}, \theta\right) q_\pi\left(s^{\prime}, a\right) \] where \[ \operatorname{Pr}_\pi\left(s^{\prime} \mid s\right) \doteq \sum_{k=0}^{\infty} \gamma^k\left[P_\pi^k\right]_{s s^{\prime}}=\left[\left(I_n-\gamma P_\pi\right)^{-1}\right]_{s s^{\prime}} \] is the discounted total probability of transitioning from \(s\) to \(s^{\prime}\) under policy \(\pi\). Here, \([\cdot]_{s s^{\prime}}\) denotes the entry in the sth row and \(s^{\prime}\) th column, and \(\left[P_\pi^k\right]_{s s^{\prime}}\) is the probability of transitioning from \(s\) to \(s^{\prime}\) using exactly \(k\) steps under \(\pi\).
With the results in Lemma 9.2, we are ready to derive the gradient of \(\bar{v}_\pi^0\).
Theorem 9.2
Theorem 9.2 (Gradient of \(\bar{v}_\pi^0\) in the discounted case). In the discounted case where \(\gamma \in(0,1)\), the gradient of \(\bar{v}_\pi^0=d_0^T v_\pi\) is \[ \nabla_\theta \bar{v}_\pi^0=\mathbb{E}\left[\nabla_\theta \ln \pi(A \mid S, \theta) q_\pi(S, A)\right] \] where \(S \sim \rho_\pi\) and \(A \sim \pi(S, \theta)\). Here, the state distribution \(\color{purple}{\rho_\pi}\) is \[ \color{purple}{\rho_\pi(s)}=\sum_{s^{\prime} \in \mathcal{S}} d_0\left(s^{\prime}\right) \operatorname{Pr}_\pi\left(s \mid s^{\prime}\right), \quad s \in \mathcal{S} \] where \(\operatorname{Pr}_\pi\left(s \mid s^{\prime}\right)=\sum_{k=0}^{\infty} \gamma^k\left[P_\pi^k\right]_{s^{\prime} s}=\left[\left(I-\gamma P_\pi\right)^{-1}\right]_{s^{\prime} s}\) is the discounted total probability of transitioning from \(s^{\prime}\) to \(s\) under policy \(\pi\).
Theorem 9.3
Theorem 9.3 (Gradients of \(\bar{r}_\pi\) and \(\bar{v}_\pi\) in the discounted case). In the discounted case where \(\gamma \in(0,1)\), the gradients of \(\bar{r}_\pi\) and \(\bar{v}_\pi\) are \[ \begin{aligned} \nabla_\theta \bar{r}_\pi=(1-\gamma) \nabla_\theta \bar{v}_\pi & \approx \sum_{s \in \mathcal{S}} d_\pi(s) \sum_{a \in \mathcal{A}} \nabla_\theta \pi(a \mid s, \theta) q_\pi(s, a) \\ & =\mathbb{E}\left[\nabla_\theta \ln \pi(A \mid S, \theta) q_\pi(S, A)\right], \end{aligned} \] where \(S \sim d_\pi\) and \(A \sim \pi(S, \theta)\). Here, the approximation is more accurate when \(\gamma\) is closer to 1 .
Appendix
Proof of Lemma 9.2
First, for any \(s \in \mathcal{S}\), it holds that \[ \begin{aligned} \nabla_\theta v_\pi(s) & =\nabla_\theta\left[\sum_{a \in \mathcal{A}} \pi(a \mid s, \theta) q_\pi(s, a)\right] \\ & =\sum_{a \in \mathcal{A}}\left[\nabla_\theta \pi(a \mid s, \theta) q_\pi(s, a)+\pi(a \mid s, \theta) \color{silver}{\nabla_\theta q_\pi(s, a)}\right] \end{aligned} \] where \(q_\pi(s, a)\) is the action value given by \[ q_\pi(s, a)=r(s, a)+\gamma \sum_{s^{\prime} \in \mathcal{S}} p\left(s^{\prime} \mid s, a\right) v_\pi\left(s^{\prime}\right) \]
Since \(r(s, a)=\sum_r r p(r \mid s, a)\) is independent of \(\theta\), we have \[ \color{silver}{\nabla_\theta q_\pi(s, a)}=0+\gamma \sum_{s^{\prime} \in \mathcal{S}} p\left(s^{\prime} \mid s, a\right) \nabla_\theta v_\pi\left(s^{\prime}\right) \] Substituting this result into the policy gradient \(\nabla_\theta v_\pi(s)\) yields \[ \begin{aligned} \nabla_\theta v_\pi(s) & =\sum_{a \in \mathcal{A}}\left[\nabla_\theta \pi(a \mid s, \theta) q_\pi(s, a)+\pi(a \mid s, \theta) \color{silver}{\gamma \sum_{s^{\prime} \in \mathcal{S}} p\left(s^{\prime} \mid s, a\right) \nabla_\theta v_\pi\left(s^{\prime}\right)}\right] \\ & =\sum_{a \in \mathcal{A}} \nabla_\theta \pi(a \mid s, \theta) q_\pi(s, a)+\gamma \sum_{a \in \mathcal{A}} \pi(a \mid s, \theta) \sum_{s^{\prime} \in \mathcal{S}} p\left(s^{\prime} \mid s, a\right) \nabla_\theta v_\pi\left(s^{\prime}\right) \end{aligned} \]
It is notable that \(\nabla_\theta v_\pi\) appears on both sides of the above equation. Here, we use the matrix-vector form to calculate it. In particular, let \[ u(s) \doteq \sum_{a \in \mathcal{A}} \nabla_\theta \pi(a \mid s, \theta) q_\pi(s, a) \] Since \[ \sum_{a \in \mathcal{A}} \pi(a \mid s, \theta) \sum_{s^{\prime} \in \mathcal{S}} p\left(s^{\prime} \mid s, a\right) \nabla_\theta v_\pi\left(s^{\prime}\right)=\sum_{s^{\prime} \in \mathcal{S}} p\left(s^{\prime} \mid s\right) \nabla_\theta v_\pi\left(s^{\prime}\right)=\sum_{s^{\prime} \in \mathcal{S}}\left[P_\pi\right]_{s s^{\prime}} \nabla_\theta v_\pi\left(s^{\prime}\right) \] equation \[ \nabla_\theta v_\pi(s) = \sum_{a \in \mathcal{A}} \nabla_\theta \pi(a \mid s, \theta) q_\pi(s, a)+\gamma \sum_{a \in \mathcal{A}} \pi(a \mid s, \theta) \sum_{s^{\prime} \in \mathcal{S}} p\left(s^{\prime} \mid s, a\right) \nabla_\theta v_\pi\left(s^{\prime}\right) \] can be written in matrix-vector form as \[ \underbrace{\left[\begin{array}{c} \vdots \\ \nabla_\theta v_\pi(s) \\ \vdots \end{array}\right]}_{\nabla_\theta v_\pi \in \mathbb{R}^{m n}}=\underbrace{\left[\begin{array}{c} \vdots \\ u(s) \\ \vdots \end{array}\right]}_{u \in \mathbb{R}^{m n}}+\gamma\left(P_\pi \otimes I_m\right) \underbrace{\left[\begin{array}{c} \vdots \\ \nabla_\theta v_\pi\left(s^{\prime}\right) \\ \vdots \end{array}\right]}_{\nabla_\theta v_\pi \in \mathbb{R}^{m n}}, \] which can be written concisely as \[ \nabla_\theta v_\pi=u+\gamma\left(P_\pi \otimes I_m\right) \nabla_\theta v_\pi . \] Here, \(n=|\mathcal{S}|\), and \(m\) is the dimension of the parameter vector \(\theta\). The reason that the Kronecker product \(\otimes\) emerges in the equation is that \(\nabla_\theta v_\pi(s)\) is a vector. The above equation is a linear equation of \(\nabla_\theta v_\pi\), which can be solved as \[ \begin{aligned} \nabla_\theta v_\pi & =\left(I_{n m}-\gamma P_\pi \otimes I_m\right)^{-1} u \\ & =\left(I_n \otimes I_m-\gamma P_\pi \otimes I_m\right)^{-1} u \\ & =\left[\left(I_n-\gamma P_\pi\right)^{-1} \otimes I_m\right] u . \end{aligned} \]
For any state \(s\), it follows from this equation that \[ \begin{aligned} \nabla_\theta v_\pi(s) & =\sum_{s^{\prime} \in \mathcal{S}}\left[\left(I_n-\gamma P_\pi\right)^{-1}\right]_{s s^{\prime}} u\left(s^{\prime}\right) \\ & =\sum_{s^{\prime} \in \mathcal{S}}\left[\left(I_n-\gamma P_\pi\right)^{-1}\right]_{s s^{\prime}} \sum_{a \in \mathcal{A}} \nabla_\theta \pi\left(a \mid s^{\prime}, \theta\right) q_\pi\left(s^{\prime}, a\right) . \end{aligned} \]
The quantity \(\left[\left(I_n-\gamma P_\pi\right)^{-1}\right]_{s s^{\prime}}\) has a clear probabilistic interpretation. In particular, since \[ \left(I_n-\gamma P_\pi\right)^{-1}=I+\gamma P_\pi+\gamma^2 P_\pi^2+\cdots , \] we have \[ \left[\left(I_n-\gamma P_\pi\right)^{-1}\right]_{s s^{\prime}}=[I]_{s s^{\prime}}+\gamma\left[P_\pi\right]_{s s^{\prime}}+\gamma^2\left[P_\pi^2\right]_{s s^{\prime}}+\cdots=\sum_{k=0}^{\infty} \gamma^k\left[P_\pi^k\right]_{s s^{\prime}} . \] Note that \(\left[P_\pi^k\right]_{s s^{\prime}}\) is the probability of transitioning from \(s\) to \(s^{\prime}\) using exactly \(k\) steps (see my previous post). Therefore, \[ \left[\left(I_n-\gamma P_\pi\right)^{-1}\right]_{s s^{\prime}} \] is the discounted total probability of transitioning from \(s\) to \(s^{\prime}\) using any number of steps. By denoting \[ \left[\left(I_n-\gamma P_\pi\right)^{-1}\right]_{s s^{\prime}} \doteq \operatorname{Pr}_\pi\left(s^{\prime} \mid s\right) , \] we obtain \[ \sum_{s^{\prime} \in \mathcal{S}}\left[\left(I_n-\gamma P_\pi\right)^{-1}\right]_{s s^{\prime}} \sum_{a \in \mathcal{A}} \nabla_\theta \pi\left(a \mid s^{\prime}, \theta\right) q_\pi\left(s^{\prime}, a\right) = \sum_{s^{\prime} \in \mathcal{S}} \operatorname{Pr}_\pi\left(s^{\prime} \mid s\right) \sum_{a \in \mathcal{A}} \nabla_\theta \pi\left(a \mid s^{\prime}, \theta\right) q_\pi\left(s^{\prime}, a\right) = \nabla_\theta v_\pi(s). \]
Q. E. D.
Proof of Theorem 9.2
Since \(d_0(s)\) is independent of \(\pi\), we have \[ \nabla_\theta \bar{v}_\pi^0=\nabla_\theta \sum_{s \in \mathcal{S}} d_0(s) v_\pi(s)=\sum_{s \in \mathcal{S}} d_0(s) \color{brown}{\nabla_\theta v_\pi(s)} . \]
Substituting the expression of \(\nabla_\theta v_\pi(s)\) given in Lemma 9.2 into the above equation yields $$ \[\begin{aligned} \nabla_\theta \bar{v}_\pi^0 = \sum_{s \in \mathcal{S}} d_0(s) \color{brown}{\nabla_\theta v_\pi(s)} & = \sum_{s \in \mathcal{S}} d_0(s) \color{brown}{\sum_{s^{\prime} \in \mathcal{S}} \operatorname{Pr}_\pi\left(s^{\prime} \mid s\right) \sum_{a \in \mathcal{A}} \nabla_\theta \pi\left(a \mid s^{\prime}, \theta\right) q_\pi\left(s^{\prime}, a\right)} \\ & =\sum_{s^{\prime} \in \mathcal{S}}\left(\sum_{s \in \mathcal{S}} d_0(s) \operatorname{Pr}_\pi\left(s^{\prime} \mid s\right)\right) \sum_{a \in \mathcal{A}} \nabla_\theta \pi\left(a \mid s^{\prime}, \theta\right) q_\pi\left(s^{\prime}, a\right) \\ & \doteq \sum_{s^{\prime} \in \mathcal{S}} \color{purple}{\rho_\pi}\left(s^{\prime}\right) \sum_{a \in \mathcal{A}} \nabla_\theta \pi\left(a \mid s^{\prime}, \theta\right) q_\pi\left(s^{\prime}, a\right) \\ & =\sum_{s \in \mathcal{S}} \rho_\pi(s) \sum_{a \in \mathcal{A}} \nabla_\theta \pi(a \mid s, \theta) q_\pi(s, a) \quad \quad \text { (change } s^{\prime} \text { to } s \text { ) } \\ & =\sum_{s \in \mathcal{S}} \rho_\pi(s) \sum_{a \in \mathcal{A}} \pi(a \mid s, \theta) \nabla_\theta \ln \pi(a \mid s, \theta) q_\pi(s, a) \\ & =\mathbb{E}\left[\nabla_\theta \ln \pi(A \mid S, \theta) q_\pi(S, A)\right], \end{aligned}\]$$ where \(S \sim \rho_\pi\) and \(A \sim \pi(S, \theta)\). The proof is complete.
Proof of Theorem 9.3
It follows from the definition of \(\bar{v}_\pi\) that \[ \begin{aligned} \nabla_\theta \bar{v}_\pi & =\nabla_\theta \sum_{s \in \mathcal{S}} d_\pi(s) v_\pi(s) \\ & =\sum_{s \in \mathcal{S}} \nabla_\theta d_\pi(s) v_\pi(s)+\sum_{s \in \mathcal{S}} d_\pi(s) \nabla_\theta v_\pi(s) \end{aligned} \]
This equation contains two terms. On the one hand, substituting the expression of \(\nabla_\theta v_\pi\) given in (9.17) into the second term gives \[ \begin{aligned} \sum_{s \in \mathcal{S}} d_\pi(s) \nabla_\theta v_\pi(s) & =\left(d_\pi^T \otimes I_m\right) \nabla_\theta v_\pi \\ & =\left(d_\pi^T \otimes I_m\right)\left[\left(I_n-\gamma P_\pi\right)^{-1} \otimes I_m\right] u \\ & =\left[d_\pi^T\left(I_n-\gamma P_\pi\right)^{-1}\right] \otimes I_m u \end{aligned} \]
It is noted that \[ d_\pi^T\left(I_n-\gamma P_\pi\right)^{-1}=\frac{1}{1-\gamma} d_\pi^T \] which can be easily verified by multiplying \(\left(I_n-\gamma P_\pi\right)\) on both sides of the equation. Therefore, (9.21) becomes \[ \begin{aligned} \sum_{s \in \mathcal{S}} d_\pi(s) \nabla_\theta v_\pi(s) & =\frac{1}{1-\gamma} d_\pi^T \otimes I_m u \\ & =\frac{1}{1-\gamma} \sum_{s \in \mathcal{S}} d_\pi(s) \sum_{a \in \mathcal{A}} \nabla_\theta \pi(a \mid s, \theta) q_\pi(s, a) \end{aligned} \] On the other hand, the first term of \[ \sum_{s \in \mathcal{S}} \nabla_\theta d_\pi(s) v_\pi(s)+\sum_{s \in \mathcal{S}} d_\pi(s) \nabla_\theta v_\pi(s) \] involves \(\nabla_\theta d_\pi\). However, since the second term contains \(\frac{1}{1-\gamma}\), the second term becomes dominant, and the first term becomes negligible (#TODO) when \(\gamma \rightarrow 1\). Therefore, \[ \nabla_\theta \bar{v}_\pi \approx \frac{1}{1-\gamma} \sum_{s \in \mathcal{S}} d_\pi(s) \sum_{a \in \mathcal{A}} \nabla_\theta \pi(a \mid s, \theta) q_\pi(s, a) \]
Furthermore, it follows from \(\bar{r}_\pi=(1-\gamma) \bar{v}_\pi\) that \[ \begin{aligned} \nabla_\theta \bar{r}_\pi=(1-\gamma) \nabla_\theta \bar{v}_\pi & \approx \sum_{s \in \mathcal{S}} d_\pi(s) \sum_{a \in \mathcal{A}} \nabla_\theta \pi(a \mid s, \theta) q_\pi(s, a) \\ & =\sum_{s \in \mathcal{S}} d_\pi(s) \sum_{a \in \mathcal{A}} \pi(a \mid s, \theta) \nabla_\theta \ln \pi(a \mid s, \theta) q_\pi(s, a) \\ & =\mathbb{E}\left[\nabla_\theta \ln \pi(A \mid S, \theta) q_\pi(S, A)\right] \end{aligned} \] The approximation in the above equation requires that the first term does not go to infinity when \(\gamma \rightarrow 1\).