[Matplotlib] Multiple axes with different scales

It’s quite often we will output the chart of the records that a neural network have played during training a neural network playing games.

As the records grows, the chart becomes crowded and is difficult to see how the line graph grows.

Therefore, we come up with a strategy that we only plot out the average of each 250 records in our chart.

Not too bad, but seems a bit monotonous. Let’s overlap the two chart together.

Before we start, we notice that if using the original way plt.plot() to plot two data sets in a same chart, the heigh of the latter data set will be flattened due to their different range of y axes.

The feasible approach is to create two subplots with each of them contains one dataset. In this way, two subplots will have their own scales independently and thus not affecting others.

fig, ax1 = plt.subplots()
plt.title('Numbers of blocks NN opened per Game')
ax2 = ax1.twinx()

The code above will create our first subplot and set up the title as well as the x axis label of the chart. Lastly, generate another subplot called ax2.

Next, let’s feed the raw records into our first subplot.

ax1.set_ylabel('Opened Blocks', color='tab:blue')
ax1.plot(range(len(data)), data, color='tab:blue', alpha=0.75)
ax1.tick_params(axis='y', labelcolor='tab:blue')

Apply the same process to our second subplot with the average of each 250 records this time.

ax2.set_ylabel('Opened Blocks (Avg of 250 Games)', color='black')
ax2.plot([ 250 * (i + 1) for i in range(len(data_avg))], data_avg, color='black', alpha=1)
ax2.tick_params(axis='y', labelcolor='black')

When everything is done, we are now able to display our chart with multiple Y axes.


Pretty neat isn’t it.

Leave a Reply