To make sure that you have “at least steps_per_epoch * epochs
batches“, set the steps_per_epoch
to
steps_per_epoch = len(X_train)//batch_size
validation_steps = len(X_test)//batch_size # if you have validation data
You can see the maximum number of batches that model.fit()
can take by the progress bar when the training interrupts:
5230/10000 [==============>...............] - ETA: 2:05:22 - loss: 0.0570
Here, the maximum would be 5230 – 1
Importantly, keep in mind that by default, batch_size
is 32 in model.fit()
.
If you’re using a tf.data.Dataset
, you can also add the repeat()
method, but be careful: it will loop indefinitely (unless you specify a number).