Search⌘ K
AI Features

Pure Functions and Random Numbers

Explore how JAX emphasizes pure functions that produce consistent outputs without side effects, essential for just-in-time compilation. Understand JAX's approach to random number generation using PRNG keys, ensuring reproducibility and functional integrity in deep learning workflows.

We'll cover the following...

Pure functions

A pure function has no side effects and expects the output to come only from its inputs. JAX transformation functions expect pure functions. When working with JAX, all input should be passed through function parameters, while all output should come from the function results. Hence, something like Python’s print function is not pure. This can be demonstrated using the impure_print_side_effect() function in the code example.

The side effects appear during the first run. Subsequent runs with parameters of the same type and shape may not show the side effects. This is because JAX ...