Tech News
← Back to articles

A Trick for Backpropagation of Linear Transformations

read original related products more articles

Linear transformations such as sums, matrix products, dot products, Hadamard products, and many more can often be represented using an einsum (short for Einstein summation).

This post explains a simple trick to backpropagate through any einsum, regardless of what operations it represents.

Example Einsum

For example, an einsum for matrix multiplication can be written like so:

import numpy as np A = np.arange(2 * 3).reshape(2, 3) # A = [ # [0, 1, 2], # [3, 4, 5] # ] B = np.arange(3 * 4).reshape(3, 4) # B = [ # [0, 1, 2, 3], # [4, 5, 6, 7], # [8, 9, 10, 11] # ] # C = A @ B (matrix multiplication) # calculate with einsum # A uses i and j, B uses j and k, and C uses i and k C_einsum = np.einsum('ij,jk->ik', A, B) # C_einsum = [ # [20, 23, 26, 29], # [56, 68, 80, 92] # ] # calculate with for loops C_forloop = np.zeros((2, 4)) # order of the loops doesn't matter # you can swap the order of the loops, # just like you can swap the order of some double integrals for i in range(2): for j in range(3): for k in range(4): # A uses i and j, B uses j and k, and C uses i and k, # just like before with the einsum C_forloop[i, k] += A[i, j] * B[j, k] # C_forloop = [ # [20, 23, 26, 29], # [56, 68, 80, 92] # ]

Backpropagating Through Einsum

Here comes the fun part. Backpropagating through an einsum is easy with a simple trick.

Let $L$ be the loss function, and assume we know $\frac{\partial L}{\partial C}$ and need to compute $\frac{\partial L}{\partial A}$.

We can just swap what we did for $C$ with what we did for $A$, keeping the letters the same for each tensor. The code should explain it best.

# forward pass # A uses i and j, B uses j and k, and C uses i and k C = np.einsum('ij,jk->ik', A, B) # backward pass # dL_dC is computed somewhere in the backward pass # A uses i and j, B uses j and k, and C uses i and k # again, just like before with the original einsum # we swapped the letters for C with the letters for A # we used dL_dC instead of A for the first parameter # we computed dL_dA instead of C for the output dL_dA = np.einsum('ik,jk->ij', dL_dC, B)

... continue reading