Search⌘ K
AI Features

Auto Vectorization

Explore how to apply JAX's vmap function for automatic vectorization to handle batch processing efficiently in deep learning. Understand how to vectorize over different input indices and outputs to simplify implementation while enhancing computational performance.

Introduction

Those that are familiar with stochastic gradient descent (SGD) will know that it is applied one sample at a time, thus making it computationally inefficient. Instead, we use it in the batches in a technique usually known as minibatch gradient descent.

This batching operation is a common practice throughout the deep learning regime and can be used for various tasks like convolution, optimization, and so on.

Let’s have a look at a convolution function for 1D vectors:

Python 3.8
a = jnp.arange(5)
b = jnp.arange(2,5)
def Convolve(x, f):
output = []
for i in range(1, len(x)-1):
output.append(x[i-1:i+2]@ f)
return jnp.array(output)
print(Convolve(a,b))

Batching

So far, the above function is only for a single pair of vectors. To apply it on a batch ...