To make sure that you have “at least steps_per_epoch * epochs batches“, set the steps_per_epoch to
steps_per_epoch = len(X_train)//batch_size
validation_steps = len(X_test)//batch_size # if you have validation data
You can see the maximum number of batches that model.fit() can take by the progress bar when the training interrupts:
5230/10000 [==============>...............] - ETA: 2:05:22 - loss: 0.0570
Here, the maximum would be 5230 – 1
Importantly, keep in mind that by default, batch_size is 32 in model.fit().
If you’re using a tf.data.Dataset, you can also add the repeat() method, but be careful: it will loop indefinitely (unless you specify a number).