Extending the Theory and intuition behind one of our introductory algorithms
This post is a direct continuation of part 1 of introducing vanilla policy gradient. Start there!!
From the most basic possible working version of the algorithm introduced in that post, we now extend VPG with many of the foundational “tricks” in RL to make it into a useable algorithm.
…Now, we can do better than the very simple algorithm presented in the last post. Take a breather, then let’s take a step back and think.
If I take an action, what parts of the trajectory will this affect? I’m not currently in possession of a time machine, so any action I take will not affect the past.
Therefore, the part of the trajectory
This set of rewards is sometimes known as the “rewards-to-go.” This term always struck me as kind of confusing, maybe because “to-go” isn’t as precise as I’d like. About half the time I see it, I think “‘to-go?’ Where are the rewards going? Oh, you mean ‘to-go’ as in what we have to go, or remaining.” Let’s avoid any temporary confusion and use “future return.”
Let’s be as explicit about this as possible.
Remember our trajectory
In other words,
where T is the total number of time steps in an episode (and we’ve remembered to include the discounting factor
With the above argument about not being able to affect the past, it turns out there’s little reason to ever use the original full trajectory return
It looks the same? Oh yes, but now we understand it differently!
Now, how do we actually calculate
The obstacle here is that you never have the future rewards for a time step as you’re playing out that step. You only have immediate access to the rewards you don’t care about, those in the past. So you’re forced to keep “rolling out” the trajectory to its conclusion, recording all future time steps, and only then you can accurately calculate
Once you’ve collected the rewards,
Starting from the final time step, each
Now just loop this backward in time.
Whew. That seems a lot more tedious than we were hoping for, doesn’t it?
Now, modern “offline” algorithms aren’t actually this bad. They usually collect short sequences of transitions rather than full trajectories, and store them in a buffer to be sampled and asynchronously learned from. So the “reflection between games” is happening at all times, and the algorithm is “reflecting on all past games,” so to speak, rather than the most recent experience. In other words, in “online” algorithms the state space currently being learned from is the same that’s currently being explored by the active policy. However in “offline” algorithms the current policy (implicit or explicit) sets the explored state space, but the buffer of historical data encapsulates a larger and potentially different state space. There are pros and cons to each approach.
I would like my algorithm to be “online,” meaning I would like to avoid having to collect full episodes before knowing what my return
The value function
This is a lot simpler than it may seem. The value function
We have a trajectory of states
and minimize this Loss by regression with some kind of gradient descent, training our neural network to predict the return.
And by sticking
Now for a few caveats.
This approximation may only see a subset of the state space at any given time, and it may never see the full state space. We cannot assume it is an unbiased estimator unless it has trained uniformly on data from all parts of the state space, and typically that will only be the case once the training is complete, if it happens at all. In parts of the state space which are relatively less explored, it may be very inaccurate, which could lead to learning what the value function says is an optimal policy, but in reality is nonsense.
One of the best solutions to this problem is seen in the popular RL algorithm PPO. Once you’ve understood VPG, you should work your way through TRPO, and then PPO. These links are decent references, though if you feel this blog post helped you understand VPG, let me know in the comments. If there’s enough interest I’ll continue these tutorials!
PPO’s goal is to avoid veering too far out into unexplored territory too quickly. It allows the value function time to “acclimate” to its new surroundings and give more accurate estimates, and so the policy is always learning from a “good enough” value function. It accomplishes this by limiting the size of the policy update step by clipping a surrogate loss function - but I won’t digress too much on this point, it requires its own post to do it justice.
It also turns out that despite removing the noise from past rewards, this simple solution is still quite high-variance in practice. Because of this, it’s difficult for the actor to learn a good policy and to do so quickly. Luckily there are many battle-hardened techniques for reducing variance. Read on for one such technique!
What is the lowest bias approximation of the return
Hol’ up. Calculating
Indeed we are. So perhaps we can combine the two approaches. What if I use the actual reward in the next time step, and add it to the value function’s approximation of the discounted return from that point on? Aha, a slightly more accurate estimate! This one step return is shown in the following figure and equation:
Perhaps we can do better still by using the next two real rewards instead of only one. Or the next three? Each addition will require a slightly longer delay between taking an action and being able to learn from it, because we need to collect a longer sequence out of the full trajectory. This is known as the n-step return
Now the question becomes: what’s the optimal tradeoff between accuracy and this learning delay? At this point you just need to experiment and see what works best for your problem, but I will tell you this: the answer is somewhere between 100% accuracy and zero delay, as seen in this figure:
It also turns out you can do even better by doing a weighted average of all of these options: one, two, three, etc. real rewards, followed by an approximation, and also experimenting to see what the right weighting is. There is a similar tradeoff between accuracy and computation speed, yielding a chart like the one above. Optimizing these hyperparameters is problem dependent.
We have a more accurate approximation with our n-step return
The n-step return handily deals with the bias problem in approximating the value function, but doesn’t do much to help the variance. To reduce the variance we need to understand where it comes from.
In any trajectory, there are several sources of randomness contributing to what direction it takes as it’s rolled out:
Each variation at each time step ultimately leads to the variance of the return
I’ll start with the punchline: it turns out you can take the expression you want (the n-step return), and subtract another similar expression (a baseline) to reduce the variance of your approximation. Take the following figure as a simple example.
We’re in the business of using “real-time” approximations, so the n_step return is what we’d like a low-variance estimate of. Consider the 1-step return.
The most common choice for a baseline is the value function, so the above becomes:
or more generally,
Now hold on. This is just the same as we had before. It’s the n-step return, but for some reason we’re subtracting the value function evaluated on the current state. What’s the point?
Intuitively, the advantage at its most basic level answers the question: how much better is the actual return
With the addition of advantage, we have generalized advantage estimation (GAE) in a nutshell. I’ve focused on intuition here. For more specifics in how these calculations are done, see the paper; your mental picture should now be organized to have all those formalisms swiftly fall into place.
We can now convert our
It is worth noting that GAE is not limited to on-policy methods; it can be applied anywhere you use a value function, or even a Q function.
If the intuition isn’t cutting it for you, see this brief post about why subtracting a baseline reduces variance.
Finally, we sum the policy gradient over a set
Why? Well, to gain a lower variance estimator of the true policy gradient. This should be quite familiar to you if you’ve ever done mini-batched stochastic gradient descent. Many trajectories averaged together will smooth out the noise from any one trajectory, and better represent the optimal policy gradient.
With this, we finally have the full policy gradient!
Where
With this term, we can update the parameters of our policy with SGD:
or otherwise use an optimizer like Adam.
This is the policy learning taken care of. Now we turn to updating the value function. Recalling its loss function:
We add a sum over all time steps in the trajectory, and a sum over all trajectories in our gathered set of trajectories
We divide by
Standing in for the
With each gradient descent step, the value function, on average, better approximates the real sum of rewards. The policy then has a more accurate value function with which to adjust its own actions. Together, they explore the state and action space until a locally optimal policy is reached. Though every step is not guaranteed to improve due to approximation error, even this “vanilla” reinforcement learning algorithm can learn to complete simple tasks.
Now you can put all these ingredients together. Convince yourself that you understand the full algorithm!
Go over it and see that everything lines up for you.
Thanks for reading!
Now we can take a look at an implementation in JAX in the next post.