Improving Model Fit (Chollet)
Section 3 of Chollet: Chapter O5: Fundamentals of Machine Learning
To achieve perfect model fit you must first overfit. You don’t know where the boundary is, so you have to cross it to find it.
You first have to achieve a model that shows some generalization power and is able to overfit. Then once you have such a model you focus on refining generalization by fighting overfitting. There are three common problems:
Training doesn’t get started, training loss doesn’t decrease over time.
Training starts OK, but doesn’t generalize, you can’t beat your simple baseline.
Training and validation loss goes down ok, you beat your baseline, but you can’t overfit, indicating you’re still underfitting.
Addressing these issues gets you to the first big milestone, getting a model that has some generalisation power and is able able to overfit.
Tuning Gradient Descent Parameters
If your model doesn’t get started, or stalls early, your loss is stuck. This is always something you can overcome. It’s a problem with the configuration of the gradient descent process: the choice of optimizer, distribution of initial values in the weights of the model, learning rate, batch size. The parameters are interdependent, it’s usually sufficient to tune the learning rate and batch size, keeping the rest constant.
If you find yourself in a situation where you have very low performance try:
Lowering or increasing the learning rate. Too high and you’ll overshoot a proper fit, too low and it will stall the training process.
Increase the batch size. A batch with more samples will lead to gradients that are more informative and less noisy.
Using Better Architecture Priors
You have a model that fits, but your validation metrics aren’t improving. you can’t beat the baseline of a random classifier. So it trains but doesn’t generalize.
This is the worst situation because it indicates something is fundamentally wrong in your approach. It’s not always easy to tell what.
It may be that the input data you’re using doesn’t contain sufficient information to predict your targets. The problem as formulated isn’t solvable.
It may be that the kind of model you’re using is not suited for the problem at hand. Eg in a timeseries problem a densely connected architecture may not beat a trivial baseline, but a recurrent architecture might. You need to use a model that makes the right assumptions about the problem.
Textbooks like this cover the best architectures to use for the major data modalities (images, time series, text etc), the research literature will likely contain a lot on more specific problems and suitable architectures.
Increase Model Capacity
If you get a model that fits, validation metrics are going down, and you have some generalization power (you beat your baseline), you’re almost there. Now you need to start overfitting.
It should always be possible to overfit. If your validation metrics seem to stall, or improve very slowly, rather than peaking and reversing course then you’re not overfitting. You’ll encounter this often.
If you can’t overfit, you have an issue with the representational power of your model. You’ll need one with more capacity, ie it can store more information.
You can add more layers, use bigger layers, or use better architectural priors (different kinds of layers).