Related Tags

python

# What is a Hessian matrix? Talha Irfan

### Overview

If we take the second-order derivative of $f:R^n\to R$, the resultant matrix is called a Hessian matrix. Since the derivative of a derivative is commutativedoesn’t depend on the order, Hessian matrices are symmetric.

$H_f= \begin{bmatrix} \dfrac{\partial^2 f}{\partial x_1^2} & \dfrac{\partial^2 f}{\partial x_1\,\partial x_2} & \cdots & \dfrac{\partial^2 f}{\partial x_1\,\partial x_n} \\[2.2ex] \dfrac{\partial^2 f}{\partial x_2\,\partial x_1} & \dfrac{\partial^2 f}{\partial x_2^2} & \cdots & \dfrac{\partial^2 f}{\partial x_2\,\partial x_n} \\[2.2ex] \vdots & \vdots & \ddots & \vdots \\[2.2ex] \dfrac{\partial^2 f}{\partial x_n\,\partial x_1} & \dfrac{\partial^2 f}{\partial x_n\,\partial x_2} & \cdots & \dfrac{\partial^2 f}{\partial x_n^2} \end{bmatrix}$

A Hessian matrix has a number of uses, including:

• $2^{nd}$ order optimization
• Testing for critical points

### Code

We can evaluate a Hessian in a number of numerical computing libraries. As a reference, I am mentioning the respective functions/syntax of some of them:

• PyTorch - torch.autograd.functional.hessian() is used to calculate a hessian matrix in PyTorch. To learn more, refer to the official documentation.
• JAX - jax.hessian() is used for the calculation of hessian matrix. The official documentation can help understand its internal implementation.

Limiting ourselves to only JAX here, we can calculate a Hessian matrix directly using the jax.hessian(). As an example, take a function: $f(x,y) = 3x^3-6y^2+3$

import jax
import jax.numpy as jnp

def F(x):
return 3*x*x*x-6*x*x+3

hessian_x = jax.hessian(F)
print(type(hessian_x)) # a function that we will use later
We can directly calculate Hessian using jax.hessian()

Having initialized hessian_x, now we are in a position to evaluate it at any value.

Like any other linear algebra function, hessian() also works on vector-valued inputs, which means that we need to convert the $(x,y)$ pair into a vector:

$X = \begin{bmatrix} x \\ \\y \end{bmatrix}$

before passing as an input.

X=jnp.array([1.2,3.4])
print("Hessian evaluation at (1.2,3.4)")
print(hessian_x(X))
Y=jnp.array([-1.0,1.0])
print("Hessian evaluation at (-1,1)")
print(hessian_x(Y))
Evaluating the Hessian at some pair of points

### Testing for critical points

As mentioned at the start, we can use a Hessian to test the critical points in a pretty simple way.

If a Hessian is:

• Positive semidefinite: The point is a global minima.
• Negative semidefinite (i.e., all eigenvalues are negative): The point is a global maxima.
• IndefiniteA square matrix having some eigenvalues positive and others negative: It is a saddle point (and we are in trouble).
from jax.numpy import linalg #used for checking matrix definiteness

def TestCriticalPoint(hess, x):
if(jnp.all(linalg.eigvals(hess(x)) > 0)):
print("Point is local minima")
elif(jnp.all(linalg.eigvals(hess(x)) < 0)):
print("Point is local maxima")
else:

X=jnp.array([1.2,3.4])
Y=jnp.array([-1.0,1.0])
print("----Testing for X-----")
TestCriticalPoint(hessian_x,X)
print("----Testing for Y-----")
TestCriticalPoint(hessian_x,Y)
Hessian is frequently used for testing the critical points

Hessian's use in $2^{nd}$ order methods is less frequent due to memory and computational requirements.

RELATED TAGS

python

CONTRIBUTOR Talha Irfan 