If we take the second-order derivative of , the resultant matrix is called a Hessian matrix. Since the derivative of a derivative is
A Hessian matrix has a number of uses, including:
We can evaluate a Hessian in a number of numerical computing libraries. As a reference, I am mentioning the respective functions/syntax of some of them:
torch.autograd.functional.hessian()is used to calculate a hessian matrix in PyTorch. To learn more, refer to the official documentation.
jax.hessian()is used for the calculation of hessian matrix. The official documentation can help understand its internal implementation.
Limiting ourselves to only JAX here, we can calculate a Hessian matrix directly using the
As an example, take a function:
import jax import jax.numpy as jnp def F(x): return 3*x*x*x-6*x*x+3 hessian_x = jax.hessian(F) print(type(hessian_x)) # a function that we will use later
hessian_x, now we are in a position to evaluate it at any value.
Like any other linear algebra function,
hessian() also works on vector-valued inputs, which means that we need to convert the pair into a vector:
before passing as an input.
X=jnp.array([1.2,3.4]) print("Hessian evaluation at (1.2,3.4)") print(hessian_x(X)) Y=jnp.array([-1.0,1.0]) print("Hessian evaluation at (-1,1)") print(hessian_x(Y))
As mentioned at the start, we can use a Hessian to test the critical points in a pretty simple way.
If a Hessian is:
from jax.numpy import linalg #used for checking matrix definiteness def TestCriticalPoint(hess, x): if(jnp.all(linalg.eigvals(hess(x)) > 0)): print("Point is local minima") elif(jnp.all(linalg.eigvals(hess(x)) < 0)): print("Point is local maxima") else: print("Its a Saddle point") X=jnp.array([1.2,3.4]) Y=jnp.array([-1.0,1.0]) print("----Testing for X-----") TestCriticalPoint(hessian_x,X) print("----Testing for Y-----") TestCriticalPoint(hessian_x,Y)
Hessian's use in
order methods is less frequent due to memory and computational requirements.
View all Courses