Search⌘ K
AI Features

Optax: Advanced Features

Explore Optax's advanced features in this lesson to understand learning-rate scheduling methods like exponential and cosine decay, combining multiple optimizers, gradient clipping to handle exploding gradients, adding gradient noise to improve learning, and applying differential privacy techniques for secure model training. Gain practical knowledge of leveraging Optax for more effective and configurable optimization in JAX-based deep learning projects.

While the last lesson reviewed some common loss functions and optimizers, Optax has much more to offer. More than we can reasonably cover in this lesson, actually, so we’ll restrict ourselves to just a handful of functionalities here.

Learning-rate scheduling

Not content with the default setting of a constant learning rate, the deep learning community has been experimenting with variable learning rates. Optax offers more than a dozen versions of this technique, which is known as learning-rate scheduling. Let’s review a few:

  • Exponential decay
  • Cosine decay
  • Combining (multiple, existing) schedules
  • Injecting hyperparameters

Exponential decay

This scheduling scheme follows an exponential distribution.

ηt=η0ekT\eta_t = \eta_{0}e^{-kT}

We can switch between a continuous or discrete sampling by setting up the staircase attribute to True or False.

Cosine decay

In 2017, Ilya Loshchilov & Frank Hutter proposed Stochastic Gradient Descent with warm Restarts, SGDR. It uses a cosine decay scheduling, which can be represented as:

ηt=ηmini+12(ηmaxiηmini)(1+cos(TcurrTiπ))\eta_t = \eta_{\min}^i+\frac{1}{2}(\eta_{\max}^i - \eta_{\min}^i)(1 + \cos(\frac {T_{curr}}{T_i}\pi)) ...