Search⌘ K
AI Features

Chex

Explore how to use Chex to test and debug numerical computations in JAX, including shape, type assertions, device availability checks, and emulating parallel code for easier development and troubleshooting.

Testing a numerical computation program can be tricky. This is especially the case when using JAX due to the parallel processing over GPU/TPU. The JAX ecosystem provides a library for it. Chex is a useful library with utilities like:

  • Assertions.
  • Debugging transformations (like vmap or pmap).
  • Testing code across JIT and non-JIT versions.

Assertions

Traditional PyType annotations do not support DeviceArray size or shape, so Chex provides assertions of its own.

Primitives

By using assert_shape() and assert_rank(), we can validate both the shape and dimension of a given JAX array.

Python 3.8
import chex
from chex import assert_shape, assert_rank
x = jnp.ones((5,5))
y = jnp.ones((2,5,3,4))
print(assert_shape(x, (5, 5)))
#print(assert_shape(x,[2, 4])) #will throw error due to incosistent shapes
print(assert_rank(y,4))

We can also validate the equal shapes directly using assert_equal_shape() and can use assert_type() to verify datatypes consistency.

Please note that assert_type() ...