What is the numpy.squeeze() function in NumPy?
Overview
The squeeze() function in NumPy is used to remove an axis of length 1 from an input array.
Axes in NumPy are defined for arrays having more than one dimension. For example, a 2-D array has two corresponding axes: the axes running vertically downward across rows (this is axis 0), and the axes running horizontally across columns (this is axis 1).
Syntax
numpy.squeeze(a, axis=None)
Syntax for the squeeze() function
Parameter value
The squeeze() function takes the following parameter values.
a: This is the input array. It is a required parameter.axis: This selects a subset of the length in the given shape. It is an optional parameter.
Return value
The squeeze() function returns the input array, a, but with the subset of the dimension with length 1 removed.
Example
import numpy as np# creating an input arraya = np.array([[[1], [2], [3], [4]]])# getting the length of aprint(a.shape)# removing the dimensions with length 1b = np.squeeze(a)# obtaining the shape of the new arrayprint(b.shape)
Code explanation
- Line 1: We import the
numpymodule. - Line 3: We create an input array,
a, using thearray()function. - Line 6: We obtain and print the dimensions of
ausing theshapeattribute. - Line 9: We remove the dimension of length
1from the input array,a, using thesqueeze()function. The result is assigned to a variable,b. - Line 12: We obtain and print the squeezed array,
b, with the dimensions of length1removed.