JAX Expressions (Jaxpr)
Explore how JAX expressions, called jaxpr, encode function operation sequences for effective analysis and optimization. Understand lambda expressions, jaxpr structure, and the role of JAX expressions in comparing JIT and non-JIT functions to improve deep learning code efficiency.
We'll cover the following...
We just discussed how trace objects are used to preempt the sequence of operations outlined in the function. These extracted JAX expressions are discussed in this lesson.
Lambda expressions
Before beginning the JAX expressions, it is helpful to have a quick introduction of Lambda expressions.
Lambda expressions are usually used for anonymous functions. Defining them is pretty easy:
- Specify the
lambdakeyword - Outline the parameters (enclosed by
:). - Specify the expression body.
For example, a simple Lambda expression to calculate the circumference of a circle is:
Similarly, below is an example ...