Multimodal AI models, which process information across modalities like text, images, and audio, are driving the next wave of AI innovation, powering everything from advanced search to autonomous systems. However, training state-of-the-art multimodal models is complex and computationally intensive.
This blog post demonstrates how disaggregated hybrid parallelism with Ray addresses this challenge. This approach fundamentally improves training efficiency at scale by tailoring an optimal parallelism strategy for each distinct module of a multimodal model.
We implemented this strategy on Ray and tested it with Qwen-VL 32B. Compared to uniformly applying tensor parallelism to the entire model, we achieved a 1.26–1.37x throughput improvement. Furthermore, compared to DeepSpeed ZeRO3, we were able to train sequences up to 7x longer.
Check out our repository to try it out yourself.
Unlike traditional Large Language Models (LLMs) that typically have a homogeneous architecture composed of a uniform stack of identical Transformer layers, multimodal models often consist of highly specialized "modules" which have vastly different computational and memory requirements.
For example, many Vision-Language Models (VLMs) combine a vision encoder with a large language model (LLM). Figure 1 shows the architecture of Qwen-VL, one of the most popular VLMs. First, the vision encoder processes high-resolution images as a sequence of tokens. Next, the Projector adapts the hidden states of the vision encoder to align with the LLM, compressing the sequence length to one-quarter of the original. Finally, the LLM processes the concatenated input of projected vision tokens and text tokens. Typically, a vision encoder is much smaller than the LLM backbone. In Qwen-VL 2.5, for instance, the vision encoder has 670 million parameters, while the LLM ranges from 3 billion to 72 billion parameters.
Figure 1. Qwen-VL's architecture Training large-scale VLMs with high-resolution images and videos (which result in long sequences) presents several challenges. For LLMs with billions of parameters, fitting model weights, gradients, and optimizer states into a single GPU's memory is often impossible. To address this, techniques such as tensor parallelism, DeepSpeed ZeRO3, and PyTorch FSDP are widely adopted to shard the model across multiple devices. In addition, to fit very long sequences into GPU memory, sequence parallelism algorithms such as DeepSpeed-Ulysses [1] and Ring-Attention [2] are often used. Below, we overview these methods and examine the overheads when applied to multimodal model training.
First, let's look at tensor parallelism. Since it involves allreduce communication on activations, the overhead scales linearly with sequence length (Figure 2). This creates a bottleneck for the vision encoder: it processes long sequences but has few parameters, making such partitioning unnecessary and expensive.
Figure 2. Tensor ParallelismAlternatively, let's look at DeepSpeed ZeRO3 (Figure 3). It uses allgather to reconstruct sharded parameters for each layer. While communication scales with model size rather than sequence length, this forces every GPU to temporarily hold full parameters along with unsharded sequences. Consequently, peak memory usage spikes significantly for long sequences, potentially causing out of memory (OOM) errors.
Figure 3. DeepSpeed ZeRO3 / PyTorch FSDPFinally, let's look at sequence parallelism (Figure 4). It partitions activations along the sequence dimension but leaves model parameters unsharded. Consequently, fitting large models into memory is impossible.
Figure 4: Sequence ParallelismIn summary, applying widely used partitioning techniques like tensor parallelism, DeepSpeed ZeRO3 (PyTorch FSDP), or sequence parallelism uniformly across the entire VLM architecture is not efficient, as it fails to address the distinct computational characteristics of the vision and language modules.
To overcome the challenge, we propose disaggregated hybrid parallelism with Ray, which applies optimal parallelization strategies tailored to each distinct module of a multimodal model.
Ray is a universal framework for scaling AI and Python applications. At its core, Ray provides actors, which are stateful distributed objects. This abstraction allows developers to orchestrate complex and scalable workloads, such as large-scale data processing (e.g. at Amazon, Apple, and Netflix) and reinforcement learning pipelines (e.g. with verl, OpenRLHF).
By leveraging this architecture, developers can map distinct modules to independent groups of Actors. This design enables them to express complex orchestration logic within a simple training loop. Furthermore, the flexibility of Ray Actors allows for independent resource allocation. For example, developers can assign a specific number of GPUs or nodes to each group, tailoring the hardware setup to the unique requirements of each module.
Figure 5. Overview of Disaggregated Hybrid Parallelism with RayThis architectural flexibility translates into straightforward and intuitive code. For instance, we can implement an ActorGroup class to encapsulate the parallelization across multiple GPUs for a single module:
1class ActorGroup:
2 def __init__(
3 self, config, actor_cls,
4 num_actors, num_cpus, num_gpus,
5 ):
6 # Create Ray actors of the specified class
7 self._actors = [
8 remote_actor_cls.options(
9 num_cpus=num_cpus,
10 num_gpus=gpus_per_actor,
11 ).remote(config, i) for i in range(num_actors)]
12 self._setup_process_group()
13...
14
15def main():
16...
17 # Create actor group for vision
18 vision_group = ActorGroup(
19 vision_config, VisionTrainerClass,
20 num_actors=vision_parallel_size, ...
21 )
22 # Create actor groups for text
23 text_group = ActorGroup(
24 text_config, TextTrainerClass,
25 num_actors=text_parallel_size, ...
26 )With these actor groups established, Ray can be used to build a single controller to orchestrate the training flow. The following pseudo-code illustrates a typical multimodal training loop that coordinates the forward and backward passes between a vision encoder group and an LLM group.
1for _ in range(num_iterations):
2 # Forward Pass
3 vision_output_refs = vision_group.forward(...)
4 text_output_refs = text_group.forward(vision_output_refs, ...)
5
6 losses = ray.get(text_output_refs)
7
8 # Extract loss values
9 loss_values = [loss.item() if torch.is_tensor(loss) else loss for loss in losses]
10
11 # Backward Pass
12 text_backward_refs = text_group.backward()
13 vision_backward_refs = vision_group.backward(text_backward_refs)
14
15 # Optimizer Steps
16 text_group.optimizer_step(...)
17 vision_group.optimizer_step(...)This clear separation of orchestration logic from module execution is a core benefit of Ray's approach. It enables developers to define complex hybrid parallel strategies without cluttering the main training loop.
To demonstrate the real-world impact of using Ray, we conducted a benchmark using a state-of-the-art model, Qwen-VL. As shown in Figure 1, Qwen-VL couples a vision encoder with an LLM. As discussed earlier, the vision encoder is significantly smaller than the LLM butprocesses much longer sequences. Consequently, applying tensor parallelism uniformly causes communication bottlenecks, while using DeepSpeed ZeRO3 uniformly often leads to out of memory (OOM) errors as sequence length increases.
Our approach applies sequence parallelism to the vision encoder and tensor parallelism to the LLM. This strategy offers the following benefits:
Lower Overhead: Sequence parallelism (DeepSpeed-Ulysses) incurs less communication overhead than tensor parallelism given fast NVLinks.
Leveraging Window Attention: By capitalizing on Window Attention, we significantly reduce synchronization points. In Qwen-VL, for instance, communication is required in only 4 out of 32 layers.
Memory Efficiency: We can still fit the LLM parameters into GPU memory using tensor parallelism.
To quantify these benefits, we conducted benchmarking with the Qwen-VL 2.5 32B model. We compared the following three conditions:
DeepSpeed ZeRO3 (the original authors’ implementation)
Tensor Parallelism
Disaggregated hybrid parallelism: sequence parallelism + DeepSpeed ZeRO1 for vision encoder, Tensor Parallelism for LLM
The experiments were conducted using 8× H100 GPUs on a single node, with BF16 mixed precision, Flash Attention 2, and activation checkpointing. We used the MSCOCO dataset, resizing images to measure performance across different sequence lengths. The global batch size was capped at 8 (corresponding to a micro-batch size of 1, the minimum supported by ZeRO3). When memory limits made this unfeasible, we used the largest batch size that could fit on the GPUs. ZeRO3 results were obtained using the implementation from the Qwen-VL repository, while the Tensor Parallelism baseline was implemented in-house without Ray.
Figure 6 shows the resulting throughput with different vision token counts. (TP = tensor parallelism; SP = Sequence Parallelism).
Figure 6. Throughput comparison of three parallelization strategies (SP+TP, pure TP, ZeRO3) on Qwen2.5-VL-32B across varying vision sequence lengths (1k–65k tokens).Here are our key findings:
SP+TP achieves 1.26-1.37× speedup over pure TP consistently across all sequence lengths (1k-65k tokens)
SP+TP enables training at extreme sequence lengths where ZeRO3 fails (65k tokens): ZeRO3 hits OOM errors at 16k+ tokens due to the memory overhead of fully gathered parameters and activations, whereas SP+TP effectively distributes activation memory across GPUs.
ZeRO3 has sequence-length dependent performance: optimal at ~9k tokens but struggling with short sequences and hitting OOM errors at long sequences. SP+TP, on the other hand, maintains stable and high throughput across all evaluated sequence lengths.
These results clearly demonstrate the significant advantage of disaggregated hybrid parallelism (SP+TP) approach, showing consistent speedup over pure tensor parallelism and superior memory efficiency and performance consistency compared to DeepSpeed ZeRO3, which allows training at extreme sequence lengths where monolithic frameworks fail.
We introduce disaggregated hybrid parallelism to solve computational challenges in training heterogeneous multimodal AI models like VLMs. Instead of using a single parallelization strategy for the entire model, DHP utilizes Ray to apply tailored parallelism to distinct components (e.g., sequence parallelism for the vision encoder and tensor parallelism for the LLM). Ray orchestrates this via dedicated "actor groups" for each module and a single controller, effectively resolving the inefficiencies caused by mismatched parallelization and resource needs.
Validation using the Qwen-VL 32B model showed that the Ray-enabled DHP (SP+TP) achieved a 1.26–1.37x speedup over the pure tensor parallelism baseline. Crucially, DHP demonstrated superior memory efficiency, successfully training at extreme sequence lengths (16k+ tokens) where DeepSpeed ZeRO-3 failed due to OOM errors. This confirms Ray's flexible, disaggregated architecture is vital for achieving the speed and memory capacity required for training state-of-the-art multimodal AI models at scale.
Looking ahead, we plan to extend our evaluations to larger-scale GPU clusters and heterogeneous hardware configurations. We also aim to apply this strategy to a wider variety of multimodal architectures to further validate its versatility.
We invite you to try out the implementation on our GitHub repository and share your feedback. We welcome any feedback or contributions.
[1] Sam Ade Jacobs, Masahiro Tanaka, Chengming Zhang, Minjia Zhang, Reza Yazdani Aminabadi, Shuaiwen Leon Song, Samyam Rajbhandari, Yuxiong He. "DeepSpeed Ulysses: System Optimizations for Enabling Training of Extreme Long Sequence Transformer Models." arXiv:2309.14509, 2023. https://arxiv.org/abs/2309.14509
[2] Hao Liu, Matei Zaharia, Pieter Abbeel, “Ring Attention with Blockwise Transformers for Near-Infinite Context”, arXiv:2310.01889, 2023. https://arxiv.org/abs/2310.01889