Reinforcement Learning with PyTorch: Mastering CartPole-v0!
Reinforcement Learning (RL) is a powerful paradigm that enables agents to learn how to make decisions by interacting with their environments. In this post, we'll dive into the fundamental concepts of RL and demonstrate how to enhance the training process using PyTorch. We'll solve the classic CartPole-v0 environment using a Q-network and explore the benefits of incorporating experience replay and a target network
Reinforcement Learning
Reinforcement learning is an interesting area of Machine learning. The rough idea is that you have an agent and an environment. The agent takes actions and environment gives reward based on those actions, The goal is to teach the agent optimal behaviour in order to maximize the reward received by the environment.
Introduction to CartPole-v0
The CartPole-v0 environment simulates a pole balancing on a cart. The agent can apply forces to the left or right to keep the pole upright. The goal is to prevent the pole from falling over for as long as possible. The state consists of cart position, cart velocity, pole angle, and pole angular velocity.
A pole is attached by an un-actuated joint to a cart, which moves along a frictionless track. The system is controlled by applying a force of +1 or -1 to the cart.
The pendulum starts upright, and the goal is to prevent it from falling over. A reward of +1 is provided for every timestep that the pole remains upright. The episode ends when the pole is more than 15 degrees from vertical, or the cart moves more than 2.4 units from the centre.
Observation Space
The observation space represents the information available to the agent about the current state of the environment. It is a crucial link that allows the agent to perceive and understand its surroundings. The nature and dimensionality of the observation space greatly depend on the problem being solved.
For example, in the CartPole-v0 environment, the observation space consists of four continuous values: cart position, cart velocity, pole angle, and pole angular velocity. Each time step, the environment provides these four values as observations to the agent. In more complex environments, the observation space can include images, sensor data, or any other relevant information.
Understanding the observation space is essential for designing appropriate neural network architectures to process these observations. In the case of image-based environments, convolutional neural networks (CNNs) are commonly used to extract meaningful features from visual data.
Action Space
The action space defines the set of actions an agent can take to interact with the environment. It represents the choices available to the agent for influencing the environment and achieving its goals. Like the observation space, the action space can vary significantly depending on the task.
In the CartPole-v0 environment, the action space consists of two discrete actions: applying a force to the left or right. The agent chooses one of these actions at each time step to control the cart's movement and keep the pole balanced. In more complex scenarios, the action space can be continuous, involving a range of possible actions with real-valued parameters.
Choosing an appropriate action is the crux of reinforcement learning. Agents employ various strategies, such as exploitation (choosing actions based on current knowledge) and exploration (trying out new actions to discover their effects), to maximize cumulative rewards over time.
Exploring Observation and Action Spaces
To solidify our understanding of observation and action spaces, let's run some code to interact with the CartPole-v0 environment. We'll use the OpenAI Gym library to create the environment, and we'll print out the observation space and action space to observe their characteristics.
import gym
# Create the CartPole-v0 environment
env = gym.make('CartPole-v0')
# Display information about the observation and action spaces
print("Observation Space:", env.observation_space)
print("Action Space:", env.action_space)
done = False
state = env.reset()
while not done:
env.render()
action = env.action_space.sample()
new_state, reward, done, info = env.step(action)
env.close()
This code snippet offers a practical way to visualize how the environment responds to different actions. Keep in mind that in real-world scenarios, you would replace the random action selection with actions chosen by your RL agent.
By interacting with the environment in this way, you gain a concrete sense of how observations change with actions, setting the stage for implementing more advanced RL algorithms.
The interplay between the observation and action spaces forms the foundation of the agent's decision-making process. The agent receives observations from the environment, processes them using its internal mechanisms (such as neural networks), and selects an action based on its current policy. This action is then sent back to the environment, causing a transition to a new state and yielding a reward, thus closing the feedback loop.
Understanding and appropriately handling the observation and action spaces is crucial for successful RL algorithm design. The chosen representation of observations, the dimensionality of the action space, and the strategies for exploration all influence the agent's ability to learn and optimize its behavior.
A Dive into Q-Networks
Now that we've grasped the fundamentals of observation and action spaces, let's delve into one of the foundational concepts of reinforcement learning: Q-learning. Q-learning is a simple yet powerful algorithm that enables agents to learn optimal policies through interaction with an environment.
Q-Networks: Approximating the Q-Value Function
In complex environments, it's often infeasible to maintain a table of Q-values for every state-action pair. This is where Q-networks come into play. A Q-network is a neural network that approximates the Q-value function. It takes the current state as input and outputs Q-values for all possible actions.
Implementing a Q-network is relatively straightforward using libraries like PyTorch. Here's a recap of how to set up a simple Q-network for our CartPole-v0 environment:
import torc
import torch.nn as nn
number_of_inputs = env.observation_space.shape[0]
number_of_outputs = env.action_space.n
class QNetwork(nn.Module):
def __init__(self, number_of_inputs, number_of_outputs , hidden_size=64):
super(QNetwork, self).__init__()
self.fc1 = nn.Linear(number_of_inputs, hidden_size)
self.fc2 = nn.Linear(hidden_size, number_of_outputs)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
This network takes the state as input and outputs Q-values for each action. During training, the network adjusts its parameters to minimize the discrepancy between predicted Q-values and actual rewards.
The Q-Value Function
At the heart of Q-learning lies the Q-value function. For each state-action pair, the Q-value represents the expected cumulative reward that an agent can obtain by starting from that state, taking a specific action, and then following an optimal policy. Mathematically, it's expressed as:
Recommended by LinkedIn
To bring this equation to life, we use Q-networks—a neural network-based approach. Imagine a Q-network as a mathematical wizard. It takes the state as input and outputs Q-values for each action. This mapping from states to Q-values enables the agent to discern which actions are most promising in any given situation.
For instance, let's revisit our CartPole-v0 environment. The Q-network, akin to a mathematical oracle, receives the cart's position, velocity, pole's angle, and angular velocity, and it outputs Q-values for the available actions—pushing left or right. It's this neural network's grasp of mathematical relationships that allows agents to decide on actions that yield maximum rewards.
Q-learning itself is astonishingly simple
This iterative process encapsulates the elegance of Q-learning. It doesn't rely on complex mathematics or intricate algorithms. Instead, it's a natural expression of the agent's gradual learning as it navigates states and actions.
Implementing Q-Learning
Selecting Actions with the QNet_Agent
In the realm of reinforcement learning, the selection of actions is a fundamental aspect that guides an agent's behavior. The QNet_Agent class, introduced earlier, includes a method named select_action that employs an epsilon-greedy strategy to make informed decisions about which action to take. Let's dive into how this strategy works and its significance in the Q-learning process.
class QNet_Agent()
def __init__(self, number_of_inputs, number_of_outputs):
# Initialize the Q-network
self.nn = Q_Network(number_of_inputs, number_of_outputs).to(device)
# Define loss function
self.loss_func = nn.MSELoss()
# Define optimizer
self.optimizer = optim.Adam(params=self.nn.parameters(), lr=learning_rate)
# Initialize the number of frames
self.number_of_frames = 0
# Load a previously saved model if specified
if resume_previous_training and os.path.exists(file2save):
print("Loading previously saved model ... ")
self.nn.load_state_dict(load_model())
def select_action(self, state, epsilon):
random_for_egreedy = torch.rand(1)[0]
if random_for_egreedy > epsilon:
with torch.no_grad():
state = torch.tensor(state).to(device)
action_from_nn = self.nn(state)
action = torch.max(action_from_nn, 0)[1]
action = action.item()
else:
action = env.action_space.sample()
return action
def optimize(self, state, action, new_state, reward, done):
state = torch.tensor(state).to(device)
new_state = torch.tensor(new_state).to(device)
reward = torch.tensor([reward]).to(device)
if done:
target_value = reward
else:
new_state_values = self.nn(new_state).detach()
max_new_state_values = torch.max(new_state_values)
target_value = reward + gamma * max_new_state_values
predicted_value = self.nn(state)[action]
loss = self.loss_func(predicted_value, target_value)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
if self.number_of_frames % save_model_frequency == 0:
print("** save the model **")
save_model(self.nn)
self.number_of_frames += 1
Epsilon-Greedy Strategy: Balancing Exploration and Exploitation
The concept behind the epsilon-greedy strategy is to balance the exploration of new actions with the exploitation of previously learned knowledge. This strategy recognizes the need to explore uncharted territories (untried actions) while exploiting the known actions that have yielded positive results in the past.
Here's a breakdown of how the select_action method within the QNet_Agent class implements this strategy:
if random_for_egreedy > epsilon:
def select_action(self, state, epsilon)
random_for_egreedy = torch.rand(1)[0]
if random_for_egreedy > epsilon:
with torch.no_grad():
state = torch.tensor(state).to(device)
action_from_nn = self.nn(state)
action = torch.max(action_from_nn, 0)[1]
action = action.item()
else:
action = env.action_space.sample()
return action
The epsilon-greedy strategy allows the agent to gradually reduce exploration as it gains more knowledge about the environment. Initially, when epsilon is high, there's a greater likelihood of exploring new actions. As training progresses and the agent refines its Q-network, the exploitation of learned actions becomes more dominant.
Balancing exploration and exploitation is critical for the agent's learning process. Too much exploration may slow down learning, while too much exploitation could lead to suboptimal policies. The epsilon-greedy strategy ensures that the agent efficiently explores the environment while maximizing its cumulative rewards over time.
Refining Q-Values: The Essence of Optimization
The optimize method in the QNet_Agent class is responsible for refining the Q-values stored within the Q-network. These Q-values represent an agent's estimates of the expected cumulative rewards associated with taking certain actions in specific states.
Let's break down how the optimize method is designed within the QNet_Agent class
def optimize(self, state, action, new_state, reward, done)
state = torch.tensor(state).to(device)
new_state = torch.tensor(new_state).to(device)
reward = torch.tensor([reward]).to(device)
if done:
target_value = reward
else:
new_state_values = self.nn(new_state).detach()
max_new_state_values = torch.max(new_state_values)
target_value = reward + gamma * max_new_state_values
predicted_value = self.nn(state)[action]
loss = self.loss_func(predicted_value, target_value)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
if self.number_of_frames % save_model_frequency == 0:
print("** save the model **")
save_model(self.nn)
self.number_of_frames += 1
Training the Agent: Embarking on 800 Episodes
With a solid understanding of Q-learning, Q-networks, and the QNet_Agent class, it's time to dive into the heart of the matter: training the agent to tackle the CartPole-v0 environment. Over the course of 800 episodes, we'll witness the agent's journey as it refines its decision-making skills and strives to achieve stability in balancing the pole.
number_of_inputs = env.observation_space.shape[0]
number_of_outputs = env.action_space.n
# Create an instance of the QNet_Agent class
qnet_agent = QNet_Agent(number_of_inputs, number_of_outputs)
# Initialize lists to track steps and rewards
steps_total = []
frames_total = 0
solved_after = 0
solved = False
start_time = time.time(
for i_episode in range(num_episodes):
state = env.reset()
step = 0
while True:
step += 1
frames_total += 1
epsilon = calculate_epsilon(frames_total)
action = qnet_agent.select_action(state, epsilon)
new_state, reward, done, info = env.step(action)
qnet_agent.optimize(state, action, new_state, reward, done)
state = new_state
if done:
steps_total.append(step)
mean_reward_100 = sum(steps_total[-100:])/100)
if (mean_reward_100 > score_to_solve and solved == False):
print("SOLVED! After %i episodes " % i_episode)
solved_after = i_episode
solved = True
print("Saving model after solving the game")
save_model(qnet_agent.nn, path="solved_cartpole.pt")
elapsed_time = time.time() - start_time
print("Elapsed time: ", time.strftime("%H:%M:%S", time.gmtime(elapsed_time)))
break
Exploring the Complete Code
If you're eager to dive even deeper into the details and explore the complete code implementation, you can find it on my GitHub repository. This repository hosts the entire codebase for training the QNet_Agent in the CartPole-v0 environment, along with additional resources that can help you grasp the concepts more effectively.