Chollet: Chapter 07: Working with Keras: A Deep Dive
Metadata
Title: Working with Keras: A deep dive
Number: 7
Core Ideas
Keras is guided by the principle of progressive disclosure of complexity. The idea is to make it easy to get started, but possible to handle high-complexity use cases after incremental learning.
So there are a spectrum of workflows in Keras, all sharing common APIs like Layer
and Model
.
There are three core APIs for building models, the Sequential model, the simplest, limited to a simple stack of layers; the functional API, focused on graph-like model architectures, a mid-point between usability and functionality, most commonly used; Model subclassing, low-level option for full control - you miss built-in features and can make mistakes - tends to be used by researchers.
The Sequential Model
The simplest API, a list of layers that are called in sequence. You can specify the layers when instantiating the model, or add them with the add
method. Layers’ weights are only created when they are called the first time (or built using build
) as you need to know the input tensor shape.
Once a model is built you can display its contents via the summary
method which is very handy for debugging. Add a layer, build it, and print the summary to check it is very common workflow.
You can name models and layers by passing a name
parameter on instantiation.
You can declare the shape of the inputs in advance by using the Input
class to avoid calling build
and still being able to inspect the model as you build it. This is really handy as you can see, for example, how many params you are creating in the model:
from tensorflow import keras
model = keras.Sequential(name='My Classifier')
model.add(keras.Input(shape=(10000,), name="document"))
model.add(keras.layers.Dense(64, activation="relu", name="hidden_state"))
model.summary() # check in the notebook then...
model.add(keras.layers.Dense(10, activation="softmax", name="predicted_classes"))
model.summary() # check again...
Functional API
The chapter then presents the most commmon API to use in practice, the Functional API
Subclassing the Model Class
The chapter then introduces the most low-level approach, subclassing the Model
class and doing your own thing.
It looks similar to the Layer
subclassing in ch. 3, and the classes are similar. Model
is the top-level object though, and has the fit
, evaluate
, and predict
methods on it. You can also save a model to disk.
Model subclassing allows you to define models that can’t be expressed as directed acyclic graphs of layers. For example models where layers are called recursively.
But that means the subclassed model is not an explicit data structure you can inspect and plot, it’s just bytecode. The way layers are connected will be hidden in the model’s call
method (its forward pass). You can’t access the nodes to do feature extraction, as there’s no graph. It’s just a black box once it’s instantiated.
You can mix and match the APIs though. P. 184 shows defining a simple Classifier model subclass that will output a sigmoid for binary classification, and softmax for multiclass. It then shows using that in a functional chain.
Chollet’s bottom line: if you can use the functional API (ie your model is a directed acyclic graph) then use it. The functional API with subclassing layers is maybe the best of both worlds.
Customizing Training and Evaluation
The chapter introduces some methods to customize the default training and evaluation loop. It shows how to write your own evaluation metrics that you can use for reporting (p. 186-7). Then it introduces custom callbacks.
Keras Callbacks
Chollet compares launching a training run on a large dataset to launching a paper airplane. After the initial impulse you lose control. It’s better to use a ‘drone’ that can sense its environment and respond rather than just fly on blindly come what may.
In Keras we do this via the callbacks API. A callback is an object passed to the model when you call fit
and then called by the model at various points in training. The callback can access the available data about the state of the model and performance, and then it can take action, eg stop training, save the model, or change the state of the model.
We might use callbacks to:
Save the current state of the model at different points in training (checkpointing)
Interrupt training when validation loss is no longer improving (Early stopping)
Dynamically adjust parameter values during training (eg change learning rate)
Logging metrics, or visualize them as you go.
Keras has a bunch of premade callbacks you can use, like early stopping, logging, checkpointing etc. Or you can subclass and roll your own.
Early Stopping is a key callback. You can’t predict when your model will start to overfit. You could just keep training until it overfits, then roll back to the start and train again from scratch for the optimal number of epochs. But this is expensive. Better to use
EarlyStopping
to interrupt training when a target metric hasn’t improved for a fixed number of epochs. Use it withModelCheckpoint
, which lets you continually save the model during training, or just save the best model so far.Here’s a simple example:
callback_list = [ keras.callbacks.EarlyStopping( monitor="val_accuracy", patience=2, ), keras.callbacks.ModelCheckpoint( filepath="checkpoint_path.keras", monitor="val_loss", save_best_only=True, )] model.fit(callbacks=callback_list) # other arguments left out
pp. 189/190 walk through creating a custom callback by subclassing the
Callback
class and hooking it up to the events that fire during training.
TensorBoard
TensorBoard is a browser app that you can run locally or embed in your Colab notebook.
It lets you monitor metrics during training, visualize the model architecture, explore embeddings in 3D, and visualize histograms of activations and gradients.
To use it, create an instance of the TensorBoard
callback and pass it as a callback to the model’s fit, like this:
# embed the tensorboard if using Colab
%load_ext tensorboard
%tensorboard --logdir "/my_log_dir"
# create the callback
tensorboard = keras.callbacks.TensorBoard(
log_dir="/my_log_dir")
# pass to the model
model.fit(callbacks=[tensorboard])
Writing a Training Loop
The built in fit
workflow is based on supervised learning, but what if you are working on generative learning, reinforcement learning, self-supervised learning, or something else?
You have to write your own training logic. Pages 194-200 walk through how to do this. It’s a bit beyond the needs of the course for now though. It includes using the @tf.function
decorator to compile your training steps to a TF graph which makes it much faster.
You can either implement the whole loop yourself, or just override the training step, while taking advantage of much of the built in training functionality (like callbacks).