Auto Parallelization

This lesson will introduce auto parallelization in JAX.

We'll cover the following

With ever-increasing data and computational resources, there is always a pressing need for parallel processing. Luckily, JAX facilitates this feature as well. Just like vmap(), we can use pmap() for parallel execution of a given function.

Remember: Since this lesson’s subjects, vmap() and pmap(), require GPU/TPU support, the code snippets are provided just as a guide here. Executing them on a normal, non-GPU, machine will result in an error!

Get hands-on with 1200+ tech skills courses.