Introduction
Einstein's summation convention, alternatively known as einsum in torch and jax, is particularly powerful for reasoning through distributed ML strategies. In this post, I review this notation and introduce some non-standard additions that I find quite helpful when thinking about sharding and parallelization, connecting to torch concepts along the way.
Basics
Einstein notation is maximally lazy: one avoids explicit summations (as much as possible) and operations are generally inferred by index structures. There is enough information in the indices to remove any ambiguity in the vast majority of cases.
A sum occurs when repeated indices only appear on one side of an equation:
\begin{align}
\textrm{dot product:}&\quad z = x_{ k } y_{ k } \nonumber\\
\textrm{matmul:}&\quad C_{ mn } = A_{ mk } B_{ kn } \nonumber
\end{align}
The name given to a summed-over index like $k$ is completely arbitrary (it's a "dummy-index"), but I attempt to make choices that carry some semantic or conventional meaning.
In contrast, for an element-wise vector product, the resulting tensor would carry the index:
\begin{align}
\textrm{elementwise product:}&\quad z_{ k } = x_{ k } y_{ k } \nonumber
\end{align}
It is sometimes also useful to attach indices to various operations to indicate the dimension they act on. Softmax is a common example: $p _{ s v } = \texttt{Soft} _{ v } z _{ s v }$ results in a tensor for which $\texttt{sum} _{ v } p _{ sv } = 1$, while $\texttt{sum} _{ s } p _{ sv } \neq 1$, in general.
Example: MLP Layers
A simple MLP layer with non-linearity $\phi$ acting on an input tensor $x_{ b s d}$ of shape (batch_size, seqlen, dim) produces output:
z_{ bsd } = W^{ 1 }_{ de }\phi \left ( W^{ 0 }_{ ed' } x_{ bsd' } \right ) \ .
Explicitly, the above sums over the $d'$ and $e$ indices which span over range(dim) and range(4 * dim), respectively, in the usual vanilla MLP setup.
For a SwiGLU type activation function, the above would change to1
z_{ bsd } = W^{ 2 }_{ df } W^{ 1 }_{ f d'' }x_{ bsd'' }\phi \left ( W^{ 0 }_{ fd' } x_{ bsd' } \right ) \ .
Notation for Sharding
How can sharding be represented in this notation? Two conventions are useful:
- Notation for splitting up an index.
- Notation for indicating that an index's values are physically distributed across different devices.
The latter notation must also indicate exactly which devices a tensor is sharded over, as complex, multi-dimensional device meshes are often involved.
For the first goal, we take inspiration from the excellent einops package and denote decomposing one index into multiple as in $b \longrightarrow (r c)$, or at the tensor level:
x_{ bsd } \longrightarrow x_{ (rc)sd } \longrightarrow x_{ r c s d } \ ,
where I've removed the parentheses after it's clear which indices have been split.
If the above were a batch index, this might describe a DDP- or FSDP-like sharding pattern in which one splits up b in range(B) indices across r in range(R) devices (ranks), such that c in range(B // R) (perfect divisibility assumed). Index ordering is significant here and I use row-major throughout, so that every rank gets a contiguous slice of the batch dimension in the present example.
For the second goal, add a numerical subscript to sharded indices. The value corresponds to which particular group of devices the index is sharded over, assuming that the relevant groups of devices have already been enumerated. So, $x_{ a_{ 0 } \ldots }$ indicates that no single device owns all of $a$-index values for this logical tensor, and they're instead spread out over one axis of a particular mesh of devices.
Simple sharding example. The 16-element logical tensor, indexed by $a$, is sharded across four devices. Two indices describe the sharding: $r _{ 0 }$ which indexes the device and $b$ which indexes the locally-available tensor values. Both indices takes values in range(4), here. The zero subscript indicates which axis of the device mesh the given index is sharded over, information which is crucial in more complex examples.
Note that the first rule isn't even really needed! Just attaching subscripts, like $x _{ a } \longrightarrow x _{ a _{ 0 } }$ is enough to indicate sharding, though this leaves the precise sharding pattern implicit. But for clarity and technical reasons, in this post I will be as explicit as possible, and when a given index a in range(A) is split across $R$-devices I will generally express the sharding by splitting the logical index into two, $a \longrightarrow r b$, with the range of one index matching the number of devices, r in range(R) and b in range(A // B), and then only attach the sharding suffix to $r _{ 0 }$, since it is the only index whose range is not held on a single device. So, $x _{ a } \longrightarrow x _{ r _{ 0 } b }$. This is illustrated in the figure above.
The rest of this post is dedicated to examples which illustrate the notation and its utility.
Examples
I will mostly focus on the forward-pass, as I plan to discuss gradient considerations (where this notation especially shines) in a later post.
Distributed Data Parallel
In the DDP-scenario described above with $R$ total devices and tensor $x_{ r c s d }$, each rank will only own a single $r$-index value, but holds all other index values, e.g. all of c in range(B // R), s in range(seqlen), etc. Then, letting $0$ correspond to the world group, the tensor would written as
x_{ r c s d } \longrightarrow x_{ r_{ 0 } c s d } \ .
In a future post, I will discuss how the DDP gradient-sum across devices naturally follows from the fact that local compute operates on all the non-$r _{ 0 }$-indices and leave $r _{ 0 }$ hanging in a sense which will be made precise.
All-Reduce and Reduce-Scatter
All-reduce illustration, from the NCCL docs.
Let's do a sum over a sharded dimension. Let the logical tensors be $x_{ d }, y_{ d }$ and indicate sharding as above: $x_{ d } \longrightarrow x_{ r _{0}e }$ and similar for $y$. Then, the notation makes it clear at a glance that for the dot product,
z = x_{ r _{0} e }y_{ r _{0}e } \quad \textrm{(all-reduce)} \ ,
while the sum over the $e$ index can be performed locally, the $r _{0}$-sum involves tensors on different devices and hence necessitates cross-device communication. That is, an all-reduce is required.
Reduce-scatter illustration, again from the NCCL docs.
Reduce-scatter operations can be similarly expressed, in which the result of the sum is sharded rather than replicated. A second tensor dimension is required in this case. Let the logical global operation be a matrix-vector dot product: $z_{ d } = W_{ de } x_{ e }$. Now shard the reduction dimension, $e \longrightarrow r _{0} f $. To complete the sum, our options are either to all-reduce and replicate the result across every device like above,
z_{ d } = W_{ d r _{0} f } x_{ r _{0}f } \quad \textrm{(all-reduce)} \ ,
or to also split the output dimension on the same set of devices, $d \longrightarrow r c$, and realize the output non-redundantly across ranks,
z_{ r _{0}c } = W_{ r c r _{0}' f } x_{ r _{0}'f } \quad \textrm{(reduce-scatter)} \ .
Note the subscript on a given index does not need to match across sides! Primitive operations can move subscript placement. Because the weight's output index is not physically sharded, its unprimed $r$ has no subscript, but in the process of performing the right-hand-side sums, one can store the entire logical result in a sharded manner across devices (hence the subscript on the left side's $r$). More on the rules for subscripts later. Since an all-gather just corresponds to unsharding, as in $x_{ r_{ 0 }f } \longrightarrow x_{ e }$, the fact that an all-reduce can be implemented as a reduce-scatter followed by an all-gather is also very clear in this notation.
These patterns are seen again in the Tensor- and Sequence-Parallel examples below.
Sometimes it is useful, though a little inelegant, to write communication as an operator that acts on individual indices, as in $x = \texttt{AllReduce}_{ r }(x _{ r })$.
Tensor-Parallel
Tensor-Parallel2 (TP) MLP layers just shard the intermediate reduction dimension across devices: $e \longrightarrow r _{0}f$ in the MLP expression above,
z_{ bsd } = W^{ 1 }_{ dr _{0}f }\phi \left ( W^{ 0 }_{ r _{0}fd' } x_{ bsd' } \right ) \ .
The first MLP weight is sharded across its output dimension (ColwiseParallel in torch lingo) and the second across its input dimension (RowwiseParallel, similarly). An all-reduce is needed to complete the sum over the logical intermediate index after the second matmul, due to the sharded $r_{ 0 }$.
The above decomposition is also a concise proof of correctness for the usual TP sharding strategy, which I personally find easier to follow than the more common proof-by-drawing approach.
FSDP + TP Activations
For an example of a tensor sharded over multiple dimensions, take a FSDP + TP scenario with 0 and 1 being the indices for these two groups. The logical intermediate, TP-sharded activation tensor $z _{ b s d }$ gets sharded into $z _{ f _{ 0 }c s t _{ 1 } e }$ where:
f in range(fsdp_degree)c in range(global_batch_size // fsdp_degree)s in range(seqlen)t in range(tp_degree)e in range(intermediate_dim // tp_degree)
Obviously, this is an unfortunate number of indices, but that mostly just reflects the complexities of sharding multi-dimensional tensors over multiple submeshes.
Shard, Replicate, and Partial
DTensor is the modern pytorch API used to describe distributed tensors. Each DTensor is associated with an n-dimensional mesh (array) of devices, and each mesh axis is associated with a placement that describes the state of the tensor relative to that axis, in ways I will describe. All of these ideas map very cleanly onto the present notation.
The three most fundamental placements are Replicate, Shard, and Partial. The first two are
relatively self-explanatory:
Replicate: the tensor is replicated across this dimension.Shard: one dimension of the tensor (which must be specified) is distributed across this axis.
These can be illustrated minimally with four GPUs. Dividing them into a 2x2 device mesh and taking 2D tensor $W_{ ab }$, here are some index shardings and their corresponding pytorch 2-tuple of placements:
- $W_{ ab }$:
(Replicate, Replicate). No sharding: every device gets the original tensor. - $W_{ ab } \longrightarrow W_{ a r_{ 0 } d}$:
(Shard(1), Replicate), distribute the final tensor dimension across the leading device mesh axis, keeping these shards constant across the final mesh axis. Every device gets half of the original tensor. - $W_{ ab } \longrightarrow W_{ r_{ 0 } c r '_{ 1 } d }$:
(Shard(0), Shard(1)), distribute the leading tensor dimension across the leading device mesh axis, and similar for the final tensor and device axis. Every device gets a quarter of the original tensor.
These, and more, are illustrated in the figure below. The correspondence between Shard, Replicate, and the sharded index notation is hopefully clear3.
Note that constraints exist: an expression like $W_{ ab } \longrightarrow W_{ r_{ 0 } c r'_{ 0 } d }$ where two indices are sharded over the same device mesh axis4 is nonsensical and illegal. This is reflected in the DTensor API: an n-dimensional device mesh requires an n-tuple of placements (only a single placement per axis!) to describe a DTensor. So, another rule: a tensor can never have two indices with the same sharding subscript. This requirement will come up in the Sequence-Parallel discussion below.
Various ways of splitting a simple 2x2 tensor over a 2x2 device mesh.
What about Partial? This denotes pending operations5. Return to the distributed dot-product and consider the intermediate which arises after only performing the local sums:
z_{ r_0 } = x_{ r _{0} c }y_{ r _{0}c } \ .
Being a partial result, this temporary naturally has a Partial placement with respect to the 0-dimension of the mesh. It's only after performing the 0-mesh-dim all-reduce that one gets the full result: $z = \texttt{AllReduce}_{ r _{ 0 } }(z _{ r _{ 0 } })$, which has a Replicate placement. In a more general example, the Partial could also be completed via a reduce-scatter, resulting in a Shard placement.
Therefore, dangling device-mesh indices with sharding subscripts indicate Partial placements.
In the forward, distributed reductions are generally completed immediately so that there are no long-lived Partial tensors, but in the backwards it can be highly beneficial to leave various terms in a Partial state and judiciously delay completing the pending operation. This will be discussed more elsewhere, but for now weight gradients in DDP and FSDP provide a good example: weight gradients live in a Partial state until the time is right to reduce them over the relevant mesh dimension, usually in a bucketed manner for efficiency.
Sequence-Parallel
Sequence-Parallelism (SP), in the NVIDIA sense, removes some of the redundancy present in TP due to the fact that the TP inputs and outputs are fully replicated.
In SP, the inputs and outputs are instead sharded over the sequence dimension, $s \longrightarrow r _{0}t$, reusing the TP mesh axis for the sharding. Making this additional replacement in the previous TP expression gives:
z_{ b r _{0}t d } = W^{ 1 }_{ dr _{0}'f }\phi \left ( W^{ 0 }_{ r _{0}'fd' } x_{ b r _{0}t d' } \right ) \ .
It is useful to step through the intermediates:
- The very first operation $W^{ 0 } _{ r _{0}'fd' } x _{ b r _{0}t d' }$ already has a potential issue. If one blindly performed the $d'$ sum and carried the other indices through, the result would look like $y _{ b r _{ 0 } t r' _{ 0 } f } = W^{ 0 } _{ r _{0}'fd' } x _{ b r _{0}t d' }$. But this breaks the rule that two tensor dimensions can't be sharded over the same mesh axis! The same subscript cannot appear twice, and the notational rules help catch a possible logic mistake. The solution is to unshard one of the offending input dimensions via an all-gather, and the best option is to temporarily unshard the input's sequence dimension. So, the legal progression is:
W^{ 0 } _{ r _{0}'fd' } x _{ b r _{0}t d' } \xrightarrow{\rm all-gather} W^{ 0 } _{ r _{0}'fd' } x _{ b s d' } \xrightarrow{\rm TP\ matmul} y_{ b s r'_{ 0 } f }
- The non-linearity is trivial: $v _{ b s r' _{ 0 } f } = \phi(y _{ b s r' _{ 0 } f })$. Elementwise ops are easy.
- The sums in the final matmul $W^{ 1 } _{ dr _{0}'f }v _{ b s r' _{ 0 } f }$ eliminate the sharded subscript, and so it can legally be moved onto the output tensor without inducing a double-0-subscript: $z _{ b r _{ 0 } t d }= W^{ 1 } _{ dr _{0}'f }v _{ b s r' _{ 0 } f }$. Said differently: reduce-scatter.
Ring Attention
Ring Attention is morally a distributed version of flash attention, in which inputs are sharded along the sequence dimension $s$ and the computation never realizes any intermediates which have a full seqlen-sized dimension.
Let's simplify a bit to keep the presentation short and minimize indices:
- Suppress the batch dimension and focus on a single head.
- Omit the causal mask.
- Ignore the (numerically crucial) maximum attention score tracking.
Illustration of ring attention, from this excellent blog.
With these simplifications, the attention outputs $z _{ s d }$ given queries $ q _{ s d }$ and similar keys and values is:
\begin{align}
z_{ s d} &= \texttt{Soft}_{ s' } \left ( q_{ s d' } k_{ s' d' } \right ) v_{ s' d } \nonumber\\
&=\frac{\exp \left ( q_{ s d' } k_{ s' d' } \right )v_{ s' d }}{Z_{ s }} \quad \textrm{where} \quad Z _{ s } = \texttt{sum} _{ s' } \exp \left ( q _{ s d' } k _{ s' d' } \right ) \nonumber \ .
\end{align}
Shard the sequence dimension over $R$ ranks:
\begin{align}
z_{ r_{ 0 }t d} &=\frac{\exp \left ( q_{ r_{ 0 } t d' } k_{ r'_{ 0 }u d' } \right )v_{ r'_{ 0 }u d }}{Z_{ r_{ 0 }t }} \nonumber
\end{align}
Both the numerator and denominator involve a sum over ranks ($r' _{ 0 }$), and a natural iterative algorithm to perform this sum is as follows:
- Rank $r _{ 0 }$ initializes a numerator $N _{ r _{ 0 } t d }$ and denominator $Z _{ r _{ 0 }t }$ aggregator to zero.
- Over the course of $R$ rounds, pass the $k _{ r' _{ 0 } u d }$ and $q _{ r' _{ 0 } u d }$ around in a ring so they land on every rank $r _{ 0 }$.
- In each round, perform these updates (note: no sum over repeated $r'_{ 0 }$ here):
Z _{ r _{ 0 }t } := Z _{ r _{ 0 }t } + \texttt{sum}_{ u }\exp \left ( q_{ r_{ 0 } t d' } k _{ r'_{ 0 } u d' } \right ) \nonumber\\
N _{ r _{ 0 }td } := N _{ r _{ 0 }td } + \exp \left ( q_{ r_{ 0 } t d' } k_{ r'_{ 0 }u d' } \right )v_{ r'_{ 0 }u d } \nonumber
The final output is then
\begin{align}
z_{ r_{ 0 }t d} &=\frac{N _{ r _{ 0 }td }}{Z_{ r_{ 0 }t }} \nonumber
\end{align}
While this example is not as clean as some previous ones, in terms of using only all-reduce or reduce-scatter collectives or having intermediates which fit neatly into the Replicate/Shard/Partial placement taxonomy, I still find the use of the explicit but relatively-uncluttered Einstein notation extremely helpful for precisely describing the algorithm and understanding its correctness.
Limitations
I close with a partial list of this notation's shortcomings:
- The ultra-explicit $a \longrightarrow r _{ 0}b$ type splitting, with $r _{ 0 }$ indexing devices on the mesh's 0-dimension, can be overly constraining. It assumes perfect divisibility and can't naturally describe generic all-to-all patterns, like the dynamic, generically-unequal sharding of tokens in Expert Parallel MoE routing6. Also, it creates a coupling between indices (need to recall that $b$ goes with $r _{ 0 }$ in the present example), which can be a mental burden to track. In future posts I will probably alternate between this very explicit notation and the simpler $a \longrightarrow a _{ 0 }$, sharding-pattern-implicit notation depending on needs and clarity.
- Some relevant sharding patterns which do work for nicely divisible tensors are still not easily expressible in terms of this notation. An example is the optimal zig-zag sharding strategy for causal Ring Attention, which does not correspond to a simple row-major style sharding of indices.
- Some operations just don't fit this notation very nicely. For example, convolving an input $x _{ d }$ with a filter $K _{ w }$ would be written as something like $z _{ e } = x _{ e - w }K _{ w }$ with all padding, strides, etc. left implicit? Not great.
The prevalence of primes to distinguish indices over equal ranges is unfortunate, but I have not seen a clearer way to indicate semantically-related dimensions. Very open to suggestions. ↩︎
A similar treatment can be found in this old blog. ↩︎
Though this is a case where the less-explicit $W_{ a b } \longrightarrow W_{ a_{ 0 } b_{ 1 } }$ notation is probably even clearer. ↩︎
The converse is legal, however: a single logical tensor axis can be sharded over multiple mesh axes. ↩︎
In general,
Partialcan be associated with any of the common reduction ops, but I only discussPartial(sum). ↩︎In such scenarios, the more implicit $a \longrightarrow a _{ 0 }$ notation is likely a better choice, and
Partialplacements could be indicated by, say, placing the subscript on the tensor itself, as in $z _{ 0 } = x _{ a _{ 0 } } y _{ a _{ 0 } }$ ↩︎