Pure Functions and Random Numbers

Learn how to generate random numbers and pure functions in JAX.

We'll cover the following

Pure functions

A pure function has no side effects and expects the output to come only from its inputs. JAX transformation functions expect pure functions. When working with JAX, all input should be passed through function parameters, while all output should come from the function results. Hence, something like Python’s print function is not pure. This can be demonstrated using the impure_print_side_effect() function in the code example.

The side effects appear during the first run. Subsequent runs with parameters of the same type and shape may not show the side effects. This is because JAX now invokes a cached compilation of the function. JAX reruns the Python function when the type or shape of the argument changes.

Get hands-on with 1200+ tech skills courses.