Collaborators: Less Wright, Howard Huang, Chien-Chin Huang, Crusoe: Martin Cala, Ethan Petersen
tl;dr: we used torchft and torchtitan to train a model in a real-world environment with extreme synthetic failure rates to prove reliability and correctness of fault tolerant training
Training loss across 1200 failures with no checkpoints.
NOTE: Each small spike is a non-participating worker recovering which affects the metrics but not the model
Introduction
We want to demonstrate torchft in worst case scenarios by running a training job with the most extreme failure rates possible.
Most LLM pre-training uses sharded models using FSDP. torchft supports sharded models using HSDP2, which combines a sharded model with the fault tolerant DDP all reduce from torchft. We’ve integrated torchft into torchtitan so you can use fault tolerance out of the box. torchft+titan also support other sharding/parallelisms within each replica group, such as tensor parallelism (TP), pipeline parallelism (PP) and more.
Here’s the structure of a training job with torchft:
The structure of the training job. torchft’s fault tolerant DDP implementation is used across the replica groups to synchronize the gradients. Standard FSDP2 and other parallelisms are used within each replica group.
... continue reading