Adaptive Optimizers
Learn about the adaptive optimizers used in JAX and Flax.
We'll cover the following...
We'll cover the following...
AdaBelief
AdaBelief works on the concept of “belief” in the current gradient direction. If it results in good performance, then that direction is trusted, and large updates are applied. Otherwise, it’s distrusted and the step size is reduced.
The authors of AdaBelief introduced the optimizer to:
- Converge fast, as in adaptive methods.
- Have good generalization like SGD.
- Be stable during training.
Let’s look at a Flax training state that applies the AdaBelief optimizer.
import optaxseed = random.PRNGKey(0)learning_rate = jnp.array(1/1e4)model = CNN()weights = model.init(seed, X_train[:5])optimizer = optax.adabelief(learning_rate=learning_rate) # Initialize AdaBelief Optimizeroptimizer_state = optimizer.init(weights) # Optmizer state
In the code above:
Line 1: We import the
optaxlibrary for optimizers.Lines 2–3: We define the random
seedvariable and learning rate for the CNN network.Lines 5–6: We instantiate the CNN model using
CNN()and set the initial weights using the ...