Stochastic Gradient Descent
Learn about SGD-based optimizers in JAX and Flax.
We'll cover the following...
We'll cover the following...
SGD implements stochastic gradient descent with support for momentum and
Let’s understand how to use SGD in the following playground:
import optaxseed = random.PRNGKey(0)learning_rate = jnp.array(1/1e4)model = CNN()weights = model.init(seed, X_train[:5])optimizer = optax.sgd(learning_rate=learning_rate) # Initialize SGD as Optimizeroptimizer_state = optimizer.init(weights) # Optmizer state
In the code above: