Difference between revisions of "CartPole DQN"
(→Replay) |
|||
(3 intermediate revisions by the same user not shown) | |||
Line 133: | Line 133: | ||
for i in range(len(predictions)): | for i in range(len(predictions)): | ||
# Flag states as terminal (the last state before a epoch ended). | # Flag states as terminal (the last state before a epoch ended). | ||
− | terminal_state = (next_states[i] == np.array([ | + | terminal_state = (next_states[i] == np.array([None]*self.input_dim)).all() |
# Update each state's Q-value prediction with our new estimate. | # Update each state's Q-value prediction with our new estimate. | ||
# Terminal states have no future, so set their Q-value to their immediate reward. | # Terminal states have no future, so set their Q-value to their immediate reward. | ||
Line 151: | Line 151: | ||
from DQNAgent import DQNAgent | from DQNAgent import DQNAgent | ||
</PRE> | </PRE> | ||
− | + | The first thing we want to do is create the CartPole gym environment and refresh the environment. | |
− | |||
− | The | ||
<PRE> | <PRE> | ||
− | env = gym.make( | + | env = gym.make("CartPole-v0") |
env.reset() | env.reset() | ||
</PRE> | </PRE> | ||
− | |||
Gym caps episodes of CartPole to 200 steps. In other words, the epoch will be cut off after our cart takes 200 actions. We can disable this limit by setting it to <code>None</code> or by raising it to some constant. For this tutorial, we will set the limit to 1000. | Gym caps episodes of CartPole to 200 steps. In other words, the epoch will be cut off after our cart takes 200 actions. We can disable this limit by setting it to <code>None</code> or by raising it to some constant. For this tutorial, we will set the limit to 1000. | ||
<PRE> | <PRE> | ||
Line 173: | Line 170: | ||
<ul> | <ul> | ||
<li><code>state</code> is the state the environment is in after a step. For cartpole, this will be an numpy array of 4 values representing the position of the cart from the center, the carts velocity, the angle of the pole from the vertical, and the angular velocity of the pole.</li> | <li><code>state</code> is the state the environment is in after a step. For cartpole, this will be an numpy array of 4 values representing the position of the cart from the center, the carts velocity, the angle of the pole from the vertical, and the angular velocity of the pole.</li> | ||
− | <li><code>reward</code> is the immediate reward witnessed by the agent for taking this action.</li> | + | <li><code>reward</code> is the immediate reward witnessed by the agent for taking this action. In CartPole, the reward is always 1 for staying alive.</li> |
<li><code>done</code> a boolean indicating whether the epoch is over.</li> | <li><code>done</code> a boolean indicating whether the epoch is over.</li> | ||
</ul> | </ul> | ||
Line 189: | Line 186: | ||
if done: # Episode is completed due to failure or cap being reached. | if done: # Episode is completed due to failure or cap being reached. | ||
− | print("Episode: {}, Total reward: {}, Explore P: {}".format(ep, total_reward, agent.explore_p) | + | print("Episode: {}, Total reward: {}, Explore P: {}".format(ep, total_reward, agent.explore_p)) |
− | + | if total_reward == 999: # Simulation completed without failure. Save a copy of this network. | |
− | + | agent.model.save("cartpole.h5") | |
− | if | + | # Add experience to bucket (next_state is None since epoch is over). |
− | agent.model.save(" | ||
− | # Add experience to | ||
agent.remember(state, action, None, reward) | agent.remember(state, action, None, reward) | ||
− | env.reset() | + | env.reset() # Reset environment |
− | |||
− | |||
break | break | ||
− | else: | + | else: # Episode not over. |
− | agent.remember(state, action, next_state, reward) | + | agent.remember(state, action, next_state, reward) # Store tuple. |
− | state = next_state | + | state = next_state # Advance state |
− | agent.replay() | + | agent.replay() # Train the network form replay samples. |
+ | </PRE> | ||
+ | |||
+ | Being training with <code>python cartpole.py</code>. You should initially see small total rewards... | ||
+ | <PRE> | ||
+ | Episode: 0, Total reward: 10.0, Explore P: 0.9980017990403361 | ||
+ | Episode: 1, Total reward: 19.0, Explore P: 0.9942162108059645 | ||
+ | Episode: 2, Total reward: 12.0, Explore P: 0.9918327148817936 | ||
+ | Episode: 3, Total reward: 17.0, Explore P: 0.9884658738293698 | ||
+ | Episode: 4, Total reward: 18.0, Explore P: 0.9849134396468638 | ||
+ | Episode: 5, Total reward: 16.0, Explore P: 0.9817664398149591 | ||
+ | </PRE> | ||
+ | ...which slowly climb as more training episodes go by. | ||
+ | <PRE> | ||
+ | Episode: 225, Total reward: 134.0, Explore P: 0.0064666044211831 | ||
+ | Episode: 226, Total reward: 159.0, Explore P: 0.006264181737888733 | ||
+ | Episode: 227, Total reward: 421.0, Explore P: 0.005758284210537673 | ||
+ | Episode: 228, Total reward: 340.0, Explore P: 0.005379700746561338 | ||
+ | Episode: 229, Total reward: 627.0, Explore P: 0.004745611079457094 | ||
+ | Episode: 230, Total reward: 1000.0, Explore P: 0.0038853000157592645 | ||
</PRE> | </PRE> |
Latest revision as of 01:41, 18 February 2018
Contents
Overview
This tutorial will show you how to solve the popular CartPole problem using deep Q-learning. The CartPole problem is as follows:
A pole is attached by an un-actuated joint to a cart, which moves along a frictionless track. The system is controlled by applying a force of +1 or -1 to the cart. The pendulum starts upright, and the goal is to prevent it from falling over. A reward of +1 is provided for every timestep that the pole remains upright. The episode ends when the pole is more than 15 degrees from vertical, or the cart moves more than 2.4 units from the center.
Tutorial
This section will walk you through the steps of solving the CartPole problem with a deep Q-network. This tutorial is written for python 3.
Packages
You must first pip install
the following packages: gym
keras
and
numpy
DQN Agent
The first step of our implementation will be creating a DQNAgent object. This object will manage the state of our learning, and is independent of the CartPole problem. It has all the generic parts of a Q-learning agent and can be reused for other deep Q-learning applications. Every subsection will contain a part of the DQNAgent class you must implement.
Imports
Start by creating a file DQNAgent.py
and include the following imports:
import numpy as np from keras.layers import Input, Dense from keras.optimizers import RMSprop from keras.models import Model from collections import deque
Constructor
The reason for each import will become apparent as our implementation continues. Next add a blank DQNAgent
class with an empty constructor.
class DQNAgent: def __init__(self): pass
This class will take in all of our hyperparemeters, so let's update our constructor to take in those parameters. We also provide some default values for some of those hyperparameters.
class DQNAgent: def __init__(self, input_dim, output_dim, learning_rate=.005, mem_size=5000, batch_size=64, gamma=.99, decay_rate=.0002): pass
input_dim
is the number of input nodes for our DQN.output_dim
is the number of output nodes for our DQN.learning_rate
is a Keras parameter for our network describing how much we value new information.mem_size
is the maximum number of instances allowed in our bucket for experience replay.batch_size
is the number of experience tuples we train our model on each replay event.gamma
is our discount factor for the Bellman equation update.decay_rate
is the rate at which exploration probability decays.
Now for the next step, we complete our constructor by saving all of these parameters as instance variables, defining a neural network model, and defining a few other parameters.
def __init__(self, input_dim, output_dim, learning_rate=.005, mem_size=5000, batch_size=64, gamma=.99, decay_rate=.0002): # Save instance variables. self.input_dim = input_dim self.output_dim = output_dim self.batch_size = batch_size self.gamma = gamma self.decay_rate = decay_rate # Define other instance variables. self.explore_p = 1 # The current probability of taking a random action. self.memory = deque(maxlen=mem_size) # Define our experience replay bucket as a deque with size mem_size. # Define and compile our DQN. This network has 3 layers of 24 nodes. This is sufficient to solve # CartPole, but you should definitely tweak the architecture for other implementations. input_layer = Input(shape=(input_dim,)) hl = Dense(24, activation="relu")(input_layer) hl = Dense(24, activation="relu")(hl) hl = Dense(24, activation="relu")(hl) output_layer = Dense(output_dim, activation="linear")(hl) self.model = Model(input_layer, output_layer) self.model.compile(loss="mse", optimizer=RMSprop(lr=learning_rate))
Act
The most fundamental part of a Q-learning problem is the ability for the agent to take an action. Actions are either determined by the current policy (based off Q-function values) or are picked randomly, depending on the current exploration probability. We now define an act
function which, given the current state of the environment, determines which action to take next. Note that with OpenAI gym, actions correspond to integers (0, 1, 2, ...).
def act(self, state): # First, decay our explore probability self.explore_p *= 1 - self.decay_rate # With probability explore_p, randomly pick an action if self.explore_p > np.random.rand(): return np.random.randint(self.output_dim) # Otherwise, find the action that should maximize future rewards according to our current Q-function policy. else: return np.argmax(self.model.predict(np.array([state]))[0])
Remember
One of the crucial parts of deep Q-learning is experience replay, where we store instances in a bucket and randomly draw from them to train our model. We now define the remember
function, which stores the given experience tuple in that experience replay bucket for later sampling.
def remember(self, state, action, next_state, reward): # Create a blank state. Serves as next_state if this was the last experience tuple before the epoch ended. terminal_state = np.array([None]*self.input_dim) # Add experience tuple to bucket. Bucket is a deque, so older tuple falls out on overflow. self.memory.append((state, action, terminal_state if next_state is None else next_state, reward))
Replay
The replay step is where experience tuples are randomly sampled from the bucket and are used to train the DQN. We now define the replay
function to do just that.
def replay(self): # Only conduct a replay if we have enough experience to sample from. if len(self.memory) < self.batch_size: return # Pick random indices from the bucket without replacement. batch_size determines number of samples. idx = np.random.choice(len(self.memory), size=self.batch_size, replace=False) minibatch = np.array(self.memory)[idx] self.train(minibatch) # Extract the columns from our sample states = np.array(list(minibatch[:,0])) actions = minibatch[:,1] next_states = np.array(list(minibatch[:,2])) rewards = np.array(minibatch[:,3]) # Compute a new estimate for each Q-value. This uses the second half of Bellman's equation. estimate = rewards + self.gamma * np.amax(self.model.predict(next_states), axis=1) # Get the network's current Q-value predictions for the states in this sample. predictions = self.model.predict(states) # Update the network's predictions with the new predictions we have. for i in range(len(predictions)): # Flag states as terminal (the last state before a epoch ended). terminal_state = (next_states[i] == np.array([None]*self.input_dim)).all() # Update each state's Q-value prediction with our new estimate. # Terminal states have no future, so set their Q-value to their immediate reward. predictions[i][actions[i]] = rewards[i] if terminal_state else estimate[i] # Propagate the new predictions through our network. self.model.fit(states, predictions, epochs=1, verbose=0)
This completes our DQNAgent object. Now let's move on to the actual driver of the learning process, which interacts with CartPole to drive learning.
CartPole
Now we will create the script that utilizes a DQNAgent to learn how to play CartPole. Start by creating a file CartPole.py
and include the following imports:
import gym import numpy as np from DQNAgent import DQNAgent
The first thing we want to do is create the CartPole gym environment and refresh the environment.
env = gym.make("CartPole-v0") env.reset()
Gym caps episodes of CartPole to 200 steps. In other words, the epoch will be cut off after our cart takes 200 actions. We can disable this limit by setting it to None
or by raising it to some constant. For this tutorial, we will set the limit to 1000.
env._max_episode_steps = 1000
Now create an instance of a DQNAgent. The input_dim is equal to the number of features in our state (4 features for CartPole, explained later) and the output_dim is equal to the number of actions we can take (2 for CartPole, left or right).
agent = DQNAgent(input_dim=4, output_dim=2)
We now take the first step of our simulation and save its results to some variables:
state, reward, done, _ = env.step(env.action_space.sample())
state
is the state the environment is in after a step. For cartpole, this will be an numpy array of 4 values representing the position of the cart from the center, the carts velocity, the angle of the pole from the vertical, and the angular velocity of the pole.reward
is the immediate reward witnessed by the agent for taking this action. In CartPole, the reward is always 1 for staying alive.done
a boolean indicating whether the epoch is over.
Now initialization is complete and we can enter our training loop.
# Play the game many times for ep in range(0, 500): # 500 episodes of learning total_reward = 0 # Maintains the score for this episode. while True: env.render() # Show the animation of the cartpole action = agent.act(state) # Get action next_state, reward, done, _ = env.step(action) # Take action total_reward += reward # Accrue reward if done: # Episode is completed due to failure or cap being reached. print("Episode: {}, Total reward: {}, Explore P: {}".format(ep, total_reward, agent.explore_p)) if total_reward == 999: # Simulation completed without failure. Save a copy of this network. agent.model.save("cartpole.h5") # Add experience to bucket (next_state is None since epoch is over). agent.remember(state, action, None, reward) env.reset() # Reset environment break else: # Episode not over. agent.remember(state, action, next_state, reward) # Store tuple. state = next_state # Advance state agent.replay() # Train the network form replay samples.
Being training with python cartpole.py
. You should initially see small total rewards...
Episode: 0, Total reward: 10.0, Explore P: 0.9980017990403361 Episode: 1, Total reward: 19.0, Explore P: 0.9942162108059645 Episode: 2, Total reward: 12.0, Explore P: 0.9918327148817936 Episode: 3, Total reward: 17.0, Explore P: 0.9884658738293698 Episode: 4, Total reward: 18.0, Explore P: 0.9849134396468638 Episode: 5, Total reward: 16.0, Explore P: 0.9817664398149591
...which slowly climb as more training episodes go by.
Episode: 225, Total reward: 134.0, Explore P: 0.0064666044211831 Episode: 226, Total reward: 159.0, Explore P: 0.006264181737888733 Episode: 227, Total reward: 421.0, Explore P: 0.005758284210537673 Episode: 228, Total reward: 340.0, Explore P: 0.005379700746561338 Episode: 229, Total reward: 627.0, Explore P: 0.004745611079457094 Episode: 230, Total reward: 1000.0, Explore P: 0.0038853000157592645