This blogpost accompanies my Github-project TF-rex.

The goal of this project is to play Google’s offline T-rex Dino game using Reinforcement Learning (RL). The RL algorithm is based on the Deep Q-Learning algorithm [1] and is implemented in TensorFlow (TF), hence the name TF-rex ;). Google’s offline game consists of a T-rex striving to dodge obstacles, such as cactuses and birds, and surviving as long as possible. The dino is able to perform three actions: “jumping”, “ducking” and “going forward”. You can try it yourself.

T-Rex gane played by Q-learner

In this post, I’ll talk about the project’s implementation, the RL setup and algorithms and the training procedure. We assume that the reader is familiar with the basics of RL. If not, no worries! This blogpost from Arthur Juliani will quickly get you up to speed with some of the essentials.

After reading this post, you should be able to fork the Github project and get going yourself. There are several ways to extend this work, we’ll discuss them below in the TODO section.

Overview

We want to create an “AI algorithm” that can play Google Chrome’s T-rex game. This project isn’t the first one to try this, but compared to existing projects [2, 3], we do two things differently. First, the AI is interacting in real-time with the real game, hosted in the browser. We are not using any sort of emulator or framework to slow down the game or frame rate. While this makes the framework genuinely more interesting and more directly applicable to other browser-based games, it requires some extra engineering. For example, we need to extract the game state from the browser and implement a duplex channel to allow for communication between the AI program and the browser. This channel is necessary to pass information such as actions and game states.

Secondly, the AI algorithm is trained with images of the game state. This is in contrast to earlier attempts extracting useful features (e.g. distance to next obstacle and length of the next obstacle) to train the AI. In our opinion, this setup resembles a real-life learning environment much closer.

The codebase exists of two larger components: the javascript T-rex game and the AI python program. Both modules are given in the diagram below. They communicate through a bidirectional websocket using predefined commands.

Architecture

The diagram depicts the typical RL cycle [4]: an agent interacts with its environment in discrete time steps. At each time, the agent receives an observation, which typically includes the reward and state of the game. It then chooses an action from the set of available actions, which it sends to the environment. The environment moves to a new state by executing the action and the reward associated with the transition is determined. The goal of a reinforcement learning agent is to collect as much reward as possible. More on this below.

In the next section “Implementation” we’ll talk about the structure of the code and how we turned the in-browser game into an RL environment. In the “Reinforcement Learning” section the AI algorithm is discussed. In that section we abstract the fact that we are interacting with a real-life game and assume that we have an environment that provides us with the next state and associated reward, given a new action.

Implementation

Creating the RL environment

We want to turn the javascript game, which is running inside the browser, into an RL environment. This means that we want to have a similar interface as, for example, OpenAI Gym environments. The Open AI interface looks as follows [5]

ob0 = env.reset() # sample environment state, return first observation
a0 = agent.act(ob0) # agent chooses first action
ob1, rew0, done0, info0 = env.step(a0) # environment returns observation,
# reward, and boolean flag indicating if the episode is complete.
a1 = agent.act(ob1)
ob2, rew1, done1, info1 = env.step(a1)
...
a99 = agent.act(o99)
ob100, rew99, done99, info2 = env.step(a99)
# done99 == True => terminal

To achieve this, we need to conquer some hurdles. We need to…

  1. implement a bidirectional communication channel between the browser/javascript code and the Python AI program,
  2. perform actions, i.e. let the dino jump and duck.
  3. extract the game state from the game and turn it into a parseable format. Let’s tackle these one by one.

To allow for bidirectional communication, we extent the game’s javascript code. We chose to use WebSocket objects as the transportation medium. When the page is done loading (i.e. onDocumentLoad event) the code will try to establish a connection with an entity running on 127.0.0.1:9090.

function onDocumentLoad() {
    runner = new Runner('.interstitial-wrapper');
};
document.addEventListener('DOMContentLoaded', onDocumentLoad);
var socket = new WebSocket("ws://127.0.0.1:9090"); # will connect Python-side server

The only thing left to do is to create the listening entity on the Python side. In the project’s repo you can find the websocket_server.py file, which contains the implementation of a websocket server. When a Environment object is initialized, it will create such a websocket and order it to listen for connections on 127.0.0.1:9090. (Additionally, the socket will be placed in its own thread and communicate with the main thread using a multiprocessing.Queue.) As long the browser doesn’t establish a connection with this socket, the Python program will halt.

class Environment:
    """
    Environment is responsible for the communication with the game, through the socket.
    """
    def __init__(host, port): # host = 127.0.0.1 and port = 9090
        self.queue = multiprocessing.Queue()
        self.server = WebsocketServer(port, host=host)
        thread = threading.Thread(target = self.server.run_forever)
        thread.start()
  ...

As long as both entities -the javascript code inside the browser and the Python WebSocket object- stay alive, we have a duplex communication channel. As a result, on both sides we have functions to send messages and we have callback-methods to process incoming messages. We will use this to solve the next hurdle: executing actions in the game, requested by the AI program.

The javascript snippet below shows the onmessage-callback in the javascript code, which is ran when a new message from the Python code comes in

socket.onmessage = function (event)
{
    var command = event.data;
    var runner = new Runner();
    console.log(command);

    switch (command) {
        case 'STATE':
            runner.postState(socket);
            break;
        case 'START':
            simulateKey("keydown", 32); // space
            setTimeout(function() {simulateKey("keyup", 32);}, 1000);
            break;
        case 'REFRESH':
            location.reload(true);
            break;
        case 'UP':
            simulateKey("keydown", 38); // arrow up
            setTimeout(function() {simulateKey("keyup", 38);}, 400);
            break;
        case 'DOWN':
            simulateKey("keydown", 40); // arrow down
            setTimeout(function() {simulateKey("keyup", 40);}, 400);
            break;
        default:
    }
};

You can see that we specified a simple communication-protocol, consisting of 5 commands: ‘STATE’, ‘START’, ‘REFRESH’, ‘UP’ and ‘DOWN’. The latter two, ‘UP’ and ‘DOWN’, are actions used in the game to control the dino. We execute the action by simulating a keypress on the keyboard. This is easily done in javascript using function simulateKey(type, keyCode) in combination with a setTimeout(). For example, when the Python AI program sends the message ‘UP’, the javascript code will first receive the message and subsequently simulate a press on the arrow-up key, which will cause the T-rex to jump.

The first three actions, ‘STATE’, ‘START’ and ‘REFRESH’, are controlling commands. The ‘STATE’ command will issue the javascript code to collect the current game state (i.e. the current image displaying the dino and the obstacles) and send it over the socket to the python side. The postState-function, ran when a ‘STATE’ message is received, looks like this

postState: function (socket) {
    console.log("in postState function");
    var canvas = document.getElementById('runner-id');
    var dataUrl = canvas.toDataURL("image/png");
    var state = {
        world: dataUrl,
        crashed: this.crashed.toString()
    }
    socket.send(JSON.stringify(state))
    }

The function will read the canvas - this is the PNG image you see on the screen - and parse it into a base64-encoded string. Next, a state-message is created with the base64-encoded image-string and a boolean indicating whether or not the dino is still alive. The state-message is parsed into JSON format and sent over the socket. On the python side, we parse the state-message as follows

data = json.loads(message)
image, crashed = data['world'], data['crashed']

# remove data-info at the beginning of the image
image = re.sub('data:image/png;base64,', '', image)
# convert image from base64 decoding to 2D numpy array
image = np.array(Image.open(BytesIO(base64.b64decode(image)))) # <-- tricky one
# cast string boolean to python boolean
crashed = True if crashed in ['True', 'true'] else False

The most important line in the snippet above is the decoding of the base64 image-string into a 2D image matrix. All the functionality is provided by standard Python libraries, such as PIL (Python Image Library), JSON and base64. However, it took me some time to find the correct ones as weird behaviour occurs easily when passing images between different programming languages or even libraries.

We faced the three obstacles that were required to turn the in-browser game into an RL environment. The Environment class contains most of this logic and provides an easy interface to the game. The most important method is do_action() (similar to act() in Open AI Gyms)

class Environment:
  def __init__():
    # see above

  def do_action(self, action):
     """
     Performs an action and returns the next state, reward and crash status
     """
     if action != Action.FORWARD:
         # noting needs to be send when the action is going forward
         self.server.send_message(self.game_client, self.actions[action])

     time.sleep(0.1)

     return self._get_state(action)

 def _get_state(self, action):
     self.server.send_message(self.game_client, "STATE")
     next_state, crashed = self.queue.get() # <-- halt while queue is empty (waiting for state-message)
     reward = _calculate_reward(action, crashed)
     return next_state, reward, crashed

Besides the do_action() method, a Environment-object can also start() and refresh() the game. In a similar fashion as the Open AI Gym function reset().

Architecture

The diagram below shows the overall architecture of the project. It illustrates the division of the javascript code running inside the browser and the python code, communicating over a websocket. The classes WebSocket and Environment were discussed above. In the next section, we’ll focus on the remaining classes. They are responsible for the actual learning.

Architecture

Reinforcement Learning (RL)

To teach the dino how to dodge the approaching obstacles we chose a Deep Q-learning approach, proposed in [1] by DeepMind. We briefly discuss this algorithm from a theoretical point-of-view, and explain how we implemented it.

The main idea of Deep Q-learning is to use a (deep) parametric neural network to approximate the Q-function. The Q-function of a Markov decision process (MDP), often denoted by gives the expected utility of taking a given action in a given state and following an optimal policy thereafter. In other words, the Q-function is a function which takes as arguments an action and a state and returns the expected total future reward of the agent if it would execute the action in state and continue performing optimal actions. As we wish to maximize our reward in every state we typically execute the action which optimizes the total future reward, namely . In the literature, this is referred to as the agent’s policy.

An important property of the Q-function is that , where is the state the agent ends up by performing action in state , is the corresponding reward and is the discount factor. This equation is referred to as the Bellman equation. We will later use it to train our neural network.

If we let the agent interact with the environment for a while, we end up with a collection of SARSA elements. SARSA stands for State, Action, Reward, State’ and Action’. It holds the current state of the agent S, the action the agent chooses A, the reward R the agent gets for choosing this action, the state S’ that the agent will now be in after taking that action, and finally the next action A’ the agent will choose in its new state. Taking every letter in the quintuple yields the word SARSA. This sequence is depicted in the figure below.

RL loop

In the context of TF-rex is our action space limited to three elements: “ducking”, “jumping” and “going forward”. This makes it relatively easy as we only need to deal with a small number of discrete actions. Handling continuous actions or a large space of discrete ones makes the learning typically much harder. The state space, on the other hand, is quite large as it consists out of four stacked images of size 80×80 (i.e. input dimensionality, ). In the next section we show how we create these input vectors.

Preprocessing

We don’t directly use the images we receive from the javascript game as states. We need to preprocess them before using them as the inputs of the deep Q-learning algorithm. This accelerates the training as we eliminate noisy parts from the image. We also collect multiple images into a single state which serves as a kind of memory.

The image below shows the original version, i.e. the image collected and sent by the javascript game and received by the Environment. It is a grey-scale image and has dimensions 150×600. The highscore and current score are shown in the upper-right corner.

Original image

We apply two preprocessing steps on this image.

  1. extract a region-of-interest (ROI)

    To dodge the obstacles the left part of the image is clearly much more informative than the right side. Therefore, we select roi = image[:, 420], which drops 30% of the right-side pixels. This operation reduces the number of input pixels, as roi is 150×420 and removes the meta-information in the upper-right corner.

  2. remove harmless objects

    The cloud, which you can see in the middle of the original image, doesn’t hurt the dino. The dino can touch it (i.e. have overlapping pixels) without dying. Harmless obstacles are easily filtered out, as they have a lighter colour than real obstacles, so we chose to remove them from the images using some straightforward masking techniques.

ROI and Remove clouds

We then apply a standard preprocessing step (inspired from the Atari games paper): resizing the image to 80×80 grid of pixels.

Squaring

Finally, we stack the last 4 frames in order to produce an 80×80×4 array which serves as a state for the Deep Q-Learning algorithm.

All these operations are performed by the Preprocessor class.

Learning

Like most RL algorithms, Deep Q-learning uses a SARSA observation to get an unbiased estimate of the error where denotes that the Q-function is built out of a neural net with parameters . Minimizing w.r.t. is used to optimize the neural network parameters.

While this approach works in theory, in practice we see that that during the optimization the neural net tends to oscillate or diverge. See, for example, David Silver’s courses for an explanation why [6]. A couple of tricks are introduced to reduce this unwanted behaviour.

1. Experience Replay

Instead of using only a single SARSA observation in the error function we use a batch of observations. This breaks correlations in the data, reduces the variance of the gradient estimator, and allows the network to learn from a more varied array of past experiences.

From an implementation point of view this means that we need to store each observed SARSA element. At training time, we also need to be able to sample from this memory in order to get a new batch of experiences. We implemented a Memory class which does exactly this. We choose to use a FIFO-queue to store the SARSA elements.

class Memory:

    def __init__(self, size):
        self.size = size
        self.mem = np.ndarray((size,5), dtype=object)
        self.iter = 0
        self.current_size = 0

    def remember(self, state1, action, reward, state2, crashed):
        self.mem[self.iter, :] = state1, action, reward, state2, crashed
        self.iter = (self.iter + 1) % self.size
        self.current_size = min(self.current_size + 1, self.size)

    def sample(self, n):
        n = min(self.current_size, n)
        random_idx = random.sample(list(range(self.current_size)), n)
        sample = self.mem[random_idx]
        return (np.stack(sample[:, i], axis=0) for i in range(5))

2. Target Network

In order to avoid oscillations, we will use a different network, called the target network, to estimate the Q-values during several epochs. The target network has fixed parameters, which reduces the variance and makes the learning more stable. We update the target network parameters with the values of our main network periodically. The loss is now calculated as where is our main network and is our target network. Importantly, both networks have the same architecture but can have different values for the parameters. Periodically we need to update the target parameters with the newest values, i.e. . By using this seperate network to compute the targets we get a more stable training procedure as we reduce the number of constantly shifting values in the loss function.

Most of the learning logic is implemented in the DDQNAgent class. Below we show the method that computes the targets. We start by sampling for the experience memory and then calculate for every SARSA element in the sample.

class DDQNAgent:

    def replay(self):
        states, actions, rewards, states_next, crashes = self.memory.sample(self.batch_size)
        target = rewards
        # add discounted Q value of next state to non-terminal states (i.e. not crashed)
        target[~crashes] += self.discount * self.target_dqn.get_action_and_q(states_next[~crashes])[1]
        self.main_dqn.train(states, actions, target)

Network architecture

The Deep Convolutional NN architecture used to solve the T-rex game is based on [2] but is extended with a dueling structure. It contains three convolution layers with ReLu activation functions, one max pooling layer and two fully connected layers. The state (i.e. four stacked images) first go through a convolution layer with 32 filters of size 8×8 with stride 4, followed by a ReLU layer. Then a 2×2 max pooling is applied to the output of convolution. The tensor then go through two convolution layers with 64 filters of size 4×4, stride 2 and 64 filters of size 3×3, stride 1. We then flatten the tensor to pass it through a dueling layer.

The idea behind a dueling structure is to decompose the Q value into two parts. The first is the value function which indicades the value of being in a certain state. The second is the advantage function which tells how much better taking a certain action would be compared to the others. We can then think of . The goal of Dueling DQN is to have a network that separately computes the advantage and value functions, and combines them back into a single Q-function only at the final layer [7, 8]. We achieve this by implementing two fully connected layers for both the value and advantage function. We then average the output of these two parts to end up with the final Q value.

The network logic is implemented in the DQN class.

Reward function

The reward function I used to train the model looks as follows:

if crashed:
    reward = -100
else:
    if action == UP:
        reward = -5
    elif action == DOWN:
        reward = -3
    else:
        reward = 1

It favors going forward over jumping and ducking and ducking over jumping. This reward helps the model understand that random actions when there are no obstacles are unnecessary. In earlier attempts I used another pretty straightforward reward function: -100 if the dino crashes else +1. While I think that this reward function should work in practice, it leads to a lot of spurious jumps and ducks.

Driver

Finally, in the main.py file you will find the code that initialises the different components, such as the agent, the environment and the preprocessor, and starts the learning loop. There is also some code to checkpoint the models (i.e. save the current state of the neural network in order to restore it later) and to send some interesting statistics to Tensorboards in order to monitor the experiments.

epoch = 0
while True:
    epoch += 1

    frame, _ , crashed = env.start_game()
    frame = preprocessor.process(frame)
    state = preprocessor.get_initial_state(frame)
    ep_steps, ep_reward = 0, 0

    while not crashed:

        action, explored = agent.act(state)
        next_frame, reward, crashed = env.do_action(action)
        next_frame = preprocessor.process(next_frame)
        next_state = preprocessor.get_updated_state(next_frame)
        agent.remember(state, action, reward, next_state, crashed)

        ep_steps += 1
        ep_reward += reward
        state = next_state

    agent.replay(epoch)
    agent.explore_less()
    agent.update_target_network()

Results

Training a TF-rex model takes about half a day. As we are interacting with the real game, we can not speed-up this process – the TF-rex just has to play the games one-by-one. After training, TF-rex typically reaches a score of around 1600, which is reasonably good, but on average worse than when I play the game. The reason for this is that the speed of the game changes over time. At the start of the game the obstacles approach the dino at a relatively slow speed, but as the game advances the obstacles move faster and faster towards the dino. This causes two problems:

  1. The probability that the dino dies before entering the increased velocity part is high. This limits the number of high-speed SARSA elements in the experience replay buffer, making our Memory highly unbalanced. As a result, when we sample uniformly from this memory we mainly end up with low-speed transitions that are used to optize the Bellman equation error. Effectively overlooking the behaviour of TF-rex in high-speed parts.

  2. Exploration in these high-speed parts is typically catastrophical. Any uncontrolled jump or duck will result in a crash. This limits the dino to learn new behaviour in the changing environment.

These problems should be addressed, I’m open for suggestions. I believe Prioritized Experience Replay [9] could proof useful.

Progress whilst learning

Runs of trained model

TODO

  1. Implement more techniques to speed up the learning: e.g. gradient clipping and batch norm.
  2. Try the learning framework on other javascript browser games.
  3. Connect the game with other learning algorithms from, for example, TensorForce.
  4. In all the models I’ve trained so far the dino prefers jumping over the birds rather than ducking underneath them. Eventhough I’ve modified the reward function to encourage the dino to duck instead of jump. It would be interesting to see which reward function or approach makes the dino duck.
  5. ofcourse: improve the performance :)

References

[1] Playing Atari with Deep Reinforcement Learning
[2] AI for Chrome Offline Game
[3] IAMDinosaur
[4] Wikipedia page RL
[5] Open AI Gym
[6] RL course David Silver
[7] RL blogpost
[8] Target network paper
[9] Prioritized Experience Replay

Thanks!

Thanks for reading this lengthy post about my project. Hope you enjoyed it and found it useful in some way. If you have any question, please don’t hesitate to drop me an email or find me on twitter.