...

/

Solution: Train and Test a Model

Solution: Train and Test a Model

Let’s review the solution for the previous challenge.

We'll cover the following...

After importing all the necessary libraries, your first step is to define the network.

Define the network

Create a convolutional neural network with the Linen API by subclassing a module. Because the architecture in this case is relatively simple (you’re just stacking layers), you can define the inlined submodules directly within the __call__ method and wrap it with the @compact decorator.

class CNN(nn.Module):
"""A simple CNN model."""
@nn.compact
def __call__(self, x):
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=2)(x)
return x

Define loss

We simply use optax.softmax_cross_entropy(). Note that this function expects both logits and labels to have the shape [batch, num_classes]. Since the labels will be read from the TensorFlow dataset as integer values, we first need to convert them to one-hot encoding.

Our function returns a simple scalar value ready for optimization, so we ...