Introduction
This post focuses on backwards-pass communication patterns in distributed training setups. I cover the backwards of DDP, FSDP, Tensor-Parallel (an often-confusing topic), Expert-Parallel, and more, and argue that they can all be analyzed in essentially the same way.
The math is phrased in terms of the previously-introduced Sharded Einstein Notation, which promotes concise expressions, and I also use the language of torch's DTensor throughout.
Review: Sharded Einstein Notation
First, a very brief review of the previous post.
Einstein notation/einsum leverages the high information-density of tensor index structures. Examples:
- Dot product: $z = x _{ a } y _{ a }$
- Matmul: $C _{ m n } = A _{ m k } B _{ k n }$
- Element-wise product, $z _{ a } = x _{ a } y _{ a }$, and non-linearity, $z _{ a } = \phi(x _{ a })$
Comparing indices across an equation, the meaning of the operation follows.
Extending this notation to include sharding information is extremely helpful. In the general case, the sharding occurs across a multi-dimensional device mesh, and in the preceding post I proposed adding index subscripts to specify the mesh axis (or axes) over which a given tensor dimension is sharded. Two notations:
- Implicit: attach the subscript and leave the exact sharding pattern implicit, as in $M _{ a b } \longrightarrow M _{ a _{ 0 } b _{ 1 } } $ for sharding a 2D
(A, B)-shaped tensor across a 2D mesh. - Explicit: divide a sharded index into multiple indices with device indices explicit. Sharding the above matrix across only one mesh axis comprised of
Rdevices, this would look like $M _{ a b } \longrightarrow M _{ r _{ 0 } c b }$,r in range(R)andc in range(B // R).
I try to err on the side of being explicit, but use the implicit notation when drowning in indices.
An important rule: a tensor can never have two indices sharded over the same mesh axis, so $M _{ a _{ 0 } b _{ 0 } }$ (implicit) or $M _{ r _{ 0 } c r' _{ 0 }d }$ (explicit) is illegal.
Moving and removing subscripts generically implies communication costs. Some representative operations:
| Comms Pattern | Implicit | Explicit |
|---|---|---|
| all-reduce | $z = x _{ a _{ 0 }} y _{a _{ 0 }}$ | $z = x _{r _{ 0 } b} y _{r _{ 0 } b}$ |
| all-gather | $ x _{a _{ 0 }} \longrightarrow x _{ a } $ | $ x _{r _{ 0 } b} \longrightarrow x _{ a } $ |
| reduce-scatter | $z _{ a _{ 0 } } = x _{ a b _{ 0 } } y _{ b _{ 0 } }$ | $z _{r _{ 0 } c } = x _{ r c r' _{ 0 } d } y _{ r' _{ 0 } d }$ |
| all-to-all | $x _{ a b _{ 0 } } \longrightarrow x _{ a _{ 0 } b}$ | $x _{ r c r' _{ 0 }d } \longrightarrow x _{ r _{ 0 } c r' d }$ |
Review: The Chain Rule
I also review the multivariate chain rule, mostly for setting notation.
The loss $L$ is a scalar which depends on various tensors. Consider a generic tensor $x _{ a b c }$ which is used (along with other tensors, typically) to compute some output $z _{ d e f }$. The relation between input and output indices is left unspecified for now. Then the chain-rule for the derivative of the loss with respect to $x _{ a b c }$ is:
\frac{\partial L}{\partial x _{ a b c }} = \frac{\partial z _{ d e f }}{\partial x _{ a b c }} \frac{\partial L}{\partial z _{ d e f}} \ ,
sum over $ d, e, f$ indices implied, as usual. If the input $x _{ a b c }$ is used in multiple locations in the forward graph, say it's also directly involved in creating some $y _{ g h i }$, then the terms sum:
\frac{\partial L}{\partial x _{ a b c }} = \frac{\partial z _{ d e f }}{\partial x _{ a b c }} \frac{\partial L}{\partial z _{ d e f}} + \frac{\partial y _{g h i}}{\partial x _{ a b c }} \frac{\partial L}{\partial y _{g h i}} \ ,
with the general case proceeding similarly.
Derivatives of the loss with respect to a tensor have the same shape as the original tensor, and to condense notation slightly I will denote loss gradients as in:
\frac{\partial L}{\partial x _{ a b c }} = g [x] _{ a b c } \ ,
so that the basic chain rule above reads
g [x] _{ a b c } = \frac{\partial z _{ d e f }}{\partial x _{ a b c }} g[z] _{ d e f} \ .
General Analysis
When a tensor $x _{ a b c }$ is used to create a generic output $z _{ d e f }$, the output indices fall into two classes:
- Output indices which also appear on the input, i.e. those which passed through the computation untouched1.
- All other output indices.
When computing derivatives of the input, $g[x] _{ a b c }$, from the upstream gradient, $g[z] _{ d e f }$, only indices belonging to the second class induce non-trivial sums in the gradient computation. When the second class also includes sharded indices, these sums involve communication.
The above is sufficient for understanding where backwards communication is required2 in many standard parallelization techniques. This is demonstrated through various examples below.
Example: Activation Matmuls
First, an example without sharding: consider a (batch_size, seqlen, dim) shaped activation $x _{ b s d }$ and act on the hidden
dimension with a weight $W _{ e d }$
z _{ b s e } = W _{ e d } x _{ b s d } \ .
The weight and activation gradients are
\begin{align}
g[W]_{ e d } &= \frac{ \partial z _{ b s e } }{ \partial W _{ ed } }\ g[z]_{ b s e }= x_{ b s d }\ g[z]_{ b s e }\nonumber\\
g[x]_{ b s d} &= \frac{ \partial z _{ b s e } }{ \partial x _{ bsd } }\ g[z]_{ b s e } = W_{e d}\ g[z]_{ b s e }\nonumber \ ,
\end{align}
which illustrates the general principle of where summation is required:
- $W _{ e d }$: only the $e$-index appears in the output, and so there are non-trivial sums over the remaining output indices $b, s$.
- $x _{ b s d }$ carries all output-indices except for $e$, so there's only a non-trivial $e$-sum.
Rephrased: indices on the upstream gradient must get converted into those of the downstream gradient, and any mismatched indices get removed through summation.
Sharded Examples
For the remainder of the post, I apply the above to different distributed examples.
Distributed Data Parallel
In DDP the inputs are sharded on the batch dimension: $x _{ b s d } \longrightarrow x _{ r _{ 0 } l s d}$, r in range(ddp_degree) and l in range(local_batch_size) = range(batch_size // ddp_degree). The forwards computation is embarrassingly parallel across the batch dimension, and all intermediates also carry the $r _{ 0 }l$ sharded batch indices. However, weights are not sharded at all: they look like $W _{ d e }$. It follows that all intermediate computations take the form
z _{ r_{ 0 } l f} = F(y _{ r_{ 0 } l s d }, W _{ d e }, \ldots ) \ ,
where more tensors may also be involved and the output index $f$ is intended as generic (it may correspond to multiple, semantically distinct dimensions). In words: the device and local batch index will generically pass through every local computation unchanged, but other indices may get transformed.
So, what backwards computations will require collectives? From the general analysis:
- Because weight gradients don't carry the output's device-index $r _{ 0 }$, a sum over $r _{ 0 }$ is required in the gradient computation, i.e. a collective.
- Because inputs and all activations carry the $r_{ 0 }$ device-index, their gradient computations never require comms.
More explicitly, from the general considerations above the common DDP chain rule computations are of the form:
\begin{align}
g[W]_{ e d} &= \frac{ \partial z _{ r_{ 0 } l f } }{ \partial W _{e d} } g[z]_{ r_{ 0 } l f }\nonumber\\
g[y]_{ r_{ 0 } l s d } &= \frac{ \partial z _{ r_{ 0 } l f } }{ \partial y _{ r _{ 0 } l s d } } g[z]_{ r_{ 0 } l f }\nonumber
\end{align}
and the sum over $r_{ 0 }$ in the first expression indicates an all-reduce. This is the usual DDP behavior: all ranks train independently, apart from an all-reduce for averaging weight-gradients.
Precisely when the all-reduce should be computed will be discussed later.
Fully Sharded Data Parallel
FSDP is essentially DDP with redundancy removal and increased communication. The backwards communication patterns are similar to DDP: collectives are needed for weight-gradient computation, but not for activation-gradients. The primary difference is the form of the collective: FSDP reduce-scatters instead of all-reducing.
In FSDP, both inputs and weights are sharded on the same device mesh axis3. Immediately prior to computation, the weights are unsharded4, such that the local computations appear exactly as in the DDP section above. A primary difference is that rather than Replicate-ing (in DTensor jargon) the weight gradients as in $g[W] _{ e d}$, they are sharded across the FSDP mesh axis: $g[W] _{ e d} \longrightarrow g[W] _{ r _{ 0 }h d}$ with r in range(fsdp_degree). The weight-gradient computation is then arranged as
g[W] _{ r _{ 0 }h d} = \frac{ \partial z _{ r' _{ 0 } c f } }{ \partial W _{r h d} } g[z] _{ r' _{ 0 } c f } \ ,
which is physically a reduce-scatter5, since the results of the cross-device sum on the right are stored in a sharded fashion across devices on the left.
Tensor-Parallel
A quick review: the usual MLP computation in Einstein notation looks like
z_{ bsd } = W^{ 1 }_{ de }\phi \left ( W^{ 0 }_{ ed' } x_{ bsd' } \right ) \ .
and in the Tensor-Parallel version the intermediate $e$-dimension is sharded across devices
z_{ bsd } = W^{ 1 }_{ dt _{0}f }\phi \left ( W^{ 0 }_{ t _{0}fd' } x_{ bsd' } \right ) \ .
with t in range(tp_degree) the device index.
The classic TP diagram from NVIDIA
A rite of passage in understanding parallelized ML training is coming to terms with the Megatron-style TP implementation which utilizes the _CopyToModelParallelRegion custom autograd function, which (simplified slightly) looks like
class _CopyToModelParallelRegion(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, group):
ctx.group = group
return input_
@staticmethod
def backward(ctx, grad_output):
return all_reduce(grad_output, ctx.group), None
The above does absolutely nothing in the forward pass, but all-reduces in the backwards, and in the TP MLP layer it is only used to wrap the initial input $x_{ bsd' }$. Many common confusions ensue:
- Why is this needed at all?
- Why does it only wrap the input tensor?
- When else is this needed?
All of these are answered by the general rule. Tracing through the TP MLP specifically, it's only the initial computation using the input $x _{ bsd' }$ which produces an output with a sharded-index that wasn't present on an input tensor, so only the backwards through this operation requires comms.
Index analysis:
- Initial matmul: $y _{b s t _{ 0 } f } = W^{ 0 } _{ t _{0}fd' } x _{ bsd' }$.
- $g[x] _{ bsd' }$ requires an all-reduce6 since it does not carry the sharded $t _{ 0 }$ index of the output.
- $g[W^{ 0 }] _{ t _{0}fd' }$ carries $t _{ 0 }$ and thus is collective-free.
- Nonlinearity: $v _{b s t _{ 0 } f } = \phi(y _{b s t _{ 0 } f })$. Indices trivially match; no communication required for $g[y] _{b s t _{ 0 } f }$ .
- Final matmul: $z _{b s d } = W^{ 1 } _{ dt _{0}f }v _{b s t _{ 0 } f }$. No sharded index on the output at all, so no sharded-sum and thus no comms for computing $g[W^{ 1 }] _{ dt _{0}f }$ or $g[v] _{b s t _{ 0 } f }$.
So, the same rule governing DDP and FSDP backwards collectives also dictates those of TP. Analogous patterns and reasoning apply to TP attention layers where attention heads are sharded across devices.
Sequence-Parallel
Now consider Sequence-Parallel, in which the TP inputs and outputs of the previous section are sharded across their sequence-dimension using the same TP device mesh axis.
The classic SP diagram from NVIDIA
In order to avoid a flurry of indices, I will switch to more implicit sharding notation where I still add a subscript to denote a sharded index, but no longer explicitly split that index in two. For instance, the TP weight sharding becomes $W^{ 0 } _{ e d } \longrightarrow W^{ 0 } _{ e _{ 0 }d }$ and the entire computation is
\begin{align}
z_{ bs _{ 0 }d } &= W^{ 1 }_{ de _{0} }\phi \left ( W^{ 0 }_{ e _{0}d' } x_{ bs _{ 0 }d' } \right ) \nonumber\\
&\equiv W^{ 1 }_{ de _{0} }\phi \left ( y _{ b s e_{ 0 } } \right ) \nonumber\\
&\equiv W^{ 1 }_{ de _{0} }v _{ b s e _{ 0 } } \nonumber
\end{align}
where in the second step the sharded $s _{ 0 }$ index on $x _{ b s _{ 0 } d }$ is all-gathered prior to the $W^{ 0 } _{ e _{ 0 }d }$ matmul to avoid an illegal double-0-subscript, and the final matmul is completed via a reduce-scatter.
What collectives are needed in the backwards? Starting with the output-weight gradients ($g[W^{ 1 }] _{ d e _{ 0 } }$) and comparing input ($W^{ 1 } _{ d e _{ 0 } }$) and output ($z _{ b s _{ 0 } d }$) indices, it's immediately seen that a sum over $s _{ 0 }$ is required and collectives are needed. The explicit chain-rule reads
\begin{align}
g[W^{ 1 }]_{ d e_{ 0 } } & = \frac{\partial z _{ b s_{ 0 } d } }{\partial W^{ 1 } _{ d e_{ 0 } } } g[z]_{ b s _{ 0 } d } \nonumber\\
& = v _{ b s e_{ 0 } } g[z]_{ b s _{ 0 } d } \nonumber \ .
\end{align}
The $s$-sum between $v _{ b s e _{ 0 } }$ and $g[z] _{ b s _{ 0 } d }$ (whose sequence-dimensions are Replicate and Shard(0), respectively) can be completed by unsharding the latter via an all-gather and performing the sum locally. This is generally how reduce-scatters are handled: their backward is an all-gather.
The backwards comm patterns for other gradient operations are similarly read off by inspecting their chain-rule index structure:
- $g[v] _{ b s e _{ 0 } } = W^{ 1 } _{ d e _{ 0 } } g[z] _{ b s _{ 0 } d }$: all-gather the upstream $z$-gradient to compute the replicated $v$-gradients (comms amortized with the $W^{ 1 }$ grad computation above).
- $g[y] _{ b s e _{ 0 } } = \phi' \left ( y _{ b s e _{ 0 } }\right ) g[v] _{ b s e _{ 0 } }$: entirely element-wise, no comms.
- $g[W^{ 0 } ] _{ e _{ 0 } d' } = x _{ b s _{ 0 }d }\ g[y] _{ b s e _{ 0 } }$: all-gather the sharded inputs prior to computation (or use cached, unsharded inputs from the forward).
- $g[x ] _{ bs _{ 0 } d } = W _{ e _{ 0 } d }\ g[y] _{ b s e _{ 0 } }$: the distributed $e _{ 0 }$-sum is reduce-scattered into the sequence-sharded $x$-gradients.
The explicit form of the Jacobian factors was written out in the above, but the same conclusions can be reached without computing them directly; only index structure analysis is required.
Rules for DTensors
How does this translate into the language of DTensor?
Assume the sharding patterns are simple and every tensor is either Replicate or Shard(i), for some i, along every mesh dimension. This information is stored in the placements property of every Dtensor:
DTensor.placements: tuple[Replicate | Shard, ...]
Then for an operation mapping n inputs to an output, communication is certainly required for computing an input's gradient if a Shard in the output placement isn't reflected in an equivalent Shard on the input placement. A code sketch:
def must_have_bwd_comms(
inputs: tuple[DTensor, ...],
output: DTensor,
) -> list[bool]:
"""
Assumptions:
- Only one output
- Only Replicate/Shard placements
- All inputs and outputs sharded over the same DeviceMesh
"""
needs_bwd_comms = [False for _ in inputs]
for out_p, *inputs_p in zip(
output.placements,
*(input_t.placements for input_t in inputs),
):
if isinstance(out_p, Shard):
for input_idx, maybe_shard in enumerate(inputs_p):
if not needs_bwd_comms[input_idx] and maybe_shard != out_p:
needs_bwd_comms[input_idx] = True
return needs_bwd_comms
Note that this is not exhaustive. The above check only covers the criteria I focus on in this post and each input gradient may require communication for additional reasons. Examples:
- If an input has a
Shardplacement that does not appear in the output, communication will also be needed for its gradient. - An operation can consume and produce DTensors which have matching
Replicate/Shardplacements while internally performing complicated communication with intermediates that do not conform to this sharding taxonomy. Ring attention is an illustrative example. Communication is also generically needed for the backwards in this case. This corresponds to communication occurring in the Jacobian factor.
These cases (and others) can also be efficiently analyzed using the sharded Einstein notation shorthand, but this post is not intended to be exhaustive, and I leave this to the future.
Expert-Parallel
In Expert-Parallel (EP) MoE, where do the backwards collectives occur? Index tracking again easily determines the answer.
A sketch of the EP MoE logic is as follows. The inputs are batch-sharded across the EP mesh axis: $x _{ b s d } \longrightarrow x _{ b _{ 0 } s d }$, again using the more implicit sharding notation for brevity. A router determines which experts the tokens are routed to, and a dispatch component performs the necessary all-to-all:
y _{ e_{ 0 } l d } = \texttt{dispatch}(x _{ b_{ 0 } s d }) \ .
Above, e in range(routed_experts) indexes the experts which are sharded across devices, per the subscript. The local token index $l$ is generically ragged, logically, meaning that different experts receive different numbers of tokens. The dispatch output has $k$-times as many tokens as its input, for top-$k$ MoE. MoE MLP compute is performed in parallel across experts,
v _{ e_{ 0 } l d } = W^{ 1 }_{ e_{ 0 } d f }\phi \left ( W^{ 0 }_{ e_{ 0 } f d' }y _{ e_{ 0 } l d' }\right) \ ,
and the results are sent back to their original devices via a combine component:
c_{ b_{ 0 } k s d } = \texttt{combine}(v _{ e_{ 0 } l d })
where k in range(top_k). The final output comes from weighting the above with the tensor $w _{ b _{ 0 } k s }$ also produced by the router: $z _{ b _{ 0 }s d } = c _{ b _{ 0 } k s d }w _{ b _{ 0 } k s }$
The only locations where sharded indices differ across the sides of an equality are in the dispatch and combine mechanisms, so these are the only EP operations whose backwards require comms.
DDP + Context-Parallel
A final concrete example: consider combining DDP with Context-Parallel (CP) so that the inputs are doubly sharded
x _{ b s d } \longrightarrow x_{ r_{ 0 } l c_{ 1 } t d} \ .
Above, r in range(ddp_degree) and c in range(cp_degree) are device indices (returning to explicit, index-splitting notation), $l/t$ are local batch/sequence indices, and 0/1 subscripts are for DDP/CP mesh axes.
What will the communication patterns look like for computing the gradient for an unsharded weight $W_{ ab }$ in this setup? The chain rule structure is known on general grounds:
g[W]_{ a b } = \frac{ \partial y_{ { r_{ 0 } l c_{ 1 } t d} } }{ \partial W _{ ab } } g[y]_{ { r_{ 0 } l c_{ 1 } t d} } \ .
It follows that all-reduces over the batch ($r _{ 0 }$) and sequence ($c _{ 1 }$) indices are required for the weight-grad backwards, and an optimization is available here: rather than sequentially all-reducing on one mesh axis and then the other, the entire sum can be done in a single, 2d-mesh-wide all-reduce, which is generally more efficient. More on this kind of optimization in the next section.
Implementation Note: these considerations can lead to tricky design issues. CP often requires custom code, and a natural approach is to write the implementation so that it does the necessary all-reduces in a CP-only context, allowing for standalone unit-testing. But, if the impl is then used in a DDP + CP context, additional mechanisms are needed to either perform the DDP all-reduce or replace the CP all-reduce with a DDP+CP one. Alternatively, the implementation could delegate the all-reduces to external DDP-like APIs and rely on the user to apply these wrappings appropriately. Both options are demanding of the user and error-prone.
Partial and Backwards
Cartoon forward-pass DAG for input x and parameters W_i computing a scalar loss L. Circles
are ops.
I will end this post on a general discussion of the use of Partial placements in sharded backwards passes, which closely relates to the preceding sections.
Partial placements don't arise much in the forward pass (with one common exception, discussed below), but occur all the time in sharded backwards. The central reason is that Partial tensors are very limited, in that the vast majority of operations require that the partial operation be completed before being able to compute with the tensor. But this is not true for the common operations on weight gradients.
A brief review: Partial placements correspond to pending cross-device operations. In practice, this occurs most often in all-reduces where the local sums have been performed, but the cross-device ones have not:
x_{ r_{ 0 } } = y _{ r_{ 0 } b } z _{ r _{ 0 } b }
There's not much we can do with such tensors due to the pending $\texttt{sum} _{ r _{ 0 } }$ op. Reason: for a general operation $z = F(x, y, \ldots )$, indices suppressed, it's usually the case that
F(\texttt{sum} _{ r_{ 0 } } x _{ r _{ 0 } }, y, \ldots ) \neq \texttt{sum} _{ r_{ 0 } }F( x _{ r _{ 0 } }, y, \ldots ) \ ,
i.e. it's not generally possible to delay the sum. The partial must be completed immediately.
The main case where Partial(sum) tensors can be used directly without immediate communication is when the operation is (of course) also a sum and the input tensors all have identical placements7, e.g. $z _{ r _{ 0 } c d } = x _{ r _{ 0 } c d } + y _{ r _{ 0 } c d }$ with $r _{ 0 }$ the pending device index and $c, d$ some replicated indices. Otherwise, Partial tensors can only be kept in this state when it's not involved in any more imminent operations.
The only place this occurs regularly in the forwards is the loss computation, whose final operation is typically a logical sum (mean) over a sharded batch index. This is the only sink in the typical forward-pass DAG; see the figure above. Except for aggregating statistics across devices, there's no need to complete the cross-device sum on the loss, and the tensor can remain a Partial.
But in the backwards-pass DAG there are sinks everywhere, one for every parameter. As seen above, weight gradient computations commonly involve all-reduces and hence can be put into a Partial state. When a weight is only used a single time, as in W_{0,1,3} in the figure, there is only a single computation per weight gradient (one incoming edge) and one can manifestly delay completing its Partial until the time suits. When a weight is used multiple times, as in W_2 in the figure, multiple computations contribute to its gradient. But because these computations sum, as in the chain-rule review, the multi-Partial operations are also of precisely the right form to avoid immediate communication requirements8. In DDP and FSDP this property is used to delay and bucket the pending all-reduces into a condensed number of collectives, which are overlapped with other backwards compute, for efficiency. Activation gradients generally cannot be profitably held in long-lived Partial states, as they are usually needed for immediate downstream computation.
More precisely, these are indices for which the Jacobian is proportional to a Kronecker delta function, e.g. $\frac{\partial z _{ d e f }}{\partial x _{ a b c }}\propto \delta _{ a d }$. ↩︎
The backwards may also require additional comms beyond those dictated by this criteria. The conditions discussed here are general, but not exhaustive. ↩︎
I'm restricting to the simplest scenario for simplicity. It's also possible to use different meshes for input and weight sharding, at the cost of additional communication. ↩︎
This can be seen as a consequence of the double-subscript rule: if the weights were left unsharded during computation, the putative outputs would have multiple indices sharded on the same mesh axis, which is illegal/nonsensical. ↩︎
torchimplementation note: the reduce-scatter and all-reduce operations in FSDP1/2 and DDP all default toavgops rather than sums, effectively assuming that the loss function is also performing a mean over the batch dimension rather than a sum. This default can cause many headaches: seetorchtitan's #1551 and #2206 for related work. ↩︎Technically, the chain-rule sum could also be performed by a reduce-scatter, but because downstream computations will immediately require the fully-replicated $g[x] _{ b s d' }$, an all-reduce is the way to go. ↩︎
When placements aren't identical, it's still possible to hack together operations that work. Example: on a 1D mesh with
Rdevices, a partial $x _{ r _{ 0 } }$ cannot be directly added to a replicated $y _{ a }$, because the local-then-global-sum ordering would produce an incorrect tensor $\texttt{sum} _{ r _{ 0 } } \left( x _{ r _{ 0 } } + y _{ a }\right) \neq \texttt{sum} _{ r _{ 0 } } \left( x _{ r _{ 0 } } \right)+ y _{ a }$, overcounting the elements of $y _{ a }$R-times. But a valid implementation would be to first zero-out the elements of $x _{ a }$ on all but one device and then complete the distributed $r _{ 0 }$ sum. Kinda gross, though. ↩︎Also see Edward Yang's posts on similar topics, such as here, where it's argued that
Partialgrad placements are the natural default. ↩︎