Linear transformations such as sums, matrix products, dot products, Hadamard products, and many more can often be represented using an einsum (short for Einstein summation).
This post explains a simple trick to backpropagate through any einsum, regardless of what operations it represents.
Example Einsum
For example, an einsum for matrix multiplication can be written like so:
import numpy as np A = np.arange(2 * 3).reshape(2, 3) # A = [ # [0, 1, 2], # [3, 4, 5] # ] B = np.arange(3 * 4).reshape(3, 4) # B = [ # [0, 1, 2, 3], # [4, 5, 6, 7], # [8, 9, 10, 11] # ] # C = A @ B (matrix multiplication) # calculate with einsum # A uses i and j, B uses j and k, and C uses i and k C_einsum = np.einsum('ij,jk->ik', A, B) # C_einsum = [ # [20, 23, 26, 29], # [56, 68, 80, 92] # ] # calculate with for loops C_forloop = np.zeros((2, 4)) # order of the loops doesn't matter # you can swap the order of the loops, # just like you can swap the order of some double integrals for i in range(2): for j in range(3): for k in range(4): # A uses i and j, B uses j and k, and C uses i and k, # just like before with the einsum C_forloop[i, k] += A[i, j] * B[j, k] # C_forloop = [ # [20, 23, 26, 29], # [56, 68, 80, 92] # ]
Backpropagating Through Einsum
Here comes the fun part. Backpropagating through an einsum is easy with a simple trick.
Let $L$ be the loss function, and assume we know $\frac{\partial L}{\partial C}$ and need to compute $\frac{\partial L}{\partial A}$.
We can just swap what we did for $C$ with what we did for $A$, keeping the letters the same for each tensor. The code should explain it best.
# forward pass # A uses i and j, B uses j and k, and C uses i and k C = np.einsum('ij,jk->ik', A, B) # backward pass # dL_dC is computed somewhere in the backward pass # A uses i and j, B uses j and k, and C uses i and k # again, just like before with the original einsum # we swapped the letters for C with the letters for A # we used dL_dC instead of A for the first parameter # we computed dL_dA instead of C for the output dL_dA = np.einsum('ik,jk->ij', dL_dC, B)
This simple swapping trick makes it very easy to create formulas when backpropagating through einsums.
Verifying the Shape
It has the same shape as $A$ because the output letters we used (i and j) during the backward pass correspond to the input letters for $A$ (i and j) during the forward pass.
Interpreting the Einsum
If you swap around the letters, you can more easily interpret the einsum used for backpropagation, as a kind of matrix multiplication.
# replacing j with k, and k with j # doesn't change the result dL_dA = np.einsum('ij,kj->ik', dL_dC, B)
This corresponds to multiplying $\frac{\partial L}{\partial C}$ by $B^\intercal$. We can tell because it looks like the example einsum for matrix multiplication, except the letters for $B$ are swapped, meaning $B$ is transposed.
Verifying the Values
We can use JAX's automatic differentiation to verify the values.
import jax import jax.numpy as jnp # compute the loss def loss(A, B): C = jnp.einsum('ij,jk->ik', A, B) return jnp.sum(C) # compute the gradient of the loss with respect to A def grad_A(A, B): # argnums=0 corresponds to A return jax.grad(loss, argnums=0)(A, B) A = jnp.arange(2 * 3).reshape(2, 3).astype(jnp.float32) B = jnp.arange(3 * 4).reshape(3, 4).astype(jnp.float32) # autograd output print(grad_A(A, B)) C = jnp.einsum('ij,jk->ik', A, B) # gradient of the loss with respect to C dL_dC = jnp.ones_like(C) # gradient of the loss with respect to A # this is the backpropagation formula we derived dL_dA = jnp.einsum('ik,jk->ij', dL_dC, B) # manually computed output print(dL_dA) assert (dL_dA == grad_A(A, B)).all() print("Success!")
Conclusion
Einsums are a powerful tool for representing and reasoning about linear transformations, and this simple swapping trick makes it very easy to backpropagate through them. I hope you find this informative!