Trusted answers to developer questions
Trusted Answers to Developer Questions

Related Tags

python
autograd

What is a Hessian matrix?

Talha Irfan

Overview

If we take the second-order derivative of f:RnRf:R^n\to R, the resultant matrix is called a Hessian matrix. Since the derivative of a derivative is commutativedoesn’t depend on the order, Hessian matrices are symmetric.

Hf=[2fx122fx1x22fx1xn2fx2x12fx222fx2xn2fxnx12fxnx22fxn2] H_f= \begin{bmatrix} \dfrac{\partial^2 f}{\partial x_1^2} & \dfrac{\partial^2 f}{\partial x_1\,\partial x_2} & \cdots & \dfrac{\partial^2 f}{\partial x_1\,\partial x_n} \\[2.2ex] \dfrac{\partial^2 f}{\partial x_2\,\partial x_1} & \dfrac{\partial^2 f}{\partial x_2^2} & \cdots & \dfrac{\partial^2 f}{\partial x_2\,\partial x_n} \\[2.2ex] \vdots & \vdots & \ddots & \vdots \\[2.2ex] \dfrac{\partial^2 f}{\partial x_n\,\partial x_1} & \dfrac{\partial^2 f}{\partial x_n\,\partial x_2} & \cdots & \dfrac{\partial^2 f}{\partial x_n^2} \end{bmatrix}

A Hessian matrix has a number of uses, including:

  • 2nd2^{nd} order optimization
  • Testing for critical points

Code

We can evaluate a Hessian in a number of numerical computing libraries. As a reference, I am mentioning the respective functions/syntax of some of them:

  • PyTorch - torch.autograd.functional.hessian() is used to calculate a hessian matrix in PyTorch. To learn more, refer to the official documentation.
  • JAX - jax.hessian() is used for the calculation of hessian matrix. The official documentation can help understand its internal implementation.

Limiting ourselves to only JAX here, we can calculate a Hessian matrix directly using the jax.hessian(). As an example, take a function: f(x,y)=3x36y2+3f(x,y) = 3x^3-6y^2+3

import jax
import jax.numpy as jnp

def F(x):
  return 3*x[0]*x[0]*x[0]-6*x[1]*x[1]+3

hessian_x = jax.hessian(F)
print(type(hessian_x)) # a function that we will use later
We can directly calculate Hessian using jax.hessian()

Having initialized hessian_x, now we are in a position to evaluate it at any value.

Like any other linear algebra function, hessian() also works on vector-valued inputs, which means that we need to convert the (x,y)(x,y) pair into a vector:

X=[xy] X = \begin{bmatrix} x \\ \\y \end{bmatrix}

before passing as an input.

X=jnp.array([1.2,3.4])
print("Hessian evaluation at (1.2,3.4)")
print(hessian_x(X))
Y=jnp.array([-1.0,1.0])
print("Hessian evaluation at (-1,1)")
print(hessian_x(Y))
Evaluating the Hessian at some pair of points

Testing for critical points

As mentioned at the start, we can use a Hessian to test the critical points in a pretty simple way.

If a Hessian is:

  • Positive semidefinite: The point is a global minima.
  • Negative semidefinite (i.e., all eigenvalues are negative): The point is a global maxima.
  • IndefiniteA square matrix having some eigenvalues positive and others negative: It is a saddle point (and we are in trouble).
from jax.numpy import linalg #used for checking matrix definiteness

def TestCriticalPoint(hess, x):
    if(jnp.all(linalg.eigvals(hess(x)) > 0)):
        print("Point is local minima")
    elif(jnp.all(linalg.eigvals(hess(x)) < 0)):
        print("Point is local maxima")
    else:
        print("Its a Saddle point")

X=jnp.array([1.2,3.4])
Y=jnp.array([-1.0,1.0])
print("----Testing for X-----")
TestCriticalPoint(hessian_x,X)
print("----Testing for Y-----")
TestCriticalPoint(hessian_x,Y)
Hessian is frequently used for testing the critical points

Hessian's use in 2nd2^{nd} order methods is less frequent due to memory and computational requirements.

RELATED TAGS

python
autograd

CONTRIBUTOR

Talha Irfan
Copyright ©2022 Educative, Inc. All rights reserved
RELATED COURSES

View all Courses

Keep Exploring