Matplotlib is a graphical library for plotting mathematical functions and data in Python. The library is capable of producing a wide range of plots; however, this article will focus on generating 2D and 3D plots, as this is a common use case.
Installing the prerequisites
In addition to matplotlib, we will be using NumPy in this tutorial. You might be wondering, “Why not use Python’s built-in data structures?” We can’t use it for complex examples because the built-in data structures don’t have all the functions we need. (More on this later.)
Matplotlib can easily be installed on Python 2 and 3 with the following command:
python -mpip install matplotlib
Please ensure you reference your desired Python version to avoid any surprises. (You can get more help from the official website if this does not work for you.)
If you don’t have NumPy installed already, you can follow these instructions to install it for your operating system.
Hello World in matplotlib
Now that you have successfully installed matplotlib, let’s jump into a simple example to give us a feel for how the library works. First, copy and paste the code below in your IDE and run it.
import matplotlib.pyplot as plt plt.plot([1, 2, 3, 4]) plt.ylabel('y-axis') plt.xlabel('x-axis') plt.show()
You should see a new window like the one below with a graph of a straight line. Note that the graph will appear in the window of your IDE if you are using PyCharm with scientific mode activated.
Note that the x and y axis have the names we provided in the xlabel and ylabel methods of matplotlib. You need to call plt.show() if you want to see the graph after running the code.
Plotting multiple functions on one graph
We now have an idea of how matplotlib works, so let’s move on to a more complex and practical use case: plotting multiple functions in the same graph. Note that we will be using NumPy from now on.
The code below plots two quadratic functions on the same graph as seen in the image that follows.
import matplotlib.pyplot as plt import numpy as np x = np.arange(0, 1000, 20) y1 = x ** 2 + 500 * x y2 = x ** 2 plt.plot(x, y1) plt.plot(x, y2) plt.show()
If you are not familiar with NumPy, you might be surprised by the simplicity of the mathematical expressions in the code. Let’s try doing the same thing with Python’s lists using the code below.
npoints = 1000 x = [x for x in range(npoints)] y1 = [x**2 + 500 * x for x in range(npoints)] y2 = [x**2 for i in range(npoints)] plt.plot(x, y1) plt.plot(x, y2) plt.show()
Not only is the code more complex, but matplotlib also produces an error if you run the code. This is because matplotlib attempts to perform operations that lists do not support. As a rule of thumb, you should avoid using lists in mathematically intensive situations (or just don’t use them when plotting).
You may have observed that matplotlib was smart enough to use different colors for each function. You can control this behavior by providing special options to the plot method, like so:
plt.plot(x, y1, 'r o') plt.plot(x, y2, 'g ^')
The extra option requires a character representing a color and another character that represents the symbol that would be plotted at each point on the function. This is particularly useful for scatter plots where the data may not be continuous. Replacing the calls to plot in the previous example with this code produces a graph like that shown below.
Finally, I will show an example of creating a three-dimensional plot using a helix. A helix is simply the spiral shape formed by a spring. To make a 3D plot, we need to create a three-dimensional axis and supply it with data for the x, y and z axis. The code below plots a helix.
import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D import numpy as np from numpy import sin, cos fig = plt.figure() # preferred method for creating 3d axis ax = fig.add_subplot(111, projection='3d') r = 10 c = 50 t = np.linspace(0, 5000, 100) # parametric equation of a helix x = r*cos(t) y = r*sin(t) z = c*t ax.plot(x, y, z, zdir='z', lw=2) plt.show()
Similar to 2D plots, you can make 3D scatter plots by replacing ax.plot(x, y, z, zdir=’z’, lw=2) with ax.scatter(x, y, z, zdir=’z’, lw=2, c=’g’, marker=’^’). The images below show the results of the plot and scatter methods respectively.
As you might expect, you can rotate the 3D graph by clicking and dragging on the plot.