Alex's Notes

VanderPlas: Chapter 04 Visualization with Matplotlib

Metadata

Core Ideas

Plotting basics

We can plot from a script, just include a plt.show() method to open a graphic window to show it. This should only be used once per session, otherwise the results are undefined. We tend to do it at the end of a script.

We can plot in an IPython shell, we need to specify matplotlib mode by %matplotlib magic command after starting IPython. After this, any plt.plot() command will cause a graphic window to open. To force an update we can run plt.draw(), plt.show() is not required.

In a notebook we have the choice of two modes:

  • %matplotlib notebook will embed interactive plots in the notebook

  • %matplotlib inline will embed static images of the plot.

In inline mode, any cell within the notebook that creates a plot will embed a PNG of the graphic after it.

To save a figure we can do this: fig = plt.figure() then fig.savefig('my_fig.png') Filetype is inferred by the name, you can see which are supported by fig.canvas.get_supported_filetypes()

Interfaces

There are two interfaces with the matplotlib api, which can get confusing.

The first is a MATLAB-style interface, refelecting the origins as a MATLAB alternative. It is stateful, keeping track of a ‘current’ figure and axes. So if you create a subplot, and then another, you’ll switch context to the second. How do you change the first? it’s clunky.

The second interface is object-oriented, and is good for more complex plotting needs or when you want more control. It doesn’t rely on a ‘current’ or ‘active’ figure or axes, it uses explicit figure and axes objects, that provide methods to manipulate them.

Creating Line and Scatter Plots

For all matlplotlib plots we start by creating a figure and axes, in their simplest form fig = plt.figure(), ax = plt.axes()

  • The figure (an instance of the class plt.Figure can be thought of as a single container that contains all the objects representing axes, graphics, text, and labels.

  • The axes (an instance of plt.Axes) is a bounding box with ticks and labels that will contain the plot elements that make up the visualization.

Once we have an axes instance we can call ax.plot to plot some data. If we want to create a single figure with multiple lines, we can call ax.plot multiple times.

We can pass in a linestyle=keyword argument to set the style (eg ‘solid’, ‘dashed’, ‘dashdot’, ‘dotted’), and a color=<val> to set the colour. Colour can be hex code, rgb tuple, html colour name, grey scale between 0 and 1, and colour code (‘rgbcmyk’).

Matplotlib will set default axes limits but you can specify with plt.xlim(start,end) and plt.ylim(start,end). Alternatively plt.axis([xstart, xend, ystart, yend]). axis can also be used to specify an equal aspect ratio plt.axis('equal') or to tighten the bounds: plt.axis('tight').

You can label a plot with plt.title("My graph") and axes with plt.xlabel("x") and plt.ylabel("y").

In the object oriented style the methods have slightly different names and it’s more convenient to use the ax.set() method to do them all at once:

ax = plt.axes()
ax.plot(x, np.sin(x))
ax.set(xlim=(0,10), ylim=(-2,2), xlabel='x', ylabel='sin(x)', title='a simple plot')

You can create a legend if plotting multiple functions as follows:

plt.plot(x, np.sin(x), 'green', label='sin(x)')
plot.plot(x, np.cos(x), 'blue', label='cos(x)')

plt.legend()

You can use plot to plot a scatter chart as well by passing a symbol as the third argumet but a more powerful way of doing it is to use plt.scatter(x,y,marker='o') It’s more powerful because then properties of each point can be mapped to the data and individually controlled.

For example:

rng = np.random.RandomState(0
x = rng.randn(100)
y = rng.randn(100)
colors = rng.rand(100)
sizes = 1000 * rng.rand(100)

plt.scatter(x,y,c=colors, s=sizes, alpha=0.3, cmap='viridis')
plt.colorbar() #show color scale

size is given in pixels, color argument is mapped to a scale. So we can use color and size to visualize multidimensional data.

Histograms

The basic plt.hist method can just take a numpy series and return a histogram with a default number of bins. or it can be customized:

plt.hist(data, bins=30, normed=True, alpha=0.5,
   histtype='stepfilled', color='steelblue',
   edgecolor='none')

Stepfilled and alpha can be good to compare histograms of different distributions.

We can use a variety of methods to plot two-dimensional histograms:


# straightforward 2d hist
plt.hist2d(x,y, bins=30, cmap='Blues')
cb = plt.colorbar()
cb.set_label('counts in bin')

# use a hexagonal tessalation instead
plt.hexbin(x,y,gridsize=30, cmap='Blues')
cb = plt.colorbar(label='count in bin')

Legends

The simplest legend can be created with ax.legend() or plt.legend() But there are lots of options of course:

  • loc=<location> eg ‘upper left’, ‘lower center’

  • frameon=False turns off the frame

  • ncol=<cols> set number of columns

  • fancybox=True round the box

  • framealpha=1 set the transparency

  • shadow=True box shadow

  • borderpad=1 set border padding

The legend includes all labeled elements by default. You can customize this by passing it the objects you want to include that are returned by the plot methods. EG lines = plt.plot(x,y), then plt.legend(lines[:2]). It will by default ignore elements if there is no ‘label’ set on them.

Sometimes we want to show information not in the legend defaults. For example if we’re using size semantically, we might want to show a legend accordingly.

We can plot empty lists to ‘fake’ the data since the legend must reference some object in the plot:


for area in [100,300,500]:
  plt.scatter([],[],c='k', alpha=0.3, s=area,
    label=str(area) + ' km$^2$')

plt.legend(scatterpoints=1, frameon=False,
	   labelspacing=1, title='City Area')

Annoyingly matplotlib only allows a single legend with the main legend() method, trying to do a second one will just overwrite the first. You have to manually call the ax.add_artist() method to create a second legend like this:

from matplotlib.legend import Legend
leg = Legend(ax, lines[2:], ['Line C', 'line D'], loc='lower right', frameon=False)
ax.add_artist(leg)

Colour Bars

Colour bars are for continuous, rather than discrete labels. Most simply you can just use plt.colorbar(). You can also add a cmap to the plotting function to specify the colour map used. To see the choices do plt.cm.<TAB> in IPython.

The colorbar itself is just an instance of plt.Axes so you can use all the axes and tick formatting from that class. You can use them discretely rather than continuously if you want, you have to call plt.cm.get_cmap() with the colormap name and number of bins.

Other options include indicating out of bound values with the extend argument when you create the colorbar. And narrowing the color limits with plt.clim(lower, upper)

Multiple Plots

Often we want to compare different views of data side by side. Matplotlib has the concept of subplots for this. Groups of smaller axes that can exist together in a single figure. These subplots might be insets, grids of plots, or more complex layouts.

By default plt.axes creates an axes object that fills the figure. Instead we can pass a list of four numbers [bottom, left, width, height] in the figure co-ordinate system, which range from 0 at the bottom left to 1 at the top right, and percentage for width and height (between 0 and 1). In the object oriented api the best way of doing this is:

fig = plt.figure()

ax1 = fig.add_axes([bottom, left, width, height])
ax2 = fig.add_axes([bottom, left, width, height])

ax1.plot(data)
ax2.plot(data)

If you want to create a regular grid of plots you can do it using convenience methods, here’s the OO api approach:

fig = plt.figure()
fig.subplots_adjust(hspace=0.4, wspace=0.4)

for i in range(1,7):
  ax = fig.add_subplot(2,3,i)
  ax.text(0.5, 0.5, str((2,3,i)), fontsize=18, ha='center')

More complex arrangements can be achieved with the plt.GridSpec() tool, which creates something like a CSS grid that can then be used to layout plots.

grid = plt.GridSpec(2,3,wspace=0.4,hspace=0.3) # creates a 2x3 grid

plt.subplot(grid[0,0])  # occupies one cell
plt.subplot(grid[0,1:]) # column span
plt.subplot(grid[1,:2])
plt.subplot(grid[1,2])

See the book, p. 267 for a complex example.

Seaborn

Seaborn provides an API on top of matplotlib designed to give better defaults, remove boilerplate, and work nicely with Pandas.