Linear transformations such as sums, matrix products, dot products, Hadamard products, and many more can often be represented using an einsum (short for Einstein summation).
This post explains a simple trick to backpropagate through any einsum, regardless of what operations it represents.
Example Einsum
For example, an einsum for matrix multiplication can be written like so:
import numpy as np A = np.arange(2 * 3).reshape(2, 3) # A = [ # [0, 1, 2], # [3, 4, 5] # ] B = np.arange(3 * 4).reshape(3, 4) # B = [ # [0, 1, 2, 3], # [4, 5, 6, 7], # [8, 9, 10, 11] # ] # C = A @ B (matrix multiplication) # calculate with einsum # A uses i and j, B uses j and k, and C uses i and k C_einsum = np.einsum('ij,jk->ik', A, B) # C_einsum = [ # [20, 23, 26, 29], # [56, 68, 80, 92] # ] # calculate with for loops C_forloop = np.zeros((2, 4)) # order of the loops doesn't matter # you can swap the order of the loops, # just like you can swap the order of some double integrals for i in range(2): for j in range(3): for k in range(4): # A uses i and j, B uses j and k, and C uses i and k, # just like before with the einsum C_forloop[i, k] += A[i, j] * B[j, k] # C_forloop = [ # [20, 23, 26, 29], # [56, 68, 80, 92] # ]
Backpropagating Through Einsum
Here comes the fun part. Backpropagating through an einsum is easy with a simple trick.
Let $L$ be the loss function, and assume we know $\frac{\partial L}{\partial C}$ and need to compute $\frac{\partial L}{\partial A}$.
We can just swap what we did for $C$ with what we did for $A$, keeping the letters the same for each tensor. The code should explain it best.
# forward pass # A uses i and j, B uses j and k, and C uses i and k C = np.einsum('ij,jk->ik', A, B) # backward pass # dL_dC is computed somewhere in the backward pass # A uses i and j, B uses j and k, and C uses i and k # again, just like before with the original einsum # we swapped the letters for C with the letters for A # we used dL_dC instead of A for the first parameter # we computed dL_dA instead of C for the output dL_dA = np.einsum('ik,jk->ij', dL_dC, B)
... continue reading