Keras uses way too much GPU memory when calling train_on_batch, fit, etc

It is a very common mistake to forget that the activations, gradients and optimizer moment tracking variables also take VRRAM, not just the parameters, increasing memory usage quite a bit. The backprob calculations themselves make it so the training phase takes almost double the VRAM of forward / inference use of the neural net, and the Adam optimizer triples the space usage.

So, in the beginning when the network is created, only the parameters are allocated. However, when the training starts. the model actiavtions, backprop computations and the optimizer’s tracking variables get allocated, increasing memory use by a large factor.

To allow the training of larger models, people:

  • use model parallelism to spread the weights and computations over different accelerators
  • use gradient checkpointing, which allows a tradeoff between more computation vs lower memory use during back-propagation.
  • Potentially use a memory efficient optimizer that aims to reduce the number of tracking variables, such as Adafactor, for which you will find implementations for all popular deep learning frameworks.

Tools to train very large models:

  • Mesh-Tensorflow https://arxiv.org/abs/1811.02084
    https://github.com/tensorflow/mesh
  • Microsoft DeepSpeed:
    https://github.com/microsoft/DeepSpeed https://www.deepspeed.ai/
  • Facebook FairScale: https://github.com/facebookresearch/fairscale
  • Megatron-LM: https://arxiv.org/abs/1909.08053
    https://github.com/NVIDIA/Megatron-LM
  • Article on integration in HuggingFace Transformers: https://huggingface.co/blog/zero-deepspeed-fairscale

Leave a Comment

Hata!: SQLSTATE[HY000] [1045] Access denied for user 'divattrend_liink'@'localhost' (using password: YES)