Trusted answers to developer questions

Talha Irfan

If we take the second-order derivative of $f:R^n\to R$, the resultant matrix is called a **Hessian matrix**. Since the derivative of a derivative is *symmetric*.

$H_f= \begin{bmatrix} \dfrac{\partial^2 f}{\partial x_1^2} & \dfrac{\partial^2 f}{\partial x_1\,\partial x_2} & \cdots & \dfrac{\partial^2 f}{\partial x_1\,\partial x_n} \\[2.2ex] \dfrac{\partial^2 f}{\partial x_2\,\partial x_1} & \dfrac{\partial^2 f}{\partial x_2^2} & \cdots & \dfrac{\partial^2 f}{\partial x_2\,\partial x_n} \\[2.2ex] \vdots & \vdots & \ddots & \vdots \\[2.2ex] \dfrac{\partial^2 f}{\partial x_n\,\partial x_1} & \dfrac{\partial^2 f}{\partial x_n\,\partial x_2} & \cdots & \dfrac{\partial^2 f}{\partial x_n^2} \end{bmatrix}$

A Hessian matrix has a number of uses, including:

- $2^{nd}$ order optimization
- Testing for critical points

We can evaluate a Hessian in a number of numerical computing libraries. As a reference, I am mentioning the respective functions/syntax of some of them:

**PyTorch -**`torch.autograd.functional.hessian()`

is used to calculate a hessian matrix in PyTorch. To learn more, refer to the official documentation.**JAX -**`jax.hessian()`

is used for the calculation of hessian matrix. The official documentation can help understand its internal implementation.

Limiting ourselves to only JAX here, we can calculate a Hessian matrix directly using the `jax.hessian()`

.
As an example, take a function:

import jax import jax.numpy as jnp def F(x): return 3*x[0]*x[0]*x[0]-6*x[1]*x[1]+3 hessian_x = jax.hessian(F) print(type(hessian_x)) # a function that we will use later

We can directly calculate Hessian using jax.hessian()

Having initialized `hessian_x`

, now we are in a position to evaluate it at any value.

Like any other linear algebra function, `hessian()`

also works on **vector-valued inputs**, which means that we need to convert the $(x,y)$ pair into a vector:

$X = \begin{bmatrix} x \\ \\y \end{bmatrix}$

before passing as an input.

X=jnp.array([1.2,3.4]) print("Hessian evaluation at (1.2,3.4)") print(hessian_x(X)) Y=jnp.array([-1.0,1.0]) print("Hessian evaluation at (-1,1)") print(hessian_x(Y))

Evaluating the Hessian at some pair of points

As mentioned at the start, we can use a Hessian to test the critical points in a pretty simple way.

If a Hessian is:

- Positive semidefinite: The point is a global
*minima*. - Negative semidefinite (i.e., all eigenvalues are negative): The point is a global
*maxima*. : It is a saddle point (Indefinite A square matrix having some eigenvalues positive and others negative *and we are in trouble*).

from jax.numpy import linalg #used for checking matrix definiteness def TestCriticalPoint(hess, x): if(jnp.all(linalg.eigvals(hess(x)) > 0)): print("Point is local minima") elif(jnp.all(linalg.eigvals(hess(x)) < 0)): print("Point is local maxima") else: print("Its a Saddle point") X=jnp.array([1.2,3.4]) Y=jnp.array([-1.0,1.0]) print("----Testing for X-----") TestCriticalPoint(hessian_x,X) print("----Testing for Y-----") TestCriticalPoint(hessian_x,Y)

Hessian is frequently used for testing the critical points

Hessian's use in

$2^{nd}$ order methods is less frequent due to memory and computational requirements.

RELATED TAGS

python

autograd

CONTRIBUTOR

Talha Irfan

Copyright ©2022 Educative, Inc. All rights reserved

RELATED COURSES

View all Courses

Keep Exploring

Related Courses