Multivariate Calculus

This lesson will introduce multivariate calculus in JAX.

We can build on our linear algebra refresher to use auto differentiation for vectors and matrices as well. These are instrumental in neural networks when we apply gradient-based optimization on whole matrices.


If we have a multivariate function (a function of more than one variable), f:RnRf:R^n \to R, we calculate its derivative by taking the partial derivative with respect to every (input/independent) variable.

For example, if we have:

f(x,y,z)=4x+3y2zf(x,y,z) = 4x+3y^2-z

it’s derivative, represented by , will be calculated as:

f(x,y,z)=[δf(x,y,z)δxδf(x,y,z)δyδf(x,y,z)δz]=[46y1] \nabla f(x,y,z) = \begin{bmatrix} \frac{\delta f(x,y,z)}{\delta x} \\ \\\frac{\delta f(x,y,z)}{\delta y} \\\\\frac{\delta f(x,y,z)}{\delta z} \end{bmatrix} = \begin{bmatrix} 4 \\ \\6y \\\\ -1 \end{bmatrix}

Generally, we can define this as:

f=[δfδx1δfδx2...δfδxn] \nabla f = \begin{bmatrix} \frac{\delta f}{\delta x_1} \\ \\\frac{\delta f}{\delta x_2} \\ \\ . \\ . \\ . \\ \frac{\delta f}{\delta x_n} \\ \end{bmatrix}

The above example can be calculated using grad(), as below.

Note: Since the first and third terms in the example are constants, 4 and -1, we can try any permutation of the input vector’s values to confirm the same output.

Get hands-on with 1000+ tech skills courses.