How to create 3D Graphs using matplotlib in Python
Overview
We can create 3D plots in Python thanks to the mplot3d toolkit in the matplotlib library.
Matplotlib was introduced with only 2D plots in mind. However, as of the 1.0 release, 3D utilities were developed on top of 2D, so 3D implementations of data are available today.
In this article, we’ll cover the following using matplotlib:
- 3D scatter plot
- Rotate plot angle
- 3D line plot
- Surface plots
3D scatter plot
First, let’s import the pyplot and NumPy modules. We use NumPy random module to create some x,y, and z data. And in total, we have 40 points(n). We also use the rcparams to update the figure size.
# Importing pyplot and numpy from matplotlibfrom matplotlib import pyplot as pltfrom numpy import randomrandom.seed(31)mu = 3n= 40x = random.normal(mu, 1, size=n)y = random.normal(mu, 1, size=n)z = random.normal(mu, 1, size=n)plt.rcParams['figure.figsize'] = (7,5)
Let’s go ahead and create a 3D Scatter Plot.
Setting up our axis
We start by creating our sets of axis and referencing pyplot. And then use the keyword projection='3d' to tell matplotlib that we want to plot something in 3 dimensions.
Code
ax = plt.axes(projection='3d');
Then we add data to our axis and also label our axis:
# Adding data to our axes, 's=40' is to increase point size by 40ax.scatter3D(x, y, z, s=40)# Labelling your axesax.set_xlabel('x-axis')ax.set_ylabel('y-axis')ax.set_zlabel('z-axis');
Rotate plot angle
To rotate the plot angle, we use ax.view_init(). ax.view_init() takes in two arguments, elevation viewing angle and
ax = plt.axes(projection='3d')ax.scatter3D(x, y, z, s=100)ax.set_xlabel('x-axis')ax.set_ylabel('y-axis')ax.set_zlabel('z-axis');ax.view_init(45, 100);
3D line plot
Here we plot a trigonometric spiral. So let’s create some new data. z_line is a line space from zero to ten, x_line and y_line is the cosine and sine of the line z_line, respectively.
The NumPy linspace function generates a sequence of evenly spaced values within a specified interval.
We specify the start and endpoints of an interval and then set the total number of breakpoints we want in that interval.
# spiral controls the number of spiralspiral = 5# dataz_line = np.linspace(0, 5, 100)x_line = np.cos(spiral*z_line)y_line = np.sin(spiral*z_line)ax = plt.axes(projection='3d')ax.plot3D(x_line, y_line, z_line, lw=5);
Surface plots
A surface plot is a representation of a 3D data set. Describe the functional relationship between the two independent variables X and Z and the assigned dependent variable Y without showing individual data points. Companion plots of contour plots.
For example, express the equation y = 45 - ( - ) on a surface plot.
In this case, x and z are the independent variables, while y is the dependent variable.
We can solve this problem by following the steps below:
- Define a function for the y variable that takes in x and y as parameters. This function represents the surface we want to create.
- Set up a grid using the Numpy mesh function. This function takes in our x and z values.
- Plot the X and Z scatter plots to get a view of the grid.
- Compute the value of y by passing in the x and z variable into the function
- Finally, create our surface plots by passing in the three variables X, Z, and Y in the method
plot_surface().
def function_y(x, z):return 45 - (x**2 + z**2)N = 40x_values = np.linspace(-5, 5, N)z_values = np.linspace(-5, 5, N)X, Z = np.meshgrid(x_values, z_values)plt.scatter(X, Z);
Surface plots
Y = function_y(X, Z)ax = plt.axes(projection='3d')ax.plot_surface(X, Z, Y);