How do you beat overfitting in Machine Learning ? Part 3: Early stopping
Early stopping is a method used to prevent overfitting in machine learning. Overfitting occurs when the model becomes too specialized in the training data and loses its ability to generalize to new data. In this article, we will explain in detail what early stopping is, how it works, its advantages, its limitations, as well as its possible combinations with other regularization techniques.
We will use a practical example to illustrate each step, while detailing fundamental concepts such as MNIST, validation metrics, normalization, dense, and more.
What is early stopping?
Early stopping consists of stopping the training of a model when its performance on a validation set begins to degrade. The idea is to avoid the model training for too long and learning specific details of the training data that would not generalize (e.g., noise).
Detailed steps for implementing early stopping
Step 1: Select a validation metric to monitor
To monitor the model’s performance, a validation metric is used, such as:
- Accuracy: Measures the percentage of correct predictions.
- Loss: Measures the model’s overall error. For example, in classification, it is often cross-entropy loss.
Why monitor a metric?
The model’s performance depends on its ability to minimize an error or maximize accuracy. If the metric improves, this indicates that the model is generalizing well. If it starts to degrade, the model is likely overfitting.
Step 2: Evaluate the model on the validation set
During training, the model is evaluated after each epoch (a complete pass over the training data) on a validation set. This set is a portion of the data not used for training but only for measuring performance. This simulates how the model would perform on unseen data.
Step 3: Store the best performance and save the corresponding weights
The model’s weights are the internal parameters (neuron values) that change with each training iteration. With early stopping :
- After each epoch, if the validation set performance is better than the previous one, the model weights are saved.
- At the end of training, the weights corresponding to the best validation performance are restored. This ensures that the final model uses its most optimal parameters.
Step 4: Stop training if performance stops improving
Training stops when there is no improvement for a certain number of consecutive epochs (called patience). This parameter prevents training from stopping due to minor fluctuations.
Practical example: Image recognition with MNIST
What is MNIST?
MNIST is a well-known database used for machine learning demonstrations, ideal for illustrating fundamental concepts such as normalization or, in this case, early stopping. MNIST contains 70,000 images of handwritten digits (0 to 9).
The database is divided into three sets: training, testing. Initially, the database is pre-divided into:
- 60,000 images for training
- 10,000 images for testing
Then divide the training set (60,000 images) into two parts:
- 50,000 images for training
- 10,000 images for validation
The 50,000 training images are used for training.
The 10,000 validation images are used to monitor performance DURING training to detect overfitting.
The 10,000 test images are used ONLY at the end to evaluate the model’s final performance, such as calculating the test accuracy.
Is it always necessary to have a validation set?
No, not in all cases.
It depends on the method and the data available:
- With Early Stopping (3 sets required):
You need a validation set to monitor performance.
A separate test set is necessary to evaluate final generalization. - Without Early Stopping (2 sets possible):
If you do not use a validation set, all 60,000 training images can be used for learning, and the 10,000 test images serve as final evaluation.
However, this prevents adopting techniques like early stopping.
Implementation with Python
Here is an example with TensorFlow and Keras
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.callbacks import EarlyStopping
# Chargement des données (MNIST)
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
# Normalisation des données
x_train, x_test = x_train / 255.0, x_test / 255.0 # Divise chaque pixel par 255 pour obtenir des valeurs entre 0 et 1
# Division des données : ensemble de validation
x_train, x_val = x_train[:50000], x_train[50000:]
y_train, y_val = y_train[:50000], y_train[50000:]
# Définir le modèle
model = Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)), # Aplatit les images 28x28 en vecteurs de 784
Dense(128, activation='relu'), # Couche dense avec 128 neurones
Dense(10, activation='softmax') # Couche de sortie avec 10 neurones (10 classes pour 0-9)
])
# Compilation du modèle
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
# Early Stopping : surveiller la perte sur l'ensemble de validation
early_stopping = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
# Entraînement avec early stopping
history = model.fit(x_train, y_train,
validation_data=(x_val, y_val),
epochs=50,
callbacks=[early_stopping])
# Évaluation finale
test_loss, test_accuracy = model.evaluate(x_test, y_test)
print(f"Test Accuracy: {test_accuracy}")
Data normalization:
The images consist of pixels with values between 0 and 255. Dividing by 255 scales these values to between 0 and 1, facilitating model convergence during training
Dataset splitting :
The 60,000 training images are divided into:
- 50,000 for training.
- 10,000 for validation.
Dense :
A dense layer is a fully connected layer where each neuron is connected to all neurons in the previous layer.
Example:
- Dense(128): 128 neurons.
- activation=”relu”: Activation function introducing non-linearity.
Restore Best Weights:
When early stopping is triggered, the weights are reset to their state corresponding to the best validation performance. This ensures the final model is the most performant.
During training, after each epoch, the algorithm compares the model’s performance on the validation set to previous performances:
- If the metric (e.g., validation loss or accuracy) improves, the current model state (weights and biases) is saved.
- If the metric worsens, training continues, but the weights corresponding to the best observed performance are retained.
Triggering Early Stopping:
Once training stops (e.g., after validation loss does not decrease for 5 consecutive epochs), the algorithm restores the saved model version with the best weights.
- The best weights were reached at epoch 3, with a validation loss of 0.40.
- If early stopping is triggered after epoch 6 (because no improvement occurred for 3 consecutive epochs), the algorithm restores the weights saved at epoch 3.
- The best weights were reached at epoch 3, with a validation loss of 0.40.
- If early stopping is triggered after epoch 6 (because no improvement occurred for 3 consecutive epochs), the algorithm restores the weights saved at epoch 3.
- The best weights were reached at epoch 3, where the accuracy was 84%.
- If training continues without improvement for 3 epochs (with a patience of 3), early stopping will be triggered, and the weights saved at epoch 3 will be restored.
Test Accuracy :
Test accuracy is the validation metric chosen. Accuracy measures the final model’s precision on the test set. This indicates how well the model generalizes to unseen data
Advantages of early stopping
- Simple to use.
- Cost reduction: Saves time and resources by stopping training early.
- Improves generalization.
Limitations of early stopping
- Dependency on the validation set:
If the validation set is poorly chosen (too small or unrepresentative), it can skew results. - Not always optimal:
In some cases, other methods like L2 or L1 regularization may be more effective. - Sensitivity to fluctuations and overly small patience parameter:
If validation performance fluctuates (e.g., small increases or decreases) and the patience parameter is too small, training may stop prematurely before the model reaches its full potential.
=> Example: suppose the validation metric is accuracy:
Epoch 1: 85%
Epoch 2: 86%
Epoch 3: 85.5%
Epoch 4: 85.8%
Epoch 5: 86.5%
=> Here, accuracy fluctuates but continues to improve overall. If the patience parameter is too short (e.g., 1 epoch), training might stop at epoch 3, before accuracy reaches 86.5%. This means the model is not fully optimized.
Can early stopping be combined with other techniques?
Yes, here are common combinations based on various criteria:
- Model complexity :
=> Complex model: combine Dropout, L2, and Early Stopping.
=> Simple model: early Stopping alone may suffice.
- Data quantity :
=> Small dataset: favor L2 or Data Augmentation.
=> Large dataset: combine Batch Normalization with Early Stopping.
- Noise in the data :
=> If your data is very noisy, use Dropout and a reduced learning rate alongside early stopping.
Conclusion
Early Stopping is a powerful method but must be used wisely. When combined effectively with other techniques, it can become a crucial asset to prevent overfitting and achieve high-performing models.
Sirine Amrane