Search⌘ K

Softmax

Explore how to implement the softmax function in TensorFlow to convert model logits into probability distributions for multiclass classification. Learn to update your model's training with softmax cross entropy loss, generate class predictions using tf.math.argmax, and measure accuracy for models handling multiple classes.

Chapter Goals:

  • Update the model to use the softmax function
  • Perform multiclass classification

A. The softmax function

To convert the model to multiclass classification, we need to make a few changes to the metrics and training parameters. Previously, we used the sigmoid function to convert logits to probabilities, then rounded those probabilities to get a predicted label. However, now that there are multiple possible classes, we need to use the generalization of the sigmoid function, known as the softmax function.

The softmax function takes in a vector of numbers (logits for each class), and converts the numbers to a probability distribution. This means that the sum of the probabilities across all the classes equals 1, and each class's individual probability is based on how large its logit was relative to the sum of all ...