Solution: Train and Test a Model
Let’s review the solution for the previous challenge.
We'll cover the following...
After importing all the necessary libraries, your first step is to define the network.
Define the network
Create a convolutional neural network with the Linen API by subclassing a module. Because the architecture in this case is relatively simple (you’re just stacking layers), you can define the inlined submodules directly within the __call__
method and wrap it with the @compact
decorator.
class CNN(nn.Module):"""A simple CNN model."""@nn.compactdef __call__(self, x):x = nn.Conv(features=64, kernel_size=(3, 3))(x)x = nn.relu(x)x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))x = nn.Conv(features=64, kernel_size=(3, 3))(x)x = nn.relu(x)x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))x = x.reshape((x.shape[0], -1)) # flattenx = nn.Dense(features=256)(x)x = nn.relu(x)x = nn.Dense(features=2)(x)return x
Define loss
We simply use optax.softmax_cross_entropy()
. Note that this function expects both logits
and labels
to have the shape [batch, num_classes]
. Since the labels will be read from the TensorFlow dataset as integer values, we first need to convert them to one-hot encoding.
Our function returns a simple scalar value ready for optimization, so we ...