Challenges in Enabling PyTorch Native Pipeline Parallelism for Hugging Face Transformer Models #589
Replies: 2 comments 2 replies
-
Have you looked at: This way
Reference: Option 2: splitting a model automatically |
Beta Was this translation helpful? Give feedback.
-
|
Hello, For Challenge 4: Gradient Scaling, When multiple microbatches exist, calculating |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Authors: @hemildesai
Introduction
As large language models (LLMs) continue to grow in scale - from billions to hundreds of billions of parameters - training these models efficiently across multiple GPU nodes has become increasingly challenging. While data parallelism works well for smaller models, larger models often exceed the memory capacity of a single GPU or a single node, necessitating more sophisticated parallelization strategies.
Pipeline parallelism is one such strategy that addresses this challenge by splitting a model's layers across different devices and processing them in a pipelined fashion. Each device processes a different stage of the model, enabling training of models that wouldn't fit on a single device, while maintaining high GPU utilization through overlapped computation. You can read more about pipeline parallelism in this PyTorch guide or in the Megatron-LM paper.
NeMo Automodel is a GPU-accelerated PyTorch library for training LLMs. We recently added support for PyTorch native pipeline parallelism via:
AutoPipelinefor any Hugging Face Transformer language model, including popular LLMs in the AutoModelForCausalLM category such as Llama, Qwen, Mistral, Gemma, with support for vision language models and additional architectures coming soon.functionalAPI for custom models, or for users seeking more granular control. ThefunctionalAPI offers modular building blocks that can be adapted to any PyTorch model architecture—making pipeline parallelism accessible across the entire ecosystem.This article will focus on
AutoPipeline, and users can refer to the guide here for more details on thefunctionalAPI.While we drew inspiration from TorchTitan during the development of our pipelining component, enabling automatic pipeline parallelism for Hugging Face models presented a unique set of challenges. In this article, we explore those challenges and share the solutions we implemented in
AutoPipelineto make pipeline parallelism both robust and user-friendlyHow AutoPipeline Works: High-Level Process
To give you a high-level overview, when you call
AutoPipeline(...).build(model, loss_fn), here's what happens under the hood:.modelattribute, number of layers, rotary embeddings, etc.)layers_per_stageand validate against pipeline sizePipelineStagewith proper stage indexing and device placementparallelize_fnis providedThe result is a complete pipeline-parallel setup with automatic handling of all the challenges described in this article.
Challenge 1: Module Assignment - Understanding Model Structure
When implementing pipeline parallelism, one of the first challenges is determining how to split the model across pipeline stages. This isn't simply a matter of dividing layers equally - certain components need special treatment based on how Hugging Face models are structured.
Let's examine a typical Hugging Face causal language model structure using Qwen3 as an example:
This creates a hierarchical structure where:
Qwen3ForCausalLMcontainsmodel(inner model) andlm_head(output projection)model.modelcontainsembed_tokens,layers,norm, androtary_embWhen splitting this model across pipeline stages, different components have different placement requirements:
model.embed_tokens): Must be in the first stage only - converts token IDs to embeddingsmodel.layers): Distributed across multiple stages - the core computationmodel.norm): Must be in the last or second last stage - applies final layer normalizationlm_head): Must be in the last stage only - projects to vocabulary logitsmodel.rotary_emb): Must be in all stages - shared position encoding utilityThese placement constraints become even more pronounced for vision language models and other complex model architectures.
Our
generate_hf_model_fqn_per_model_partfunction infunctional.pyhandles this complexity automatically for most cases:This implementation demonstrates several key insights:
Hierarchical Naming: The
fqn_prefix="model."parameter accounts for HuggingFace's nested structure where most components are insidemodel.modelMixed Hierarchy Handling: Notice that
lm_headhas no prefix because it lives at the top level (Qwen3ForCausalLM.lm_head), whilenormuses the prefix because it's inside the inner model (Qwen3ForCausalLM.model.norm)Shared Component Replication: The
rotary_embis added to every stage because position embeddings are needed by all transformer layersSmart Distribution: The function automatically calculates how many layers per stage, handling remainder layers by distributing them to the first few stages
To illustrate how this works in practice, consider a 32-layer Qwen3 model split across 4 stages:
[ # Stage 0: Input processing + first 8 layers + shared utilities ["model.embed_tokens", "model.layers.0", ..., "model.layers.7", "model.rotary_emb"], # Stage 1: Middle layers + shared utilities ["model.layers.8", ..., "model.layers.15", "model.rotary_emb"], # Stage 2: Middle layers + shared utilities ["model.layers.16", ..., "model.layers.23", "model.rotary_emb"], # Stage 3: Final layers + output processing + shared utilities ["model.layers.24", ..., "model.layers.31", "model.norm", "lm_head", "model.rotary_emb"] ]This intelligent assignment ensures that each stage has exactly what it needs, while avoiding duplication of unique components like embeddings and the language modeling head. It can also serve as a reference for automatically splitting any custom models for your own use case.
Challenge 2: nn.ModuleList vs nn.ModuleDict: The Indexing Problem
A subtle but critical issue in pipeline parallelism involves how PyTorch's
nn.ModuleListandnn.ModuleDictbehave when models are split across stages. This seemingly minor implementation detail can cause significant problems with checkpointing and state management.Most Hugging Face models use
nn.ModuleListto store transformer layers:The problem arises when we split this model across pipeline stages. Each stage gets a subset of the layers, but
nn.ModuleListautomatically re-indexes its contents starting from 0.This seemingly innocent re-indexing creates a disaster scenario for checkpointing:
Fortunately, AutoPipeline solves this by converting
nn.ModuleListtonn.ModuleDictwith explicit layer naming:With this approach, checkpoint saving and loading work correctly across all pipeline stages, maintaining the original layer identities throughout the training process.
Challenge 3: Forward Method Patching: Handling Missing Modules
Another complex challenge in pipeline parallelism is ensuring that forward methods work correctly when modules are distributed across different pipeline stages. Standard Hugging Face forward methods assume all components are available locally, but in pipeline parallelism, this assumption breaks down.
To understand the issue, consider a standard Hugging Face model forward method:
Problem 1: When we split the model across stages:
embed_tokens, but stages 1-3 don'tnormandlm_head, but stages 0-2 don'tself.embed_tokens(input_ids)on stage 1 results inAttributeError: 'NoneType' object has no attribute '__call__'Problem 2: PyTorch's Pipeline Parallelism API expects each stage to return a single tensor output, which can be passed to the next stage or used by the loss function in the final stage. However, Hugging Face models typically produce customized outputs, which are not directly compatible with this requirement.
To address these fundamental incompatibilities, AutoPipeline solves this by replacing the standard forward methods with pipeline-aware versions that handle missing modules and outputs gracefully. The actual implementation can be found in
hf_utils.py. AutoPipeline automatically applies these patches based on model type.Let's examine how this transformation works in practice.
Before Patching (Fails):
After Patching (Works):
This comprehensive patching approach solves both the missing module problem and the output compatibility issue, allowing Hugging Face models to work seamlessly with PyTorch's pipeline parallelism API.
While this solution is effective, it does introduce some maintenance considerations. First, we need to keep the patched forward methods in sync whenever
transformersversion is upgraded, otherwise it can cause unexpected errors. Second, not all language models may have the sameforwardmethod skeleton, which can result in incorrectly patched methods leading to subtle issues.Challenge 4: Gradient Scaling
A subtle but critical challenge in pipeline parallelism is ensuring correct gradient scaling when combining multiple parallelism strategies. This issue emerges during real training scenarios and can impact model convergence.
The problem became apparent during convergence testing, where we discovered that pipeline parallel training with mixed parallelism (PP + DP) resulted in different gradient norms compared to training with data parallelism alone. This occurred because, when pipeline parallelism was combined with data parallelism, gradients were incorrectly scaled by default—leading to different gradient norm curves.
According to PyTorch's pipeline parallelism documentation:
However, our training recipes use per-token loss calculation, which required a different approach. As a result, we had to disable automatic scaling in the pipeline schedule (
scale_grads=False) and handle gradient normalization manually in the training loop, ensuring proper scaling across all parallelism dimensions. This approach gives us precise control over gradient scaling, while maintaining compatibility with our per-token loss calculation.Specifically, we scale gradients in pipeline parallelism by dividing by a factor of
num_label_tokens_in_batch / dp_group_size. The/ dp_group_sizeis needed because FSDP averages the gradients across the data parallel ranks during reduction. (ref).The result is identical loss curves and gradient norm patterns across all parallelism configurations, ensuring that pipeline parallelism maintains correctness.
Verified HF models supported out of the box
After solving these challenges, many Hugging Face models that previously ran into GPU OOMs now train cleanly with AutoPipeline. Below is a summary of the models we successfully fine-tuned out of the box:
Note: This table summarizes models with at least one finished run. Many additional fine-tuned variants also ran successfully; the table groups them by family for brevity.
Conclusion
If you are training any HuggingFace Transformer model - Llama, Qwen, Mistral, Gemma, or any other,
AutoPipelineprovides the tools needed to scale your training across multiple GPUs efficiently and correctly.If you are training custom models and prefer more granular control, the
functionalAPI provides modular building blocks that can be adapted to any PyTorch model architecture, ensuring that the benefits of pipeline parallelism are accessible across the entire ecosystem.Ready to get started? Check out an example recipe with pipeline parallelism here and more documentation For questions, issues, or contributions, visit our GitHub repository.
Contributors
This work wouldn't have been possible without the incredible contributions from our team.
Special thanks to Huiying Li, Adil Asif and Alexandros Koumparoulis for their help adding pipelining support into Automodel - including checkpointing support, recipe integration, convergence sweeps, etc.
Additionally, a huge shoutout to Wenwen Gao, Bernard Nguyen, and Jennifer Gerhold for their invaluable guidance on the content — from shaping the narrative to ensuring technical accuracy and clarity.
Beta Was this translation helpful? Give feedback.
All reactions