Computing sharding with einsum
Mental arithmetic in grade school (e.g., memorizing your times tables) is typically justified on the grounds that facility in basic calculations makes it easier to focus on higher-level problems that require being able to do these manipulations. When working on DTensor, I have also found it important to be able to quickly calculate what shardings you get when you do matrix multiplies on sharded tensors. Without being able to do this quickly and accurately, working through examples becomes a slog. I’ve also found that while diagrammatic approaches (e.g., drawing a matrix and slicing it into shards) are intuitive, they are slow and unwieldy to do calculations with.
Recently, I’ve found that working on sharding with einsum is nice and efficient, and I hope to persuade you to do it this way when you need to reason about sharding! This post somewhat overlaps with Sharded Matrices and How to Multiply Them, but with some different emphasis and some different notation.
Einsum primer
Einstein summation is a compact way of representing many multi-dimensional linear algebra operations, including matrix multiplies. It is nice because you don’t have to puzzle through the abstruse differences of matrix multiply operations like @ , torch.matmul , torch.bmm , torch.mm : for any “matrix multiply”, as long as you know the input and output shapes of your tensor, you can directly write out an einsum equation. For example, classic matrix multiply as you see it in math has a signature like mm(x: f32[A, B], y: f32[B, C]) -> f32[A, C] . In einsum notation, you would simply write torch.einsum("ij,jk->ik", x, y) : each of the indices lines up exactly with the input sizes. As another example, in nn.Linear , your weight has shape (in_features, out_features) . You don’t have to remember how to setup the transposition, just write torch.einsum("bi,oi->bo", input, weight) .
A useful piece of terminology that pops up for einsum is a contraction dimension. This is any index that appears in the input tensors but not the output tensors. The ones that show up in both inputs and outputs are free dimensions: if the free dimension is in all inputs it’s a batch dimension, and if it’s missing from some inputs we will broadcast those tensors.
Einsum backwards
Do you always forget how exactly you should transpose your tensors in the backward formula for matrix multiply? As long as you aren’t doing weird things in your einsum (e.g., no repeated indices, every input index is paired with another index), there is a very simple way to compute backwards: keep every input constant except the one you want to compute the gradient for, and swap its index set with the output index set.
For example, linear is "bi,oi->bo" for (input, weight -> output) . Then we have:
grad_input = torch.einsum("bo,oi->bi", grad_output, weight) grad_weight = torch.einsum("bi,bo->oi", input, grad_output)
... continue reading