Accelerate ND-Parallel: A guide to Efficient Multi-GPU Training
Training large models across multiple GPUs can be challenging due to the complexities of different parallelism strategies. In Accelerate, together with Axolotl, we have integrated a quick and easy way to use any combination of parallelism strategies in your training script!
Here is how to add it to your training script:
from transformers import AutoModelForCausalLM
from accelerate import Accelerator
from accelerate.parallelism_config import ParallelismConfig
from accelerate.utils import FullyShardedDataParallelPlugin
# configure your desired parallelisms here - this particular configuration requires at least 2 nodes with 8 GPUs each.
# setting any parallelism degree to 1 disables it i.e. dp_replicate_size=1 disables DP.
pc = ParallelismConfig(
dp_shard_size=2, # Fully Sharded Data Parallel degree
dp_replicate_size=2, # Data Parallel degree
cp_size=2, # Context Parallel degree
tp_size=2, # Tensor Parallel degree
)
fsdp_plugin = FullyShardedDataParallelPlugin(
fsdp_version=2,
auto_wrap_policy="transformer_based_wrap",
transformer_cls_names_to_wrap=["LlamaDecoderLayer"],
state_dict_type="SHARDED_STATE_DICT",
)
accelerator = Accelerator(
parallelism_config=pc,
fsdp_plugin=fsdp_plugin
)
model = AutoModelForCausalLM.from_pretrained(
"NousResearch/Hermes-3-Llama-3.1-8B",
device_mesh=accelerator.torch_device_mesh
)
model = accelerator.prepare(model)
We've also included a more comprehensive end-to-end training script in the Accelerate repo which demonstrates how to setup your dataloader, optimizer, and training loop, and how to save your model after training.
To further streamline fine-tuning models at scale and compose parallelism strategies with a variety of fine-tuning techniques, we've also integrated this technique into Axolotl. To help you get started right away we've tested some example configs which you can modify to suit your needs - try one out with:
# note: this requires a minimum world size of 16
axolotl train examples/distributed-parallel/llama-3_1-8b-hsdp-tp.yaml
You can also check out the Axolotl ND-Parallelism docs for more details - adding ND parallel techniques to your existing configs is as simple as adding one or more of the following fields to your Axolotl config file:
# Fully Sharded Data Parallel degree (note: also requires the fsdp_config field)
# see https://docs.axolotl.ai/docs/multi-gpu.html#sec-fsdp for more details
dp_shard_size: 2
# Data Parallel degree
dp_replicate_size: 2
# Context Parallel Degree
context_parallel_size: 2
# Tensor Parallel Degree
tensor_parallel_size: 2
We've made it easy to configure the degrees of different parallelism strategies and how they are combined through the ParallelismConfig
class in Accelerate, or through config fields in Axolotl, but how do we know which configuration will work best for our use case? As we scale to training models with tens or even hundreds of billions of parameters, the primary challenge comes from understanding the different parallelism strategies and how they interact to minimise communication overhead across devices. In this post, we'll walk through how the different parallelism strategies work, and when and how you might want to compose them.
Contents
- Data Parallelism
- Fully Sharded Data Parallelism
- Tensor Parallelism
- Context Parallelism
- ND Parallelisms
- Usage Notes
Data Parallelism

Data parallelism (DP) is the most common technique for training models across multiple GPUs, and involves replicating the model, gradients and optimizer states across each device, whilst evenly distributing data batches between GPUs, and synchronising gradients across devices before updating parameters. This can significantly increase throughput compared to single-device training, but requires that your model is able to fit on a single device.
We can control the number of replicas of the model with the dp_replicate_size
parameter in Accelerate's ParallelismConfig
or config field in Axolotl. It's worth noting that DP is a top-most-level parallelism strategy, meaning that if we use dp_replicate_size=2
and we compose it with other parallelism strategies, there would be 2 replicas of the model, each also influenced by the other parallelism strategies. For example, if we use dp_replicate_size=2
and tp_size=2
, we would have 2 replicas of the model, each with 2 tensor parallel shards.
We use the term shard to describe data on a single device which is a partition of a larger piece of data.
Fully Sharded Data Parallelism

What if our model is too large to fit on a single device? Fully sharded data parallel (FSDP) addresses this issue by sharding (distributing evenly) the model’s weights, gradients, and optimizer states across GPUs (this is inspired by DeepSpeed’s ZeRO-3), whilst each device still receives its portion of the full batch of data. As you may notice from the diagram above, rather than requiring a full copy of the entire model on each device, we only gather the weights for a single layer at a time before the forward pass, after which the weights may be sharded again.
In this way, we trade memory usage for the communication overhead of gathering sharded parameters before each forward and backward pass, and reduce-scatter-ing local gradients. We can control this trade-off in FSDP by tuning the granularity at which parameters are gathered. On one extreme, we can gather and re-shard every layer of our model, which would result in the lowest peak memory usage, but incur the highest communication costs. In practice, a common approach is to gather the weights for an entire transformer decoder block at a time.
Whilst we can make further memory-compute trade-offs and offload model parameters and gradients to the CPU to train larger models, this can be prohibitively slow. Instead, let’s consider how we can effectively utilise even more devices to train larger models whilst maintaining high data throughput.
We use the term node to refer to a single machine which hosts multiple GPUs (up to a maximum of 8), with fast intra-node communication channels using e.g. NVLink between GPUs. When using multiple nodes for training, we rely on relatively slower inter-node communication channels between machines using e.g. Infiniband. We also refer to the total number of devices in the process pool as the world size - e.g. a single node with 8 GPUs represents a world size of 8, and 4 nodes would represent a world size of 32.
When using FSDP across multiple nodes, we treat the entire set of devices across nodes as if we were training on a single node. For example, with 4 nodes containing 8 GPUs each, we perform our sharding across 32 devices, and perform our collective all-reduce and reduce-scatter operations using both inter-and-intra-node communication backends. In this manner, FSDP alone can scale to a substantial number of GPUs with a large global batch size to increase data throughput. However, there comes a point where several challenges arise that may require composing FSDP with other parallelism techniques. We usually try to avoid doing FSDP across more than a full node, as the communication overhead can become too high, we'll talk about how to address this in the section on Hybrid Sharded Data Parallelism.
You can use the
dp_shard_size
parameter in Accelerate'sParallelismConfig
together with a preparedFullyShardedDataParallelPlugin
, or set thedp_shard_size
config field in Axolotl to set the degree of FSDP applied to your model.
Tensor Parallelism

Tensor Parallel (TP) is a kind of model parallelism technique, where shards of the model permanently live on separate devices, and in contrast to data parallel techniques, each device receives an identical batch of data. TP works by distributing the computation of linear layers across devices, so each device only computes a portion of the matrix multiplication. This technique works best when there are large linear layers, such as the feed-forward layers in transformer models, which can be split across devices. We can also use TP on each of the query, key, value, and output projections in the attention layers with almost no extra communication cost.
To achieve the best performance, parameters of consecutive layers can be distributed in a specific fashion, minimizing the required communication. When working with pairs of linear layers, we can split the first layer column-wise, and the subsequent layer row-wise, allowing us to compute the output with only a single all-reduce operation to combine the sharded outputs.
Unlike the dynamic sharding behaviour of FSDP, TP creates static memory partitions which result in a constant memory usage reduction scaling with the TP group size. This becomes crucial for massive models where even a single decoder layer is too large to fit into memory during the FSDP all-gather (recall that common practice in FSDP is to gather the weights of an entire decoder layer at a time). However, unlike FSDP which scales relatively linearly across nodes (up to a point - ~512 GPUs on a homogenous cluster, significantly less across lower-bandwidth connections), TP is only effective within the boundaries of a single node. TP requires frequent activation synchronization between devices during computation, as each device computes only a portion of the output, requiring the outputs from other devices to be communicated before continuing the forward pass. Thus, if we wish to utilise TP in a multi-node setup, we must consider composing TP with other parallelism techniques, while keeping TP only within a single node. Due to its large communications overhead, TP is not recommended for PCIe linked GPUs.
In Accelerate, the TP size is configured through
tp_size
inParallelismConfig
, whilst in Axolotl you can use thetensor_parallel_size
config field.
Context Parallelism
Recently, reasoning capabilities in LLMs resulted in sequence lengths skyrocketing as models use more and more tokens to solve complex tasks. To achieve this behaviour through fine-tuning, we need a way to train models on very large sequence lengths - which can sometimes reach up to a million tokens!
Since the attention operation in transformers scales quadratically with context length, this becomes impossible on a single GPU. For example, when fine-tuning a relatively small model such as Mistral-7B (which uses 32 attention heads), if we use a sequence length of 128k a single attention matrix will utilise 128k * 128k * 2 bytes * num_heads=32
= ~32GB * 32 = ~1TB of activations memory! Whilst this example is not realistic when using optimised attention implementations such as FlashAttention, it helps illustrate the growth in memory requirements from increasing the context length.
With context parallelism (CP), we can shard the inputs across the sequence dimension, resulting in each device only processing a chunk of the full context and computing a smaller portion of the full, prohibitively large, attention matrix. To see how this works, recall that the attention computation is described by the equation:
Where , , and are the query, key, and value matrices respectively. Each query vector (row, or input embedding) of must compute the attention scores against every key vector of in the entire sequence to correctly apply the softmax normalisation. These attention scores are then weighted with all value vectors in .
The crucial detail here lies in the fact that each row in can compute its attention score independently of one another, but each query vector still requires the full and matrices. In other words, given an input with sequence length $n$, we can expand our above attention equation as:
where we denote each row of the query matrix as . This can be generalized as:
When we shard the inputs across devices, the resulting , , and matrices (computed from these input shards) are also automatically sharded along the sequence dimension - each GPU computes queries, keys, and values only for its portion of the sequence. For example, with a world size of GPUs and sequence length :
- GPU 0 computes , ,
- GPU 1 computes , ,
- ...
- GPU computes , ,
How do we ensure the attention is computed correctly? As established above, each device only needs its own shard of , but requires the full and matrices to compute the attention correctly. We can achieve this by using a technique called RingAttention, which works as follows:
- Initially, each GPU holds its shard of , , (e.g., GPU 0 holds , , ).
- Each GPU then computes a partial attention matrix for its shard of and its local shard of , .
- Each GPU sends its shard of , to the next GPU in the ring.
- Each GPU receives a different shard of , from the previous GPU in the ring.
- Each GPU computes additional partial attention matrices , , etc. using the received , shards.
- Each GPU repeats this process until all shards of , have been received and all partial attention matrices have been computed.

Accelerate enables this with the accelerator.maybe_context_parallel
decorator, which is also showcased in the Accelerate example script. You can also learn more about how it works and its limitations in our CP concept guide.
Similar to TP, in Accelerate the CP size is configured through
cp_size
inParallelismConfig
, whilst in Axolotl you can use thecontext_parallel_size
config field.
ND Parallelisms
In the multi-node setting, data parallel techniques such as FSDP treat the entire network topology as if it existed along a single dimension. You may find this approach limiting for a variety of reasons:
- When scaling to more nodes, FSDP's collective operations become bottlenecked by inter-node latency, making training prohibitively slow.
- As we mentioned above, massive models may have decoder layers which cannot fit into GPU memory, or which may be too large to perform a forward pass with, even in a sharded state.
- It could be impossible to achieve your ideal batch size - either the batch becomes too large for pure data parallelism to handle efficiently, or too small due to memory constraints from model size.
To try and address some of these problems, we can think of multi-node clusters as having a two-dimensional topology: fast intra-node communication between devices along one axis, and relatively slower inter-node communication along another axis. Let’s consider how we can compose the parallelism techniques we’ve introduced so far to take advantage of this.
Hybrid Sharded Data Parallelism

Hybrid Sharded Data Parallelism (HSDP) is a kind of 2D parallelism which performs FSDP within a node, and DP across nodes - that is to say the model is replicated across each node, and sharded using FSDP within each node. This allows the greater communication overhead of FSDP to utilize the faster intra-node links, whilst DP minimises the slower inter-node communication overhead to a single gradient synchronisation step. You might consider this approach if you were facing problem 1 and wished to speed up training at the cost of increased memory usage.
It’s important to note that we can freely configure the shape of our 2D network topology, as we aren’t constrained to the dimensions being aligned with physical node boundaries - you might apply FSDP across 2 nodes whilst replicating across groups of 2 nodes, which would result in lower memory usage but slower throughput, but still reduce the intra-node FSDP communication overhead by a factor of two. This is a knob we encourage you to tune to your specific hardware setup and fine-tuning needs.
You can enable HSDP by defining both
dp_shard_size
anddp_replicate_size
in Accelerate'sParallelismConfig
or through Axolotl's config fields.
Fully Sharded Data Parallelism + Tensor Parallelism
As we mentioned earlier, TP should be applied within a node to utilize the high-bandwidth intra-node communications, thus, combining TP and FSDP involves sharding the model across nodes using FSDP, and within a node using TP. To a certain degree, this potentially offers a neat solution to all three of the issues above: the latency costs from FSDP could be reduced by a factor of 8, layers that are too large to fit on a single device are now evenly distributed across devices, and since each TP group receives an identical batch of data, we can also reduce our global batch size by a factor of 8. However, if this remains insufficient, we are unable to increase the TP size across nodes and must consider an alternative approach.
In Accelerate you can combine TP and FSDP by defining both
dp_shard_size
andtp_size
inParallelismConfig
, whilst in Axolotl you can add both of thedp_shard_size
andtensor_parallel_size
config fields.
Fully Sharded Data Parallelism + Context Parallelism
This is a 2D parallelism strategy that combines FSDP and CP, and while this is not very commonly used as CP already combines with FSDP (more on why in the accelerate concept guide), it can be useful in some cases i.e. when requiring a large sequence length, consequently requiring a large cp_size
. If this still doesn't fit into your memory budget, you can apply FSDP on top of this, further reducing the memory usage.
In Accelerate you can combine CP and FSDP by defining both
dp_shard_size
andcp_size
inParallelismConfig
, whilst in Axolotl you can add both of thedp_shard_size
andcontext_parallel_size
config fields.
Hybrid Sharded Data Parallelism + Tensor Parallelism
With a sufficiently large world size (note: while the minimum world size for 3D parallelism is 8, it is most effective at much larger scales), we can consider combining HSDP with TP which creates a hierarchy where DP first replicates the model across groups of nodes, FSDP then shards the model within each group, and TP splits individual layers within each node. You might consider this approach when facing all of the scaling constraints we mentioned above, as it provides the most flexibility to adapt to your specific training setup by making trade-offs between memory usage and throughput.
In Accelerate you can combine HSDP and TP by defining all of
dp_shard_size
,dp_replicate_size
, andtp_size
inParallelismConfig
. Similarly in Axolotl you can add all of thedp_shard_size
,dp_replicate_size
, andtensor_parallel_size
config fields.
Usage notes
There are additional ways to combine multiple parallelisms which we haven't covered, such as 4D parallel using HSDP + TP + CP, but they operate very similarly to the techniques we've already covered. Most of all, we encourage you to play with different techniques and configurations - this is the best way to gain an intuition for the different ways in which you can make memory/throughput trade-offs.
Below are some additional tips you may find useful when working in distributed settings:
When using FSDP and working with models that are too large to fit in a single device, enabling both CPU RAM efficient loading and sharded state dict checkpointing technique is crucial. You can enable this through the
cpu_ram_efficient_loading
andstate_dict_type
parameters in Accelerate'sFullyShardedDataParallelPlugin
,fsdp2_plugin = FullyShardedDataParallelPlugin( fsdp_version=2, auto_wrap_policy="transformer_based_wrap", transformer_cls_names_to_wrap=["LlamaDecoderLayer"], state_dict_type="SHARDED_STATE_DICT", cpu_ram_efficient_loading=True )
or through the
cpu_ram_efficient_loading
andstate_dict_type
config fields inside thefsdp_config
in Axolotl:fsdp_version: 2 fsdp_config: auto_wrap_policy: TRANSFORMER_BASED_WRAP transformer_layer_cls_to_wrap: LlamaDecoderLayer state_dict_type: SHARDED_STATE_DICT cpu_ram_efficient_loading: True
The total batch size used during training plays an important factor in training stability, memory usage, and data throughput. When using DP and/or FSDP the effective batch size is calculated as:
effective_batch_size = micro_batch_size * gradient_accumulation_steps * dp_world_size
.where
dp_world_size = (dp_shard_size * dp_replicate_size) / tp_size
. You can increase your batch size by increasing your total micro batch size or gradient accumulation steps in your training loop, or setting themicro_batch_size
andgradient_accumulation_steps
config fields in Axolotl, or increasing the totaldp_world_size
by adding more GPUs. As we mentioned above, this imposes a minimum total batch size ofdp_world_size
- when using pure DP/FSDP, this will be your total world size, and if this is too high the only way to decrease the total batch size is by introducing tensor parallelism. Finally, with a fixed number of GPUs and in memory-constrained scenarios, we recommend increasinggradient_accumulation_steps
instead ofmicro_batch_size
to achieve larger effective batch sizes, and vice-versa.Correspondingly, when your effective batch size increases due to introducing data parallelism, you should scale your learning rate to maintain training stability. Common approaches include linear scaling
scaled_lr = base_lr * (effective_batch_size / base_batch_size)
or square root scalingscaled_lr = base_lr * sqrt(effective_batch_size / base_batch_size)
.When memory constraints persist even with parallelism strategies, gradient checkpointing can provide additional memory savings by trading compute for memory. During the forward pass, only a subset of activations are kept in memory (typically at transformer block boundaries), and intermediate activations are recomputed during the backward pass. This technique works seamlessly with all parallelism strategies covered above. In Accelerate, you can enable it by setting
activation_checkpointing=true
inFullyShardedDataParallelPlugin
:fsdp2_plugin = FullyShardedDataParallelPlugin( fsdp_version=2, auto_wrap_policy="transformer_based_wrap", transformer_cls_names_to_wrap=["LlamaDecoderLayer"], state_dict_type="SHARDED_STATE_DICT", cpu_ram_efficient_loading=True, activation_checkpointing=True )
and similarly in Axolotl:
fsdp_version: 2 fsdp_config: auto_wrap_policy: TRANSFORMER_BASED_WRAP transformer_layer_cls_to_wrap: LlamaDecoderLayer state_dict_type: SHARDED_STATE_DICT cpu_ram_efficient_loading: True activation_checkpointing: True
Note that gradient checkpointing typically increases training time by ~20-30% due to activation recomputation, but can reduce activation memory by 60-80%, making it particularly valuable when training very large models or using long sequence lengths.