Do you know what inm,kij,jnm->knm is all about?

For an interactive version of this post, see this Colab notebook.

Introduction

Linear combinations are ubiquitous in machine learning and statistics. Many algorithms and models in the statistics and machine learning literature can be written (or approximated) as a matrix-vector multiplication. Einsums are a way of representing the linear interaction among vectors, matrices and higher-order dimensional arrays.

In this post, I lay out examples that make use of Einsums. I assume that the reader is familiar with the basics of einsums. However, I provide a quick introduction in the next section. For reference, see also [1] and [2]. Throughout this post, we borrow from the numpy literature and denote the element ${\bf x} \in \mathbb{R}^{M_1 \times M_2 \times \ldots \times M_N}$ as an $N$-dimensional array.

In the next section, we present a brief summary of einsum expressions and its usage in numpy / jax.numpy.

An quick introduction to einsums: from sums to indices

Let ${\bf a}\in\mathbb{R}^M$ a 1-dimensional array. We denote by $a_m$ the $m$-th element of $\bf a$. Suppose we want to express the sum over all elements in $\bf a$. This can be written as

$$ \sum_{m=1}^M a_m $$

To introduce the einsum notation, we notice that the sum symbol ($\Sigma$) in this equation simply states that we should consider all elements of $\bf a$ and sum them. If we assume that 1) there is no ambiguity on the number of dimensions in $\bf a$ and 2) we sum over all of its elements, we define the einsum notation for the sum over all elements in the 1-dimensional array $\bf a$ as

$$ \sum_{m=1}^N a_m\stackrel{\text{einsum}}{\equiv} {\bf a}_m

$$

To keep our notation consistent, we denote indices with parenthesis as static dimensions. Static dimensions allows us to expand the expressiveness power of einsums. That is, we denote all of the elements of $\bf a$ under the einsum notation as ${\bf a}_{(m)}$.

Since the name of the arrays are not necessarily meaningful to define these expressions, we define einsum expressions in numpy by focusing only on the indices. To represent which dimensions are static and which should be summed over, we introduce the -> notation. Elements to the left of -> define the set of indices of an array and elements to the right of -> represent indices that we do not sum over. For example, the sum over all elements in $\bf a$ is written as

$$ {\bf a}_m \equiv \texttt{m->} $$

and the selection of all elements of $\bf a$ is written as

$$ {\bf a}_{(m)} \equiv \texttt{m->m} $$

In the following snippet, we show this notation in action.

>>> a = np.array([1, 2, 3, 4])
>>> np.einsum("m->", a)
10
>>> np.einsum("m->m", a)
array([1, 2, 3, 4])