Tech News
← Back to articles

Learning to Play Tic-Tac-Toe with Jax

read original related products more articles

Learning to Play Tic-Tac-Toe with Jax

In this article we’ll learn how to train a neural network to play Tic-Tac-Toe using reinforcement learning in Jax. This article will aim to be more pedagogical, so the code we’ll end up with won’t be super optimized, but it will be fast enough to train a model to perfect play in about 15 seconds on a laptop.

Code from this page can be found at this Github repo as well as in a Colab notebook (although the Colab notebook runs considerably more slowly).

Playing Tic-Tac-Toe in Jax

Before we get to the fancy neural networks and reinforcement learning we’ll first look at how a Tic-Tac-Toe game might be represented using Jax. For this we’ll use the PGX library, which implements a number of games in pure Jax. PGX represents a game’s state with a dataclass called State . This dataclass has a couple of fields:

current_player : This is simply a 0 or a 1 and alternates on every turn. What is perhaps confusing about this is that there is no relationship between player 0 and an X or an O. Player 0 is randomly assigned X or O on each game and X always goes first. This is helpful because it means that you can assign your neural net to always play as Player 0 and ensure that it plays as X (and goes first) half the time and plays as O (going second) half the time.

observation : This tells us what the board looks like at the current step. The representation PGX uses is a boolean array of shape (3, 3, 2) . The first two axes represent the 3x3 grid as you might expect, and then the first channel of the last axis is True wherever there is a piece for the current player and the second channel is True wherever there is a piece for the opponent. (Note that the axes switch on every turn since the current_player switches.) For example, here is a state that the board might be in: This gets represented as:

Array([[[False, False], [False, True], [False, True]], [[False, False], [ True, False], [False, False]], [[ True, False], [False, False], [False, False]]], dtype=bool)

legal_action_mask : This is a (flat) boolean array that with a False for every filled space and a True for every empty space.

rewards : This array is of shape (2,) and gives us the reward on each step. The first index gives us the reward for player 0 and the second for player 1. Note that the reward is provided for the state after a winning move is played. This means that we have to take into account the fact that the current player switches when determining the reward. Rewards are also not cumulative — if we continue to transition to new “states” after the game has ended (which happens due to batching), the rewards on subsequent states are 0.

... continue reading