What each line does and what breaks if you move it.
A three-class spiral dataset. Shaded regions show the model's softmax confidence. The boundary sharpens as training progresses.
Building a PyTorch training loop is fairly straightforward, but getting everything in the right place and in the right order can feel surprisingly fragile. There are loads of moving parts and after the most basic errors are fixed, most of the other mistakes can be pretty hard to spot. Training runs will fail to converge, produce incorrect results, or consume excessive memory if lines are misplaced.
The sections below will go through each operation in sequence, explaining exactly how to write each section, and all the common mistakes to watch out for. Distributed training, FSDP, and multi-GPU setups are out of scope here, but we'll come back to that in a future essay. (The animation above was produced by running the loop on synthetic data and capturing the decision boundary at each epoch.)
The complete loop
Let's look, first of all, at the complete training loop. You don't need to understand or memorise it yet, just get a feel for the structure.
1 import torch 2 import torch . nn as nn 3 from torch . utils . data import DataLoader , TensorDataset 4 5 # --- data --- 6 dataset = TensorDataset ( X_train , y_train ) 7 loader = DataLoader ( dataset , batch_size = 64 , shuffle = True ) 8 9 # --- model, loss, optimiser --- 10 model = MLP ( in_features = 2 , hidden = 128 , out_features = 3 ) 11 criterion = nn . CrossEntropyLoss ( ) 12 optimiser = torch . optim . Adam ( model . parameters ( ) , lr = 1e-3 ) 13 scheduler = torch . optim . lr_scheduler . CosineAnnealingLR ( optimiser , T_max = 100 ) 14 15 # --- loop --- 16 for epoch in range ( 100 ) : 17 model . train ( ) 18 for X_batch , y_batch in loader : 19 optimiser . zero_grad ( ) 20 logits = model ( X_batch ) 21 loss = criterion ( logits , y_batch ) 22 loss . backward ( ) 23 torch . nn . utils . clip_grad_norm_ ( model . parameters ( ) , max_norm = 1.0 ) 24 optimiser . step ( ) 25 scheduler . step ( ) 26 27 model . eval ( ) 28 with torch . no_grad ( ) : 29 val_logits = model ( X_val ) 30 val_loss = criterion ( val_logits , y_val )
Now let's go through each line and understand what it does, and how not to break it. We'll start with some of the common mistakes.
TL;DR Where the order really matters
Here are some of the most common failures, and how you can break the training loop by getting the placement a little bit wrong. The reason to memorise these is that none of them will raise an exception, over time you'll get a sense for what kind of errors to look for in your training runs, but for the first few times this crib sheet will help you out.
... continue reading