Search⌘ K
AI Features

Pseudo-Random Number Generation

Explore how pseudo-random number generation works in JAX, focusing on stateless PRNGs like Threefry. Understand the advantages of JAX's approach including parallelization, smaller state vectors, and improved reproducibility for deep learning and numerical computing tasks.

Random variable

Consider a function:

f(x)=x2+5f(x) = x^2 +5

No matter how many times we call this function, it will always give the same output for a given input.

On the other hand, consider some random/stochastic function; the seventh roll of the dice, the temperature at 3 PM, or the quantum state of a particle. All of them are never guaranteed to give the same output every time.

The outcome of this random function is known as a random variable. It’s hard to think of a field that doesn’t rely in some way on random variables. It’s no wonder that they play a central role in modeling almost any real-world system.

PRNG

One of the advantages computers have over manual work is their deterministic behavior: if we run an algorithm under the same conditions, it will always give the same result.

This advantage can sometimes turn out to be a disadvantage as well, though. Computer scientists have been aware of this trait for a long time and have thus devised algorithms known as Pseudo-Random Number Generation (PRNG).

John Von Neumann, who proposed one of the earliest PRNG algorithms, cautioned in his famous quote:

“Anyone who considers arithmetical methods of producing random digits is, of course, in a state of sin.” - (1951)

Mersenne Twister

One of the most commonly used PRNG, Mersenne Twister (named after the Mersenne primes it uses), is widely used in languages and libraries like C++, R, CUDA, NumPy, and more.

Despite its broad implementation, Mersenne Twister has some issues.

Threefry counter

Counter-based PRNGs, which we’ll review shortly, address the shortcomings of Mersenne Twister and other conventional PRNGs. In 2011, D.E.Shaw Research introduced the Threefry counter, which is the baseline behind the JAX implementation.

We’ll review the JAX implementation (referred to going forward as JAX PRNG) features.


Advantages of JAX PRNG

Naturally, any user will be interested in knowing how JAX PRNG manages to address the shortcomings of the Mersenne Twister.

Stateful and stateless generation

One of the key advantages of JAX PRNG is the use of stateless generations. For those who are not yet familiar with these concepts, here is an overview.

Stateful generation

Every PRNG uses a state vector to describe its current state. We use this current state to generate a new sample.

Mersenne Twister is based on stateful generation, which means there are predefined states and the generation of a random number leads to an automatic update of the state vector.

The following example will explain it in detail. It will be helpful to review these NumPy random functions:

  • seed(<int>) initializes the PRNG
  • get_state() fetches the state vector (it’s also a tuple that contains some other information as well)
  • uniform() is used to sample a random number. We’ll discuss this in detail in the following lesson.
Python 3.8
import numpy as np
import pandas as pd
np.random.seed(0)
Twister_State = np.random.get_state()
print(type(Twister_State))
print(pd.DataFrame(Twister_State))
#Sample some numbers to check effect on the state
np.random.uniform()
Twister_State = np.random.get_state()
print(pd.DataFrame(Twister_State)[1:3]) #Showing only relevant info
np.random.uniform()
Twister_State = np.random.get_state()
print(pd.DataFrame(Twister_State)[1:3])
for i in range(1,101): #Sample it further
print(np.random.uniform()) #Value will be different each time
print("State info after sampling 2+100 times")
Twister_State = np.random.get_state()
print(pd.DataFrame(Twister_State)[1:3]) #as expected it will be 204

As may be observed, each sampling or random number generation consumes two states - that is two uint32 from the state vector.

The main issue here is that the user cannot do much with this info. If we have to modify the state, it’s totally at our own risk. NumPy’s documentation clearly states:

If the internal state is manually altered, the user should know exactly what he/she is doing.

Stateless generation

Threefry or any counter-based method on the other hand, uses a much simpler explicit key generation mechanism. No matter how many times we sample a random variable, it will remain in the same state - or stateless in other words.

Python 3.8
import jax.random as random
Threefry_Key = random.PRNGKey(0)
print(type(Threefry_Key)) #Much simpler and smaller in size
print(Threefry_Key)
for i in range(1,101):
print(random.uniform(Threefry_Key,(1,1))) #Value will be same in every iteration
random.uniform(Threefry_Key,(1,1))
print(Threefry_Key) #Key value won't change unless you change it explicitly

As we saw, unless we don’t update the key, there will not be any stochasticity in the output. Some may already be skeptical about the need to switch to explicit key generation, but hang on! The following section explains why JAX PRNG outperforms classical methods.

Parallelization

We have seen that traditional PNRGs are sequential: the next random number is obtained by applying the simple transformation:

sn+1=f(sn)s_{n+1} = f(s_n)

Thus, parallelizing the random numbers generation may result in chaos.

The issue often results in an unpleasant experience for the data scientists or engineers, which is where Threefry’s explicit generation comes into play.

In JAX PRNG, or any counter-based PRNG, we use splitting which enables the parallelization by using subkeys.

Let’s reimplement the above example using the splitting feature. All we need to use is the split() function. This function takes in a key and returns both key and subkey.

Python 3.8
Threefry_Key = random.PRNGKey(0)
print("Originally key is: ",Threefry_Key)
#print(random.uniform(Threefry_Key,(1,1)))
Threefry_Key, Threefry_Subkey = random.split(Threefry_Key)
print("Key after splitting:",Threefry_Key)
print("Value of subkey:", Threefry_Subkey)
#print(random.uniform(Threefry_Key,(1,1)))

Threefry is based on the Threefish hash function introduced in 2008. It’s a scaled-down version of Threefish, which is the reasoning behind the name fry - a term for juvenile fish.

Python 3.8
Threefry_Key = random.PRNGKey(0)
for i in range(1,101):
Threefry_Key, subkey=random.split(Threefry_Key)
print(subkey)
print(random.uniform(subkey,(1,1)))

The point of the whole story is that: JAX not only facilitates stochastic PRNG but also ensures it uses:

  • A smaller state vector, which helps in the quick initialization of PRNG.
  • Parallelization without any issues, thus making it usable with XLA transformations.