Introduction
Einstein's summation convention, alternatively known as einsum in torch and jax, is particularly powerful for reasoning through distributed ML strategies. In this post, I review this notation and introduce some non-standard additions that I find quite helpful when thinking about sharding and parallelization, connecting to torch concepts along the way1.
Basics
Einstein notation is maximally lazy: whenever possible, it avoids explicit summations symbols and operations are generally inferred through index structures. There is enough information in the indices to remove any ambiguity in the vast majority of cases.
A sum occurs when repeated indices only appear on one side of an equation:
\begin{align}
\textrm{dot product:}&\quad z = x_{ k } y_{ k } \nonumber\\
\textrm{matmul:}&\quad C_{ mn } = A_{ mk } B_{ kn } \nonumber
\end{align}
The name given to a summed-over index like $k$ is completely arbitrary (it's a "dummy-index"), but I attempt to make choices that carry some semantic or conventional meaning.
In contrast, for an element-wise vector product, the resulting tensor would carry the index:
\begin{align}
\textrm{elementwise product:}&\quad z_{ k } = x_{ k } y_{ k } \nonumber
\end{align}
The above is all relatively conventional. When operators act on specific indices, I like use some less-standard notation and explicitly indicate the dimension(s) on the operator itself. So, Softmax would be $p _{ s v } = \texttt{Soft} _{ v } z _{ s v }$, which results in a tensor obeying $\texttt{sum} _{ v } p _{ sv } = 1$.
Example: MLP Layers
A simple MLP layer provides a concrete example. In Einstein notation, the weights $W^{ 0 }_{ ed }, W^{ 1 } _{ de }$ and non-linearity $\phi$ act on a (batch_size, seqlen, dim)-shaped input $x _{ b s d}$ to produce:
z_{ bsd } = W^{ 1 }_{ de }\phi \left ( W^{ 0 }_{ ed' } x_{ bsd' } \right ) \ .
Explicitly, the sums occur over the $d'$ and $e$ indices which span over range(dim) and range(4 * dim), respectively, in the usual vanilla setup.
For a SwiGLU type activation function, the above would change to2
z_{ bsd } = W^{ 2 }_{ df } W^{ 1 }_{ f d'' }x_{ bsd'' }\phi \left ( W^{ 0 }_{ fd' } x_{ bsd' } \right ) \ .
Notation for Sharding
How can sharding be represented in this notation? Two conventions are useful:
- Notation for splitting up an index.
- Notation for indicating that an index's values are physically distributed across different devices.
The latter notation must also indicate exactly which devices a tensor is sharded over, as complex, multi-dimensional device meshes are often involved.
For the first goal, we take inspiration from the excellent einops package and denote decomposing one index into multiple as in $b \longrightarrow (r c)$, or at the tensor level:
x_{ bsd } \longrightarrow x_{ (rc)sd } \longrightarrow x_{ r c s d } \ ,
where I've removed the parentheses after it's clear which indices have been split.
If the above were a batch index, this might describe a DDP- or FSDP-like sharding pattern in which one splits up b in range(B) indices across r in range(R) devices (ranks), such that c in range(B // R) (perfect divisibility assumed). Index ordering is significant here and I use row-major throughout, so that every rank gets a contiguous slice of the batch dimension in the present example.
For the second goal, add a numerical subscript to sharded indices. The value corresponds to which particular group of devices the index is sharded over, assuming that the relevant groups of devices have already been enumerated. So, $x_{ a_{ 0 } \ldots }$ indicates that no single device owns all of $a$-index values for this logical tensor, and they're instead spread out over one axis of a particular mesh of devices.
Simple sharding example. The 16-element logical tensor, indexed by $a$, is sharded across four devices. Two indices describe the sharding: $r _{ 0 }$ which indexes the device and $b$ which indexes the locally-available tensor values. Both indices takes values in range(4), here. The zero subscript indicates which axis of the device mesh the given index is sharded over, information which is crucial in more complex examples.
Note that the first rule isn't even really needed! Just attaching subscripts, like $x _{ a } \longrightarrow x _{ a _{ 0 } }$ is enough to indicate sharding, though this leaves the precise sharding pattern implicit. But for clarity and technical reasons, in this post I will mostly be as explicit as possible: when sharding an index a in range(A) across $R$-devices, I will split it into two with one being a device index, as in $a \longrightarrow r b$ with r in range(R) and b in range(A // B). I will only attach the sharding subscript to the device index, $r \longrightarrow r _{ 0 }$, since it is the only index whose range is not held on a single device. So, $x _{ a } \longrightarrow x _{ r _{ 0 } b }$. This is illustrated in the figure above.
The rest of this post is dedicated to examples which illustrate the notation and its utility.
Examples
I will mostly focus on the forward-pass, as I plan to discuss gradient considerations (where this notation especially shines) in a later post.
Distributed Data Parallel
In the DDP-scenario described above with $R$ total devices and tensor $x_{ r c s d }$, each rank will only own a single $r$-index value, but holds all other index values, e.g. all of c in range(batch_size // ddp_degree), s in range(seqlen), etc. Then, letting $0$ correspond to the world group, the tensor would written as3
x_{ r c s d } \longrightarrow x_{ r_{ 0 } c s d } \ .
In a future post, I will discuss how the DDP gradient-sum across devices naturally follows from the fact that local compute operates on all the non-$r _{ 0 }$-indices and leave $r _{ 0 }$ hanging in a sense which will be made precise.
All-Reduce and Reduce-Scatter
All-reduce illustration, from the NCCL docs.
Let's do a sum over a sharded dimension. Let the logical tensors be $x_{ d }, y_{ d }$ and indicate sharding as above: $x_{ d } \longrightarrow x_{ r _{0}e }$ and similar for $y$. Then, the notation makes it clear at a glance that for the dot product,
z = x_{ r _{0} e }y_{ r _{0}e } \quad \textrm{(all-reduce)} \ ,
the sum over the $e$ index can be performed locally, but the $r _{0}$-sum involves tensors on different devices and hence necessitates cross-device communication. That is, an all-reduce is required.
Reduce-scatter illustration, again from the NCCL docs.
Reduce-scatter operations can be similarly expressed, in which the result of the sum is sharded rather than replicated. A second tensor dimension is required in this case. Let the logical global operation be a matrix-vector dot product: $z_{ d } = W_{ de } x_{ e }$. Now shard the reduction dimension, $e \longrightarrow r _{0} f $. To complete the sum, our options are either to all-reduce and replicate the result across every device like above,
z_{ d } = W_{ d r _{0} f } x_{ r _{0}f } \quad \textrm{(all-reduce)} \ ,
or to also split the output dimension on the same set of devices, $d \longrightarrow r c$, and realize the output non-redundantly across ranks,
z_{ r _{0}c } = W_{ r c r _{0}' f } x_{ r _{0}'f } \quad \textrm{(reduce-scatter)} \ .
Note that index subscripts do not need to match across sidess! Primitive operations can move subscript placement. Because the weight's output index is not physically sharded, its unprimed $r$ has no subscript, but in the process of performing the right-hand-side sums, one can store the entire logical result in a sharded manner across devices4 (hence the subscript on the left side's $r$). More on the rules for subscripts later.
These patterns will be seen again in the Tensor- and Sequence-Parallel examples below. And a general lesson is that any operation which moves or removes a sharding subscript implies the need for communication.
Sometimes it is useful, though a little inelegant, to write communication as an operator that acts on individual indices, as in $x = \texttt{AllReduce}_{ r }(x _{ r })$.
Tensor-Parallel
Tensor-Parallel5 (TP) MLP layers just shard the intermediate reduction dimension across devices: $e \longrightarrow r _{0}f$ in the MLP expression above,
z_{ bsd } = W^{ 1 }_{ dr _{0}f }\phi \left ( W^{ 0 }_{ r _{0}fd' } x_{ bsd' } \right ) \ .
The first MLP weight is sharded across its output dimension (ColwiseParallel in torch lingo) and the second across its input dimension (RowwiseParallel, similarly). An all-reduce is needed to complete the sum over the logical intermediate index after the second matmul, due to the sharded $r_{ 0 }$.
The above decomposition is also a concise proof of correctness for the usual TP sharding strategy, which I personally find easier to follow than the more common proof-by-drawing approach.
FSDP + TP Activations
For an example of a tensor sharded over multiple dimensions, take a FSDP + TP scenario with 0 and 1 being the indices for these two groups. The logical intermediate, TP-sharded activation tensor $z _{ b s d }$ gets sharded into $z _{ f _{ 0 }c s t _{ 1 } e }$ where:
f in range(fsdp_degree)c in range(global_batch_size // fsdp_degree)s in range(seqlen)t in range(tp_degree)e in range(intermediate_dim // tp_degree)
Obviously, this is an unfortunate number of indices, but that mostly just reflects the complexities of sharding multi-dimensional tensors over multiple submeshes.
Shard, Replicate, and Partial
DTensor is the modern pytorch API used to describe distributed tensors. Each DTensor is associated with an n-dimensional mesh (array) of devices, and each mesh axis is associated with a placement that describes the state of the tensor relative to that axis, in ways I will describe. All of these ideas map very cleanly onto the present notation.
The three most fundamental placements are Replicate, Shard, and Partial. The first two are
relatively self-explanatory:
Replicate: the tensor is replicated across this dimension.Shard: one dimension of the tensor (which must be specified) is distributed across this axis.
These can be illustrated minimally with four GPUs. Dividing them into a 2x2 device mesh and taking 2D tensor $W_{ ab }$, here are some index shardings and their corresponding pytorch 2-tuple of placements:
- $W_{ ab }$:
(Replicate, Replicate). No sharding: every device gets the original tensor. - $W_{ ab } \longrightarrow W_{ a r_{ 0 } d}$:
(Shard(1), Replicate), distribute the final tensor dimension across the leading device mesh axis, keeping these shards constant across the final mesh axis. Every device gets half of the original tensor. - $W_{ ab } \longrightarrow W_{ r_{ 0 } c r '_{ 1 } d }$:
(Shard(0), Shard(1)), distribute the leading tensor dimension across the leading device mesh axis, and similar for the final tensor and device axis. Every device gets a quarter of the original tensor.
These, and more, are illustrated in the figure below. The correspondence between Shard, Replicate, and the sharded index notation is hopefully clear6.
Note that constraints exist: an expression like $W_{ ab } \longrightarrow W_{ r_{ 0 } c r'_{ 0 } d }$ where two indices are sharded over the same device mesh axis7 is nonsensical and illegal. This is reflected in the DTensor API: an n-dimensional device mesh requires an n-tuple of placements (only a single placement per axis!) to describe a DTensor. So, another rule: a tensor can never have two indices with the same sharding subscript. This requirement will come up in the Sequence-Parallel discussion below.
Various ways of splitting a simple 2x2 tensor over a 2x2 device mesh.
What about Partial? This denotes pending operations8. Return to the distributed dot-product and consider the intermediate which arises after only performing the local sums:
z_{ r_0 } = x_{ r _{0} c }y_{ r _{0}c } \ .
Being a partial result, this temporary naturally has a Partial placement with respect to the 0-dimension of the mesh. It's only after performing the 0-mesh-dim all-reduce that one gets the full result: $z = \texttt{AllReduce}_{ r _{ 0 } }(z _{ r _{ 0 } })$, which has a Replicate placement. In a more general example, the Partial could also be completed via a reduce-scatter, resulting in a Shard placement.
Therefore, dangling device-mesh indices with sharding subscripts indicate Partial placements.
In the forward, distributed reductions are generally completed immediately so that there are no long-lived Partial tensors, but in the backwards it can be highly beneficial to leave various terms in a Partial state and judiciously delay completing the pending operation. This will be discussed more elsewhere, but for now weight gradients in DDP and FSDP provide a good example: weight gradients live in a Partial state until the time is right to reduce them over the relevant mesh dimension, usually in a bucketed manner for efficiency.
Sequence-Parallel
Sequence-Parallelism (SP), in the NVIDIA sense, removes some of the redundancy present in TP due to the fact that the TP inputs and outputs are fully replicated.
In SP, the inputs and outputs are instead sharded over the sequence dimension, $s \longrightarrow r _{0}t$, reusing the TP mesh axis for the sharding. Making this additional replacement in the previous TP expression gives:
z_{ b r _{0}t d } = W^{ 1 }_{ dr _{0}'f }\phi \left ( W^{ 0 }_{ r _{0}'fd' } x_{ b r _{0}t d' } \right ) \ .
It is useful to step through the intermediates and consider the fine-grained operations that occur:
- The very first operation $W^{ 0 } _{ r _{0}'fd' } x _{ b r _{0}t d' }$ already has a potential issue. If one blindly performed the $d'$ sum and carried the other indices through untouched, this would naively result in $y _{ b r _{ 0 } t r' _{ 0 } f } = W^{ 0 } _{ r _{0}'fd' } x _{ b r _{0}t d' }$. But this breaks the rule that two tensor dimensions can't be sharded over the same mesh axis! The same subscript cannot appear twice, and the notational rules help catch a possible logic mistake. The minimal fix is to first remove one of the offending sharded subscripts (i.e., perform an all-gather), and the best option is to unshard the input tensor. So, a legal progression producing a valid output is:
W^{ 0 } _{ r _{0}'fd' } x _{ b r _{0}t d' } \xrightarrow{\texttt{AllGather}_{ r _{ 0 } }} W^{ 0 } _{ r _{0}'fd' } x _{ b rt d' } \xrightarrow{\rm TP\ matmul} y_{ b r t r'_{ 0 } f }
- The non-linearity is trivial: $v _{ b r t r' _{ 0 } f } = \phi(y _{ b r t r' _{ 0 } f })$. Elementwise ops are easy.
- The sums in the final matmul $W^{ 1 } _{ dr _{0}'f }v _{ b r t r' _{ 0 } f }$ eliminate the sharded subscript, and so it can legally be moved onto the output tensor without inducing a double-0-subscript: $z _{ b r _{ 0 } t d }= W^{ 1 } _{ dr _{0}'f }v _{ b r t r' _{ 0 } f }$. Said differently: reduce-scatter.
Ring Attention
Ring Attention is morally a distributed version of flash attention, in which inputs are sharded along the sequence dimension $s$ and the computation never realizes any intermediates which have a full seqlen-sized dimension.
Let's simplify a bit to keep the presentation short and minimize indices:
- Suppress the batch dimension and focus on a single head.
- Omit the causal mask.
- Ignore the (numerically crucial) maximum attention score tracking.
Illustration of ring attention, from this excellent blog.
With these simplifications, the attention outputs $z _{ s d }$ given queries $ q _{ s d }$ and similar keys and values is:
\begin{align}
z_{ s d} &= \texttt{Soft}_{ s' } \left ( q_{ s d' } k_{ s' d' } \right ) v_{ s' d } \nonumber\\
&=\frac{\exp \left ( q_{ s d' } k_{ s' d' } \right )v_{ s' d }}{Z_{ s }} \quad \textrm{where} \quad Z _{ s } = \texttt{sum} _{ s' } \exp \left ( q _{ s d' } k _{ s' d' } \right ) \nonumber \ .
\end{align}
Shard the sequence dimension over $R$ ranks:
\begin{align}
z_{ r_{ 0 }t d} &=\frac{\exp \left ( q_{ r_{ 0 } t d' } k_{ r'_{ 0 }u d' } \right )v_{ r'_{ 0 }u d }}{Z_{ r_{ 0 }t }} \nonumber
\end{align}
Both the numerator and denominator involve a sum over ranks ($r' _{ 0 }$), and a natural iterative algorithm to perform this sum is as follows:
- Rank $r _{ 0 }$ initializes a numerator $N _{ r _{ 0 } t d }$ and denominator $Z _{ r _{ 0 }t }$ aggregator to zero.
- Over the course of $R$ rounds, pass the $k _{ r' _{ 0 } u d }$ and $q _{ r' _{ 0 } u d }$ around in a ring so they land on every rank $r _{ 0 }$.
- In each round, perform these updates (note: no sum over repeated $r'_{ 0 }$ here):
Z _{ r _{ 0 }t } := Z _{ r _{ 0 }t } + \texttt{sum}_{ u }\exp \left ( q_{ r_{ 0 } t d' } k _{ r'_{ 0 } u d' } \right ) \nonumber\\
N _{ r _{ 0 }td } := N _{ r _{ 0 }td } + \exp \left ( q_{ r_{ 0 } t d' } k_{ r'_{ 0 }u d' } \right )v_{ r'_{ 0 }u d } \nonumber
The final output is then
\begin{align}
z_{ r_{ 0 }t d} &=\frac{N _{ r _{ 0 }td }}{Z_{ r_{ 0 }t }} \nonumber
\end{align}
While this example is not as clean as some previous ones, in terms of using only all-reduce or reduce-scatter collectives or having intermediates which fit neatly into the Replicate/Shard/Partial placement taxonomy, I still find the use of the explicit but relatively-uncluttered Einstein notation extremely helpful for precisely describing the algorithm and understanding its correctness.
Placement Interaction
How do tensors with different placements interact? Is communication needed before the compute can be carried out? What's the cheapest output placement? I find the present notation greatly helps answer these kinds of questions, too.
Take logical 2D tensors $x_{ ab }, y_{ ab }$ and a 2D mesh. The possibilities and computational flow are essentially entirely determined by the constraint that the output tensor never has two identical subscripts. I will temporarily use the more concise notation where sharding is denoted like $z _{ a } \longrightarrow z _{ a _{ 0 } }$ without also explicitly splitting the index. Various examples:
- $x _{ a _{ 0 } b } + y _{ a b }$: to complete this pointwise op, we need to logically coerce both tensors to the same sharding format. Whereas moving or removing a subscript always incurs communication costs, adding one does not: $ y _{ a b } \longrightarrow y _{ a _{ 0 } b }$ only involves discarding local information. Therefore, creating $z _{ a _{ 0 } b } = x _{ a _{ 0 } b } + y _{ a b }$ is the minimum-cost sharding-propagation pattern. Similar results hold for all pointwise ops.
- $x _{ a _{ 0 } b } + y _{ a b _{ 0 } }$: the double-subscript rule make it clear that communication is unavoidable, because $z _{ a _{ 0 } b _{ 0 } }$ is not a output option, and subscripts must be either moved or removed to complete the operation, incurring costs. An all-to-all is the cheapest option here: say, $y _{ a b _{ 0 } } \longrightarrow y _{ a _{ 0 } b } $ prior to the sum to create $z _{ a _{ 0 } b } = x _{ a _{ 0 } b } + y _{ a b _{ 0 } }$. Formalizing the costs of various subscript operations is straightforward. Operations involving tensors with different
Shards on the same mesh axis generically require communication. - Now plan out a sharded matmul, logically $z _{ m n } = x _{ m k } y _{ k n } $, where the memory-efficiency is the goal. What input and output sharding patterns are best? Clearly, the output should be doubly-sharded, so our end state is $z _{ m _{ 0 } n _{ 1 } }$. Moving subscripts around corresponds to costly comms, so to avoid that, the $m$-index on $x _{ m k }$ and the $n$-index on $y _{ k n }$ should be sharded on mesh dimensions 0 and 1, respectively. By the double-subscript rule, the other maximally memory-efficient $k$-index shardings follow. In conclusion, the compute should be set up as $z _{ m _{ 0 } n _{ 1 } } = x _{ m _{ 0 } k _{ 1 } } y _{ k _{ 0 } n _{ 1 } } $. The last step is to actually perform the $k$-sum. In general, there's room for cleverness9 in how distributed sums are performed and solutions exist which go beyond the naive strategy of fully replicating/all-gathering the sum axis on each device and performing the sums in one step, as in $x _{ m _{ 0 } k _{ 1 } } y _{ k _{ 0 } n _{ 1 } } \longrightarrow x _{ m _{ 0 } k } y _{ k n _{ 1 } } \longrightarrow z _{ m _{ 0 } n _{ 1 } }$. I like the present notation for allowing us to either go further, split up the distributed indices, and plan out a more detailed implementation (similar to the Ring Attention section above), or stop at the present level of detail and leave the implementation implicit.
More complex cases, including those involving Partial placements, can be similarly analyzed.
Limitations
I close with a partial list of this notation's shortcomings:
- The ultra-explicit $a \longrightarrow r _{ 0}b$ type splitting, with $r _{ 0 }$ indexing devices on the mesh's 0-dimension, can be overly constraining. It assumes perfect divisibility and can't naturally describe generic all-to-all patterns, like the dynamic, generically-unequal sharding of tokens in Expert Parallel MoE routing10. Also, it creates a coupling between indices (need to recall that $b$ goes with $r _{ 0 }$ in the present example), which can be a mental burden to track. In future posts I will probably alternate between this very explicit notation and the simpler $a \longrightarrow a _{ 0 }$, sharding-pattern-implicit notation depending on needs and clarity.
- Some relevant sharding patterns which do work for nicely divisible tensors are still not easily expressible in terms of this notation. An example is the optimal zig-zag sharding strategy for causal Ring Attention, which does not correspond to a simple row-major style sharding of indices.
- Some operations just don't fit this notation very nicely. For example, convolving an input $x _{ d }$ with a filter $K _{ w }$ would be written as something like $z _{ e } = x _{ e - w }K _{ w }$ with all padding, strides, etc. left implicit? Not great.
I was inspired to write up my current mental shorthand after reading Edward Yang's recent posts on similar topics. See also the relevant section in
jax's How To Scale Your Model. My notes also have an earlier form of this notation. ↩︎The prevalence of primes to distinguish indices over equal ranges is unfortunate, but I have not seen a clearer way to indicate semantically-related dimensions. Very open to suggestions. ↩︎
In
jax0-dimension or batch-dimension sharding is handled explicitly as described here, in the sense that it's not treated any differently from sharding over any other dimension. In contrast, batch-dimension sharding is special cased intorchand left implicit. ↩︎Since an all-gather just corresponds to unsharding, as in $x_{ r_{ 0 }f } \longrightarrow x_{ e }$, the fact that an all-reduce can be implemented as a reduce-scatter followed by an all-gather is also very clear in this notation. Another opration: moving a subscript on a tensor, $x _{ r _{ 0 } a r' b } \longrightarrow x _{ r a r' _{ 0 } b }$ is an all-to-all. ↩︎
A similar treatment can be found in this old blog. ↩︎
Though this is a case where the less-explicit $W_{ a b } \longrightarrow W_{ a_{ 0 } b_{ 1 } }$ notation is probably even clearer. ↩︎
The converse is legal, however: a single logical tensor axis can be sharded over multiple mesh axes. ↩︎
In general,
Partialcan be associated with any of the common reduction ops, but I only discussPartial(sum). ↩︎Examples include the GSPMD paper and the pytorch writeup of their async TP implementation. ↩︎
In such scenarios, the more implicit $a \longrightarrow a _{ 0 }$ notation is likely a better choice and
Partialplacements could be indicated by, say, placing the subscript on the tensor itself, as in $z _{ 0 } = x _{ a _{ 0 } } y _{ a _{ 0 } }$ or maybe $z _{ 0|ijkl }= x _{ a _{ 0 } ij } y _{ a _{ 0 }kl}$ is clearer in more complex cases. ↩︎