What is vectorizing map (vmap) in JAX?
The vectorization map (vmap) is a function in the JAX library that maps a function over one or more input arguments. It is a powerful tool to speed up the execution of a function, especially when the function is called many times with the same input arguments.
Syntax
The syntax of the vmap function is as follows:
jax.vmap(fun, in_axes=0, out_axes=0, axis_name=None, axis_size=None, spmd_axis_name=None)
Parameters
The vmap function has the following three parameters:
funis the function name to be mapped.in_axesspecifies which axes of the input data should be mapped over by the function.out_axesindicates where the mapped axes should appear in the function’s output.
The other parameters are optional and are listed below:
axis_nameis a unique name for the mapped axes used for parallel operations.axis_sizeis an integer indicating the axis size to be mapped.spmd_axis_nameis an optional name for the axis used for parallel execution in single program multiple data (SPMD) parallelism.
Return object
The vmap function returns a modified version of the fun function. It can apply the original function to multiple elements of an array in a batched or vectorized manner.
How to use vmap
To use the vmap function, we must define the function we want to vectorize.
Then, we’ll call vmap with the function as the first argument and the axes we want to vectorize over as the second argument.
Example 1
Consider the following playground:
import jaximport jax.numpy as jnpdef add(x, y):return x + yvmapped_add = jax.vmap(add)batch_x = jnp.array([1,2,3,4,5])batch_y = jnp.array([2,3,4,5,6])vmapped_result = vmapped_add(batch_x, batch_y)print("Output of vmap function:", vmapped_result)
- Lines 4–6:, We define the
addfunction, which performs the simple addition operation on two variables. - Line 7: To perform this operation on two batches of numbers, we use the
vmapfunction and thesumfunction as thevmaped_addfunction. - Line 12: We call the
vmaped_addfunction with two batches of data. We can see the result of element-wise addition.
Example 2
The in value of in_axes represents the mapping on the same axis while the represents the mapping of the -axis onto the -axis. In matrix terms, the value in in_axes will take the transpose of the matrix. For example, if the function has input of parameters, the in_axes will be a tuple of values. In the following example, the add function has two parameters, so the value of in_axes will include two values of and . The version of vmap and the sum function with the value of the in_axes parameter will take the transpose of the second input parameter and return the sum.
Let’s see how the in_axes parameter affects the functionality of the vmap function in the following playground:
import jaximport jax.numpy as jnpdef add(x, y):return x + yvmapped_add1 = jax.vmap(add, in_axes=(0,0))vmapped_add2 = jax.vmap(add, in_axes=(1,0))vmapped_add3 = jax.vmap(add, in_axes=(0,1))vmapped_add4 = jax.vmap(add, in_axes=(1,1))batch_x = jnp.array([[1,2],[3,4]])batch_y = jnp.array([[5,6],[7,8]])vmapped_result1 = vmapped_add1(batch_x, batch_y)vmapped_result2 = vmapped_add2(batch_x, batch_y)vmapped_result3 = vmapped_add3(batch_x, batch_y)vmapped_result4 = vmapped_add4(batch_x, batch_y)print("Output with in_axes = (0,0)")print(vmapped_result1)print("Output with in_axes = (1,0)")print(vmapped_result2)print("Output with in_axes = (0,1)")print(vmapped_result3)print("Output with in_axes = (1,1)")print(vmapped_result4)
- Lines 7–10: We created different
vmapfunctions with different values ofin_axes. - Lines 18–21: We call all the
vmapfunctions with the same input.
Example 3
The functionality of the out_axes is similar to that of in_axis but it performs the mapping on the output.
Let’s see the following example to understand that:
import jaximport jax.numpy as jnpdef add(x, y):return x + yvmapped_add1 = jax.vmap(add, in_axes=(1,0), out_axes=0)vmapped_add2 = jax.vmap(add, in_axes=(1,0), out_axes=1)batch_x = jnp.array([[1,2,3],[4,5,6]])batch_y = jnp.array([[5,6],[7,8],[9,0]])vmapped_result1 = vmapped_add1(batch_x, batch_y)vmapped_result2 = vmapped_add2(batch_x, batch_y)print("Output with out_axes = 0")print(vmapped_result1)print("Output with out_axes = 1")print(vmapped_result2)
- Lines 7–8: We created two different
vmapfunctions with the value of and ofout_axes. - Lines 11–15: We created two different batches of input values with different dimensions.
- Lines 17–18: We called all the
vmapfunctions with the same inputs to see the difference in the output. We can see the mapping of axes in the output.
Free Resources