Flax: Linear Modules

This lesson will provide an overview of linear modules in Flax.

This lesson will focus on linear modules. These modules are helpful in designing feedforward neural networks. For simplicity, we’ll omit the optional arguments.

Dense layers

We can use Dense() to apply an affine transformation to the previous layer, xx as:

y=wTx+by=w^Tx+b

Since Dense() applies a transformation only to the last dimension of the previous layer matrix, we can generalize it to any dimension by this function. For example, we can apply linear transformation along the required dimension through:

y = DenseGeneral(x,axis=<>)

Note: Most documentation incorrectly term the affine transformation as a linear transformation. This lesson will refer to them as affine thereon. (For further reading, check any linear algebra book such as Boyd’s Applied Linear Algebra (2018).)

If we print the output of a dense layer, some extra parameters will also be on display. Some of them are pretty helpful, like:

  • use_bias controls whether or not to use the bias. It’s true by default.
  • kernel_init is the function to initialize the weight matrix, ww

Get hands-on with 1200+ tech skills courses.