What is the numpy.expand_dims() function in NumPy?
Overview
The expand_dims() function in NumPy is used to expand the shape of an input array that is passed to it. This operation is done in such a way that when a new axis is inserted, it appears in the axis position of the resulting expanded array shape.
Axis in NumPy is defined for arrays that have more than one dimension. For example, a2-Darray has two corresponding axes: the axes running vertically downward across rows (this is axis0) and the axes running horizontally across columns (this is axis1).
Syntax
numpy.expand_dims(a, axis)
Syntax for the "expand_dims()" function
Parameter values
The expand_dims() function takes the following values:
a: This is the input array.axis: This is the position in the resulting array where the new axis is to be positioned.
Return value
The expand_dims() function returns a view of the input array with an increased number of dimensions.
Code
import numpy as np# creating an input arraya = np.array([1, 2, 3, 4])# getting the dimension of aprint(a.shape)# expanding the axis of ab = np.expand_dims(a, axis=1)# getting the dimension of the new arrayprint(b.shape)
Explanation
- Line 1: We import the
numpymodule. - Line 3: We create an input array
ausing thearray()function. - Line 6: We obtain the shape of
ausing theshapeattribute. - Line 9: We expand the input array
ausing theexpand_dims()function. Then, we assign the result to a variableb. - Line 12: We obtain and print the
shapeattribute of the newly expanded arrayb.