Loss Function
Find out the difference between loss functions for regression and classification and which loss function is best for the MNIST classifier.
We'll cover the following
Mean squared error vs. binary cross-entropy loss
In the last section, we developed a neural network to classify images of hand-written digits. Even though we intentionally designed the network to be simple, it worked remarkably well, getting an accuracy score of about 87% with the MNIST test dataset.
Here we’ll explore some refinements which can help us improve a network’s performance. Some neural networks are designed to produce a continuous range of output values. For example, we might want a network predicting temperatures to be able to output any value in the range 0 to 100 degrees centigrade.
Some networks are designed to produce true/false or 1/0 output. For example, we would want a network that decides whether an image is a cat or not, to output values close to 0.0 or 1.0, and not all of the values in between.
If we did the maths to work out loss functions for each of these different scenarios, we’d find that the Mean Squared Error loss makes sense for the first kind of task. You may know that this is a regression task, but don’t worry if you don’t.
For the second task, a classification task, a different kind of loss function makes more sense. A popular one is called the Binary Cross Entropy loss, and it tries to penalize wrong but confident outputs as well as correct but not so confident outputs. PyTorch provides this as nn.BCELoss()
.
Loss function for MNIST classifier
Our network, which classifies MNIST images, is of the second type. The output node values should ideally be confidently close to 0.0, with just one being confidently close to 1.0.
Let’s change the loss function from MSELoss()
to BCELoss()
.
Get hands-on with 1200+ tech skills courses.