CartPole DQN
Contents
Overview
This tutorial will show you how to solve the popular CartPole problem using deep Q-learning. The CartPole problem is as follows:
A pole is attached by an un-actuated joint to a cart, which moves along a frictionless track. The system is controlled by applying a force of +1 or -1 to the cart. The pendulum starts upright, and the goal is to prevent it from falling over. A reward of +1 is provided for every timestep that the pole remains upright. The episode ends when the pole is more than 15 degrees from vertical, or the cart moves more than 2.4 units from the center.
Tutorial
This section will walk you through the steps of solving the CartPole problem with a deep Q-network. This is tutorial is written for python 3.
Packages
You must first pip install
the following packages: gym
keras
and
numpy
DQN Agent
The first step of our implementation will be creating a DQNAgent object. This object will manage the state of our learning, and is independent of the CartPole problem. It has all the generic parts of a Q-learning agent and can be reused for other deep Q-learning applications.
Start by creating a file DQNAgent.py
and include the following imports:
from keras.layers import Input, Dense from keras.optimizers import RMSprop from keras.models import Model from collections import deque
The reason for each import will become apparent as our implementation continues. Next add a blank DQNAgent
class with an empty constructor.
class DQNAgent: def __init__(self): pass
This class will take in all of our hyperparemeters, so let's update our constructor to take in those parameters. We also provide some default values for some of those hyperparameters.
class DQNAgent: def __init__(self, input_dim, output_dim, learning_rate=.005, mem_size=50000, batch_size=64, gamma=.99, explore_start=1.0, explore_stop=.01, decay_rate=.0005): pass
input_dim
is the number of input nodes for our DQN.output_dim
is the number of output nodes for our DQN.learning_rate
is a Keras parameter for our network describing how much we value new information.mem_size
is the maximum number of instances allowed in our bucket for experience replay.batch_size
is the number of experience tuples we train our model on each replay event.gamma
is our discount factor for the Bellman equation update.explore_start
is the initial exploration probability.explore_stop
is the lowest the exploration probability can ever get.decay_rate
is the rate at which exploration probability decays.