Implementation of Softmax activation function in PyTorch
Softmax is an activation function typically used in the output layer of a neural network for multiclass classification tasks. It turns logits (the numeric output of the last linear layer of a multiclass classification neural network) into probabilities that sum up to 1. Essentially, it assigns decimal probabilities to each class in a multiclass problem. Those probabilities can be understood as the model’s confidence in its prediction.
Mathematical form
Mathematically, the Softmax function is expressed as follows:
where,
is the probability of each particular class over all possible classes. is the base of the natural logarithm. represents the input values to the softmax function, often referred to as logits. These are the outputs of the last neural network’s layer before applying softmax. represents the total number of classes.
The bar chart depicts the output of a Softmax activation function for a three-class classification problem, showing the predicted probabilities for each class. Class 1 has the highest probability, suggesting the model’s preference for this class, while Classes 2 and 3 have progressively lower probabilities, indicating less likelihood according to the model’s prediction.
Why do we use Softmax?
Probabilistic interpretation: It converts the output logits of the model into a probability distribution, which is essential for classification where you want to understand the confidence of the model in its predictions across different classes.
Differentiability: It is a differentiable function, allowing gradients to be calculated, which is necessary for backpropagation when the network is being trained.
Useful for multiclass problems: It is specifically designed for multiclass classification where each class is mutually exclusive.
Implementation of Softmax function in neural network
Let’s see the implementation of the Softmax activation function in the neural network using PyTorch.
import torchimport torch.nn as nnclass SimpleNN(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(SimpleNN, self).__init__()self.fc1 = nn.Linear(input_size, hidden_size) # Fully connected layer 1self.relu = nn.ReLU() # ReLU activation functionself.fc2 = nn.Linear(hidden_size, output_size) # Fully connected layer 2self.softmax = nn.Softmax(dim=1) # Softmax activation functiondef forward(self, x):out = self.fc1(x) # Apply the first fully connected layerout = self.relu(out) # Apply the ReLU activation functionout = self.fc2(out) # Apply the second fully connected layerout = self.softmax(out) # Apply the Softmax activation functionreturn out# Define network parametersinput_size = 64 # Number of input featureshidden_size = 128 # Number of neurons in the hidden layeroutput_size = 10 # Number of output classes (adjust as necessary)# Input datainput_data = torch.rand(32, input_size) # 32 is the batch sizetarget = torch.randint(0, output_size, (32,), dtype=torch.long) # Random target values, adjusted for multi-class# Create an instance of the SimpleNN modelmodel = SimpleNN(input_size, hidden_size, output_size)# Define the loss function (cross-entropy loss) and optimizer (Adam)# Since we are applying Softmax in the model, we use NLLLoss herecriterion = nn.NLLLoss()optimizer = torch.optim.Adam(model.parameters(), lr=0.01)# Example training loopnum_epochs = 10 # Define the number of training epochsfor epoch in range(num_epochs):# Forward passoutputs = model(input_data)# Apply log on the outputs before calculating the lossloss = criterion(torch.log(outputs), target) # Compute the loss# Backward pass and optimizationoptimizer.zero_grad() # Clear gradientsloss.backward() # Backpropagate to compute gradientsoptimizer.step() # Update the model parameters# Print the loss for each epochprint(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item()}')
Code explanation
Line 10: We create an instance of the Softmax activation function using the
nn.Softmax(), and store it as an attribute of theSimpleNNclass namedself.softmax. Thedimparameter specifies the dimension along which the softmax operation is computed. In our case,dim=1ensures that the softmax is applied to each row (each set of logits for an instance), turning them into probabilities that sum to 1.Line 16: We apply the Softmax activation function to the output of the final layer. It takes a tensor of real numbers as input and produces a probability distribution over multiple classes.
For a more detailed explanation, take a look at this Educative Answer.
Insights for machine learning engineers:
It’s a common mistake to use Softmax activation followed by
nn.CrossEntropyLoss, which can be numerically unstable and redundant. Thenn.CrossEntropyLossis already expecting logits and appliesLogSoftmaxinternally. If we're usingnn.CrossEntropyLoss, we should feed in the raw output scores from our network (i.e., the logits).During training, softmax is often coupled with the cross-entropy loss to form a single layer known as the log-softmax layer, which provides better numerical stability. However, during inference, softmax is applied to the logits to interpret the outputs as probabilities.
Softmax is suitable for multiclass classification where each instance belongs to exactly one class. If you’re dealing with a multilabel classification problem, consider using a Sigmoid activation at the output layer.
Free Resources