Training with an EMA (Exponential Moving Average)

Learn to train an image classification model with an exponential moving average.

The PyTorch Image Model framework supports an exponential moving average (EMA), which maintains moving averages of the trained variables by employing an exponential decay.

The implementation of an EMA is as follows:

  1. Add shadow copies of trained weights during initialization.
  2. Compute a moving average of the trained weights at each training step. It uses exponential decay for the computation.

weights=decayweights+(1decay)weightsnewweights = decay * weights + (1 - decay) * weights_{new}

Note: Most of the time, the value for the decay rate is approximately 1.0. A good value is typically in multiple nines, such as 0.99, 0.999, or 0.9999.

Sometimes when we apply an EMA in training, it improves the performance of the model. To perform well, architectures such as MobileNet-V3, EfficientNet, and MNASNet require the EMA smoothing weights.

Training with an EMA

Set the model-ema flag and model-ema-decay arguments to train with an EMA. The model-ema-decay argument represents the decay rate for the EMA and accepts a floating-point.

Call it as follows:

python train.py /app/dataset2 --model resnet50 --num-classes 4 --model-ema --model-ema-decay 0.99

By default, the value for model-ema-decay is 0.9998.

Note: The command above will keep 99.99% of the weights from the existing state. At each iteration, it will only update 0.01% of the new weights.

Track an EMA on the CPU

We require additional memory to train with an EMA, which may cause an out-of-memory error. We can force it to run entirely on the CPU via the model-ema-force-cpu flag.

python train.py /app/dataset2 --model resnet50 --num-classes 4 --model-ema --model-ema-decay 0.99 --model-ema-force-cpu

Note: Training with the model-ema-force-cpu flag will disable the validation of the EMA weights.

Most of the best-performing models in the Pytorch Image Model framework use an EMA in the training process.

Get hands-on with 1200+ tech skills courses.