Implementation of Softmax activation function in PyTorch

Softmax is an activation function typically used in the output layer of a neural network for multiclass classification tasks. It turns logits (the numeric output of the last linear layer of a multiclass classification neural network) into probabilities that sum up to 1. Essentially, it assigns decimal probabilities to each class in a multiclass problem. Those probabilities can be understood as the model’s confidence in its prediction.

Mathematical form

Mathematically, the Softmax function is expressed as follows:

where,

  • P(yi)P(y_i) is the probability of each particular class over all possible classes.

  • ee is the base of the natural logarithm.

  • ziz_i represents the input values to the softmax function, often referred to as logits. These are the outputs of the last neural network’s layer before applying softmax.

  • KK represents the total number of classes.

Softmax probabilities
Softmax probabilities

The bar chart depicts the output of a Softmax activation function for a three-class classification problem, showing the predicted probabilities for each class. Class 1 has the highest probability, suggesting the model’s preference for this class, while Classes 2 and 3 have progressively lower probabilities, indicating less likelihood according to the model’s prediction.

Why do we use Softmax?

  1. Probabilistic interpretation: It converts the output logits of the model into a probability distribution, which is essential for classification where you want to understand the confidence of the model in its predictions across different classes.

  2. Differentiability: It is a differentiable function, allowing gradients to be calculated, which is necessary for backpropagation when the network is being trained.

  3. Useful for multiclass problems: It is specifically designed for multiclass classification where each class is mutually exclusive.

Implementation of Softmax function in neural network

Let’s see the implementation of the Softmax activation function in the neural network using PyTorch.

import torch
import torch.nn as nn
class SimpleNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(SimpleNN, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size) # Fully connected layer 1
self.relu = nn.ReLU() # ReLU activation function
self.fc2 = nn.Linear(hidden_size, output_size) # Fully connected layer 2
self.softmax = nn.Softmax(dim=1) # Softmax activation function
def forward(self, x):
out = self.fc1(x) # Apply the first fully connected layer
out = self.relu(out) # Apply the ReLU activation function
out = self.fc2(out) # Apply the second fully connected layer
out = self.softmax(out) # Apply the Softmax activation function
return out
# Define network parameters
input_size = 64 # Number of input features
hidden_size = 128 # Number of neurons in the hidden layer
output_size = 10 # Number of output classes (adjust as necessary)
# Input data
input_data = torch.rand(32, input_size) # 32 is the batch size
target = torch.randint(0, output_size, (32,), dtype=torch.long) # Random target values, adjusted for multi-class
# Create an instance of the SimpleNN model
model = SimpleNN(input_size, hidden_size, output_size)
# Define the loss function (cross-entropy loss) and optimizer (Adam)
# Since we are applying Softmax in the model, we use NLLLoss here
criterion = nn.NLLLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
# Example training loop
num_epochs = 10 # Define the number of training epochs
for epoch in range(num_epochs):
# Forward pass
outputs = model(input_data)
# Apply log on the outputs before calculating the loss
loss = criterion(torch.log(outputs), target) # Compute the loss
# Backward pass and optimization
optimizer.zero_grad() # Clear gradients
loss.backward() # Backpropagate to compute gradients
optimizer.step() # Update the model parameters
# Print the loss for each epoch
print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item()}')

Code explanation

  • Line 10: We create an instance of the Softmax activation function using the nn.Softmax(), and store it as an attribute of the SimpleNN class named self.softmax. The dim parameter specifies the dimension along which the softmax operation is computed. In our case, dim=1 ensures that the softmax is applied to each row (each set of logits for an instance), turning them into probabilities that sum to 1.

  • Line 16: We apply the Softmax activation function to the output of the final layer. It takes a tensor of real numbers as input and produces a probability distribution over multiple classes.

For a more detailed explanation, take a look at this Educative Answer.

Insights for machine learning engineers:

  • It’s a common mistake to use Softmax activation followed by nn.CrossEntropyLoss, which can be numerically unstable and redundant. The nn.CrossEntropyLoss is already expecting logits and applies LogSoftmax internally. If we're using nn.CrossEntropyLoss, we should feed in the raw output scores from our network (i.e., the logits).

  • During training, softmax is often coupled with the cross-entropy loss to form a single layer known as the log-softmax layer, which provides better numerical stability. However, during inference, softmax is applied to the logits to interpret the outputs as probabilities.

  • Softmax is suitable for multiclass classification where each instance belongs to exactly one class. If you’re dealing with a multilabel classification problem, consider using a Sigmoid activation at the output layer.

Free Resources

Copyright ©2025 Educative, Inc. All rights reserved