In this project, we will train a Deep Q-Network (DQN) agent to try and solve Unity's Banana Collector environment.
A reward of +1 is provided for collecting a yellow banana, and a reward of -1 is provided for collecting a blue banana. Thus, the goal of our agent is to collect as many yellow bananas as possible while avoiding blue bananas.
The state space has 37 dimensions and contains the agent's velocity, along with ray-based perception of objects around agent's forward direction. Given this information, the agent has to learn how to best select actions. Four discrete actions are available, corresponding to:
- 0 - move forward.
- 1 - move backward.
- 2 - turn left.
- 3 - turn right. The task is episodic, and in order to solve the environment, our agent must get an average score of +13 over 100 consecutive episodes.
Here are the instructions to follow if you'd like to try out this agent on your machine. First, you'll need at least Python 3.6 installed on your system. You will also need these libraries to help run the code. Most of these can be installed using the 'pip install' command on your terminal once Python has been installed.
-
numpy - NumPy is the fundamental package for scientific computing with Python
-
collections - High-performance container datatypes
-
torch - PyTorch is an optimized tensor library for deep learning using GPUs and CPUs
-
unityagents - Unity Machine Learning Agents allows researchers and developers to transform games and simulations created using the Unity Editor into environments where intelligent agents can be trained using reinforcement learning, evolutionary strategies, or other machine learning methods through a simple to use Python API
-
matplotlib.pyplot - Provides a MATLAB-like plotting framework
-
Download the environment from one of the links below. You need only select the environment that matches your operating system:
- Linux: click here
- Mac OSX: click here
- Windows (32-bit): click here
- Windows (64-bit): click here
(For Windows users) Check out this link if you need help with determining if your computer is running a 32-bit version or 64-bit version of the Windows operating system.
(For AWS) If you'd like to train the agent on AWS (and have not enabled a virtual screen), then please use this link to obtain the environment.
-
Place the file in a folder, and unzip (or decompress) the file. The environment and the agent's training code have to be in the same path. If you'd like to store the environment in a different path, that needs to be udpated in the 'Navigation_Train.py' and 'Navigation_Test.py' files accordingly.
The repo contains three main files:
-
Navigation_Train.py - This file, written in Python 3.6 with the help of the PyTorch framework contains the agent and the model that we use to train the agent with. It runs until the agent has solved the environment which can vary between 200-300 episodes depending on the hyperparameter selection.
-
Navigation_Test.py - This file, also written in Python 3.6 has the code to test the trained agent with. It runs for a total of 10 episodes and plots the performance in each one of them.
-
checkpoint.pth - This file is where the trained agent's DQN weights are stored. You may use this file if you'd like to use the pretrained agent to solve the Banana collector environment. This file also gets recreated every time you run the Navigation_Train.py file. So you can create your own checkpoint.pth file with your choice of hyperparameters!
- Clone/download the three files listed above and add them in the same folder as the Banana Collector environment on your machine. You can run the code using a terminal like Anaconda Prompt or anything that can run python commands like 'pip'.
- Once you navigate to the folder where the project files are located using the 'cd' command, run either the 'Navigation_Train.py' file if you'd like to train the agent or the 'Navigation_Test.py' file if you would like to see a pretrained agent in action!
Please refer to the Report.md file if you'd like an in-depth look of the architecture
