Tl;dr: Ray Direct Transport enables fast and direct GPU transfers in Ray via RDMA-backed transports. Learn how to use the API to build distributed systems for use cases such as RL for LLMs.
Ray has seen a mass increase of adoption for reinforcement learning (RL) for LLMs due to the need to flexibly orchestrate distributed GPUs. Ray provides an API for distributed orchestration, allowing RL infrastructure builders to compose training and inference engines; implementing different placement and scheduling strategies; and transferring rollout data and weights between different frameworks and tools.
While Ray’s API simplifies orchestration, a lot of RL workloads require the efficient handling and transfer of large tensors between GPUs, which isn’t well supported by Ray’s CPU-based object store. Ray’s object store provides high-bandwidth shared reads for objects that reside on CPU memory, but it can’t take advantage of high performance transports like Infiniband and NVLink that move data directly from GPU to GPU.
Today, we are introducing Ray Direct Transport (RDT), a new feature in Ray Core that allows users to easily leverage high-bandwidth inter-GPU communication mechanisms such as NVLink or RDMA (Remote Direct Memory Access) over Infiniband. Using RDT, we can achieve up to 1000x faster GPU-GPU transfers than Ray’s native object store with a few lines of code change.

To show how this works, we’ll introduce the system requirements for RL for LLMs and show how to use RDT APIs to build a simple script for RL training that can fit on one GPU. We’ll use RDT to manage the data transfer between workers, and we’ll show how you can swap out the communication backend between libraries like NCCL and NIXL.
You can try out RDT today with Ray 2.51.1. Start here at the docs.
RL for LLMs have a unique set of system requirements compared to other workloads.

In particular, there are two communication-heavy steps within an RL training loop:
Weights synchronization from training to inference frameworks. In each training step, new model weights are generated, and the framework needs to send the updated weights to each inference engine replica. For large models and clusters, this can require 100s of GBs to TBs of data transfer.
Transfers of rollout data from inference to training. While text-only models generate relatively small amounts of data, this can become a performance problem for multimodal models.
We built RDT to support efficient data transfer of large objects via specialized data transports, especially those that use RDMA to reduce software overheads:
Large objects: As models increase in size, the amount of data transfer needed per training step rapidly increases. Reducing software overheads like unnecessary data copies or serialization is therefore critical.
Specialized data transports: Modern GPUs often come packaged with high-bandwidth interconnects like NVIDIA’s NVLink that can speed up data transfers significantly. Meanwhile, the fastest data transport can vary for each cluster. Users should be able to easily select and efficiently utilize the available transports on each node.
RDT supports these use cases by letting you specify data dependencies between actor tasks using the standard Ray API, but allowing you to choose what transport to use for the data transfer. Currently, RDT supports Gloo and NVIDIA’s NCCL and NIXL libraries; future releases will support other transports such as CUDA IPC and make the transport pluggable so that you can bring your own transport.
First let’s get an idea of how the API works. We’ll build off of the Ray Core API but add annotations for RDT-enabled objects.
Figure: Architecture diagram for Ray Direct Transport. Actor tasks can return and load GPU tensors. Actors bypass the Ray object store and exchange data directly using a third-party RDMA-backed transport like NVIDIA’s NCCL. The driver holds the RDT object metadata and coordinates the transfer.Note: As of Ray v2.50, RDT is only available for Ray actors and only torch.Tensor data will be transferred via the RDT transport. If torch.Tensors are nested inside of a Ray object, e.g., if an actor produces a list of torch.Tensors, then the non-torch.Tensor data will still be transferred via the Ray object store.
To get started, define an actor class and a task that returns a torch.Tensor:
1import torch
2import ray
3
4@ray.remote
5class MyActor:
6 def random_tensor(self):
7 return torch.randn(1000, 1000)Next, decorate the actor task with @ray.method(tensor_transport=transport_mode), where transport mode can be one of “nccl”, “nixl”, “gloo”, or “object_store”. We’ll use NIXL for this example, which supports GPU-GPU and CPU-CPU transfers.
1@ray.remote
2class MyActor:
3 @ray.method(tensor_transport="nixl")
4 def random_tensor(self):
5 return torch.randn(1000, 1000)This decorator can be added to any actor tasks that return a torch.Tensor, or that return torch.Tensors nested inside other Python objects. Adding this decorator will change Ray’s behavior in the following ways:
When returning the tensor, Ray will store a reference to the tensor instead of copying it to the CPU-based Ray object store.
When the ray.ObjectRef is passed to another task, Ray will use NIXL to transfer the tensor to the destination task.
Note that for (2) to work, the @ray.method(tensor_transport) decorator only needs to be added to the actor task that returns the tensor. It should not be added to actor tasks that consume the tensor (unless those tasks also return tensors). This example also assumes that both the producer and consumer of the tensor have NIXL installed. pip install nixl is the easiest way to install NIXL; for best performance, check out the NIXL instructions for building from source.
Now we can create and pass RDT objects between the actors. Here is a full example:
1import torch
2import ray
3
4@ray.remote
5class MyActor:
6 @ray.method(tensor_transport="nixl")
7 def random_tensor(self):
8 return torch.randn(1000, 1000)
9
10 def sum(self, tensor: torch.Tensor):
11 return torch.sum(tensor)
12
13sender, receiver = MyActor.remote(), MyActor.remote()
14
15# The tensor will be stored by the `sender` actor instead of in Ray's object
16# store.
17tensor = sender.random_tensor.remote()
18result = receiver.sum.remote(tensor)
19print(ray.get(result))When the ray.ObjectRef is passed to another task, Ray will use Gloo to transfer the tensor directly from the source actor to the destination actor instead of the default object store. Note that the @ray.method(tensor_transport) decorator is only added to the actor task that returns the tensor; once this hint has been added, the receiving actor task receiver.sum will automatically use Gloo to receive the tensor. In this example, because MyActor.sum does not have the @ray.method(tensor_transport) decorator, it will use the default Ray object store transport to return torch.sum(tensor).
For more examples, including how to use ray.put and collective-based transports, see the docs.
Figure: When using ray.put, the caller takes the place of the driver and holds both the RDT object data and metadata.
RDT can speed up object transfers between Ray actors significantly. Here is a benchmark showing the performance of different GPU tensor transports as a function of object size. This benchmark creates a CUDA tensor on one actor and sends it to a second actor on a different GPU on the same node. The second actor returns the sum of the tensor to the driver. We measure the end-to-end time to submit and finish the tasks on both actors, using 2 NVIDIA H100 GPUs.

As expected, Ray object store scales poorly with the object size due to copies between CPU and GPU, while RDT can speed up the transfers significantly by using fast inter-GPU links.
Next we’ll show how to use RDT to speed up weights synchronization in a minimal RL example. In later blog posts, we’ll extend this example to apply RDT to rollout data transfer and add support for multi-GPU LLMs.
Our minimal RL example mirrors the dataflow used to train LLMs using RL. In particular, we solve a toy problem using the Group Relative Policy Optimization (GRPO) algorithm. Our “environment” randomly generates a 2D direction vector, and the model has to predict which of the eight compass directions is closest to this vector:

GRPO is an RL algorithm that has become popular for training LLMs. It works by generating a group of outputs for each input, and computing the advantage of each output relative to the mean reward of its group. The algorithm prevents the current policy model from making predictions that are too different from previous versions of the same model to stabilize training and “catastrophic forgetting” of previous experiences.
The following diagram shows the different Ray actors involved in this example along with the resources needed for each actor.

Each arrow moves tensors from one actor to another. RDT can be applied to any of these data transfers, but in this case we’ll just apply it to the Learner → Generator data transfer since this requires a GPU-to-GPU copy.
Here are the steps for the application in detail:
[CPU] The “environment” generates random 2D vectors.
[GPU] The Generator policy predicts which of the 8 compass directions is closest to the input vector: W, NW, N, NE, E, SE, S, or SW.
[CPU] The Scorer computes rewards analytically using cosine similarity.
[CPU] The scored slices are added to the ReplayBuffer. The ReplayBuffer allows the model to learn from past experiences.
[GPU] The Learner model samples from the replay buffer and uses the GRPO algorithm to update its weights.
[GPU] The Learner sends the updated weights to the Generator, completing one training step.
For brevity, we just show the key parts of the training script in the post. The full training code is available here.
At each training step, the environment randomly samples a batch of 2D unit vectors. These state vectors are the inputs to the Generator actor. The policy model takes these 2D states as input and outputs logits over the eight possible actions.
1# The number of actions to sample for each state.
2GROUP_SIZE = 10
3
4@ray.remote(num_gpus=1)
5class Generator:
6 def __init__(self, scorer): ...
7
8 def generate(self, states: torch.Tensor):
9 # states are randomly sampled unit vectors.
10 logits = self.model(states.cuda())
11 dist = Categorical(logits=logits)
12 actions = dist.sample((GROUP_SIZE,))
13 logps = dist.log_prob(actions)
14 # Move tensors to CPU and send to the Scorer.
15 slice_batch = {
16 "policy_version": self.policy_version,
17 "state": states.detach().cpu(),
18 "actions": actions.transpose(0, 1).contiguous().detach().cpu(),
19 "old_logps": logps.transpose(0, 1).contiguous().detach().cpu(),
20 }
21 self.scorer.enqueue_trajectory_batch.remote(slice_batch)
22
23 def update_weights(self, cuda_weights):
24 # Recieve CUDA tensors from the Learner and update the model.
25 self.model.load_state_dict(cuda_weights).eval()
26 self.policy_version += 1The Scorer actor receives the dictionary of CPU tensors from the Generator and computes the rewards using the dot product of action versus the original state vector. Then, the Scorer sends the scored trajectories to the ReplayBuffer actor (using the default Ray object store).
1@ray.remote(num_gpus=1)
2class Scorer:
3 def __init__(self, replay_buffer):
4 self.replay_buffer = replay_buffer
5
6 def enqueue_trajectory_batch(self, batched_slices: dict):
7 rewards = ... # Score trajectories and send to ReplayBuffer.
8 self.replay_buffer.put.remote(dict(policy_version=policy_version,
9 state=batched_slices["state"],
10 actions=batched_slices["actions"],
11 old_logps=batched_slices["old_logps"],
12 rewards=rewards
13 ))
14The Replay buffer actor stores scored slices in its local heap memory and exposes another method to allow the Learner to sample from the buffer.
1@ray.remote
2class ReplayBuffer:
3 def __init__(self):
4 self.storage = []
5
6 def put(self, slice: dict[str, torch.Tensor]):
7 self.storage.append(slice)
8
9 def sample_from(self, n: int) -> list[dict[str, torch.Tensor]]: ...
10The Learner actor samples from the ReplayBuffer as well as the current policy in order to do a GRPO-style weight update:
1@ray.remote(num_gpus=1)
2class Learner:
3 def __init__(self, replay_buffer): ...
4 def step(self):
5 # Sample from ReplayBuffer
6 slices: list[TrajectorySlice] = ray.get(
7 self.replay_buffer.sample_from.remote(BATCH_SIZE)
8 )
9 # Perform GRPO update.
10 ...
11Finally, we expose a method on Learner to get the current model weights, which we’ll use to synchronize weights with the Generator.
1@ray.remote(num_gpus=1)
2class Learner:
3 @ray.method(tensor_transport="nixl")
4 def get_weights(self):
5 return self.model.state_dict()Note that here, we add the @ray.method(tensor_transport=”nixl”) decorator to use RDT for the weights transfer. Under the hood, this uses a one-sided RDMA read via the UCX library to bypass CPU memory. Without this decorator, Ray would transfer weights through the Ray object store.
Now, putting it all together, we implement a Ray driver program to execute a “one step off policy” asynchronous training loop, which launches generation on the current weights in parallel with the next Learner step. After each Learner update, we synchronize the weights back to the Generator to ensure that generation is at most one policy version behind.
1# Instantiate one instance of each actor.
2replay_buf = ReplayBuffer.remote()
3learner = Learner.remote(replay_buf)
4scorer = Scorer.remote(replay_buf)
5generator = Generator.remote(scorer)
6
7# Initialize the generator and replay buffer.
8generator.update_weights.remote(learner.get_weights.remote()
9generator.generate.remote(sample_unit_vector(BATCH_SIZE))
10# Training loop.
11for i in range(total_steps):
12 states = sample_unit_vector(batch_size=BATCH_SIZE)
13 generator.generate.remote(states)
14 # Launch the next learner step in parallel with generation.
15 learner.step.remote()
16 # Update the generator with new weights.
17 ray.get(generator.update_weights.remote(learner.get_weights.remote()))
18Here, we rely on Ray’s actor task ordering to ensure that the generator correctly alternates between updating weights and generation. For simplicity, we use ray.get to block on the Generator.update_weights task to ensure that the Generator has received the full weights from the Learner before the Learner starts its next step; otherwise, the Generator could receive partially updated weights. See the full training code here for an alternative method that avoids blocking calls on the driver.
On an NVIDIA B200 node, each step takes ~188ms when using the default Ray object store to complete the weights transfer. Adding the @ray.method(tensor_transport="nixl") decorator reduces each step’s run time to 81ms, a 2.3x improvement with one line of code change!
RDT is in alpha and we are actively looking for feedback! Ongoing features include performance enhancements, support for alternative tensor transports such as CUDA IPC, and support for bringing your own transport. Our goal is to achieve near parity with the Ray Core API, but for the moment, there are some key limitations to be aware of:
Support for Ray actors only. Non-actor Ray tasks are not supported.
Not yet compatible with asyncio. Follow the tracking issue for updates.
RDT objects are mutable. This means that Ray only holds a reference to the tensor, and will not copy the value until a transfer is requested. If user code writes to the tensor before the transfer happens, the receiver may see a partial update. This is different from the Ray object store, which always copies data by value to ensure immutability.
Check out the docs for more info on usage, file an issue for bug reports or feedback, and meet us at Ray Summit! In part 2 of this blog series, we’ll expand upon our example to demonstrate RDT for rollout data transfer and RL training for multi-GPU LLMs.