Activation Functions in JAX
Explore common activation functions in JAX like ReLU, Sigmoid, Softmax, and advanced options such as ELU and GELU. Understand their roles in neural networks for different tasks and learn to define custom functions like LeakyReLU for tailored performance.
Overview
Activation functions are applied in neural networks to ensure the network outputs the desired result. The activation function caps the output within a specific range. For instance, when solving a binary classification problem, the outcome should be a number between 0 and 1. This indicates the probability of an item belonging to either of the two classes.
However, in a regression problem, we want the numerical prediction of a quantity, for example, the price of an item. We should, therefore, choose an appropriate activation function for the problem being solved. Let’s look at common activation functions in JAX and Flax.
ReLU
The Rectified Linear Unit (ReLU) activation function is primarily used in the hidden layers of neural networks to ensure non-linearity. The function caps all outputs to zero and above. Outputs below zero are returned as zero, while numbers above zero are returned as they are. This ensures that there are no negative numbers in the network.
Let’s understand how to apply the ReLU activation function in the following code snippet:
In lines 7, 10, and 14, we apply the ReLU activation ...