Auto-differentiation
Understand the concept of auto-differentiation in JAX, its importance for calculating derivatives in machine learning, and how it improves over manual, symbolic, and numeric methods. Learn about forward and reverse accumulation, partial derivatives, and using JAX's grad and value_and_grad functions to simplify gradient computations for optimization.
Background
First pioneered by the seminal work of Rumelhart and Hinton in 1986, the majority of current machine learning optimization methods use derivatives. So, there is a pressing need for their efficient calculation.
Manual calculations
Most of the early machine learning researchers and scientists for example, Bottou, 1998 for Stochastic Gradient Descent had to go through a slow, laborious process of manual calculation of analytical derivatives, which is prone to error.
Using computer program
Programming-based solutions are less laborious, but calculating these derivatives in a program can also be tricky. We can categorize them into three paradigms:
- Symbolic differentiation
- Numeric differentiation
- Auto differentiation
The first and second methods are prone to errors, including:
- Calculating higher-derivatives is tricky due to long and complex expressions for symbolic differentiation and rounding-off errors meaning less accurate results