Testing a numerical computation program can be tricky. This is especially the case when using JAX due to the parallel processing over GPU/TPU. The JAX ecosystem provides a library for it. Chex is a useful library with utilities like:
- Debugging transformations (like
- Testing code across JIT and non-JIT versions.
Traditional PyType annotations do not support
DeviceArray size or shape, so Chex provides assertions of its own.
assert_rank(), we can validate both the shape and dimension of a given JAX array.