HomeBlogBlog Detail

Integrating Ray with Flyte for Efficient ML Parallel Processing

By    

Ray is an open-source framework designed for building distributed applications. It provides a simple, flexible API for scaling Python code from a single machine to a large cluster for parallel processing, making it a powerful tool for a wide range of tasks, from machine learning to distributed data processing. Ray and Flyte share a common objective: simplifying the compute orchestration experience. This alignment allows data scientists, ML engineers, and other professionals to focus on their areas of expertise instead of learning the intricacies of infrastructure management. 

Integrating Ray with Flyte brings several significant advantages:

  • Parallelizing tasks without the overhead of constantly spinning up and down kubernetes pods: Sparing implementation details, Flyte achieves best-in-class traceability and reproducibility by executing each task within a container on a Kubernetes pod. Though this brings many benefits, there can be significant overhead associated with spinning up and down pods. Ray worker node pods can be spun up once when the cluster is provisioned, and all incoming parallel tasks can be directed to one of these running workers.

  • Leveraging helpful ML features of Ray within Flyte workflows: Ray offers a rich set of machine learning features for batch inference, embedding generation, training, hyperparameter tuning, model serving, and reinforcement learning, all of which can serve as valuable components in many AI workloads. 

  • Easily including existing Ray jobs in Flyte workflows: This minimizes refactoring and reinventing the wheel when connecting the components that make up a truly production grade workflow.

The goal of this integration is to enhance the capabilities of both platforms, enabling more efficient, scalable, and manageable workflows for data scientists and ML engineers. This can facilitate a wide variety of practices that benefit from parallel computation including hyperparameter tuning, batch inference, distributed training,

LinkThe Integration in Action on Union

Union, the platform to create AI products, integrates Flyte. Let's explore how we can run Ray jobs in Union using Flyte’s SDK. We will look at two ways of running jobs in Ray: submitting jobs to an existing cluster or letting Flyte create and manage the Ray cluster for you. 

Submitting Jobs on an Existing Cluster

Organizations may have existing Ray clusters used across various teams and applications. Keeping a cluster always available ensures quick access, reduces latency caused by resource provisioning, and can improve the overall efficiency of job execution.

1@ray.remote
2def f(x):
3    return x * x
4
5@task(
6    task_config=RayJobConfig(
7        address=<RAY_CLUSTER_ADDRESS>
8    )
9)
10def ray_task() -> typing.List[int]:
11    futures = [f.remote(i) for i in range(5)]
12    return ray.get(futures)

Letting Flyte Create and Manage an Ephemeral Cluster

By letting Flyte provision resources only when they are needed, paying for idle cluster capacity can be avoided, making this an attractive economical choice. Defining a Ray cluster in a Flyte decorator also brings version control to the cluster settings, aiding in reproducibility.  As Flyte handles all aspects of cluster management, from setup to resource allocation, expertise in Ray cluster configuration and maintenance is not needed.

1@task(
2task_config=RayJobConfig(
3worker_node_config=[
4WorkerNodeConfig(
5group_name="test-group", replicas=10
6)
7]),
8)
9def ray_task() -> typing.List[int]:
10    futures = [f.remote(i) for i in range(5)]
11    return ray.get(futures)

As seen in the above two examples, integrating Ray with Flyte is as simple as including a RayJobConfig in a task decorator. With the RayJobConfig you have control over a variety of fields including: 

  • Any initialization parameters

  • The number of worker nodes and their associated names and replicas

  • Whether or not worker node replicas should scale up or down with changing resource demand and if so, the minimum and maximum number of replicas

  • And more (see here)

Finally, when a Ray task is run in a Flyte workflow, Union conveniently links to the dashboard for the Ray cluster, allowing for visual exploration of cluster resources, utilization, logs, and more.

image4
image1

LinkUse Cases

Ray shines under the context of parallel computation. In ML and AI, there are many scenarios where this can be valuable including batch inference, training of parallelizable models, and hyperparameter tuning. Ray provides examples in their documentation. Let’s look at how easy it can be to run one of these examples in a Flyte workflow.

LinkHyperparameter Tuning

The following example showcases Ray Tune for hyperparameter tuning:

1from ray import train, tune
2
3def objective(config):
4    score = config["a"] ** 2 + config["b"]
5    return {"score": score}
6
7search_space = {
8    "a": tune.grid_search([0.001, 0.01, 0.1, 1.0]),
9    "b": tune.grid_search([1, 2, 3]),
10}
11
12tuner = tune.Tuner(objective, param_space=search_space)
13
14results = tuner.fit()
15print(results.get_best_result(metric="score", mode="min").config)

In this example, we select the parameters a and b that result in the lowest score. Say we wanted Flyte to provision a Ray cluster such that we can leverage the parallelization that comes native to ray.tune.Tuner. This is as simple as refactoring the above code to:

1from flytekit import task
2
3@task(
4    container_image=<custom-image-with-ray>,
5    Resources(mem="10Gi", cpu="2"),
6    task_config=RayJobConfig(
7        worker_node_config=[
8            WorkerNodeConfig(
9                group_name="ray-job", replicas=5
10            )
11        ]),
12)
13def ray_task() -> dict:
14    search_space = {
15        "a": tune.grid_search([0.001, 0.01, 0.1, 1.0]),
16        "b": tune.grid_search([1, 2, 3]),
17    }
18
19    tuner = tune.Tuner(objective, param_space=search_space)
20    results = tuner.fit()
21    return results.get_best_result(metric="score", mode="min").config

Note that we can even specify the resource requirements of our worker node replicas; in the above example each replica will have 10Gi of memory and two CPU cores. 

LinkSimple Batch Inference

Let’s set up a very simple, yet realistic, example of using the Ray plugin to parallelize batch inference in a Flyte workflow. In this example, we will pull some images from a remote S3 directory, organize them into batches, and pass each batch to a Ray worker for inference on a torchvision pretrained model. We’ll leverage Flyte’s ImageSpec to manage dependencies used in our Ray jobs, Flyte’s task decorator for specifying the resources available to our Ray jobs, and the simple ray.get() pattern for parallelizing the inference. 
> Note that we use ray.get() for simplicity, however, efficiency gains can likely be found using other patterns like ray.wait() or map_batches() from Ray Data.

image2

Highlighting some of the above mentioned patterns, our Ray job dependencies and resources can be defined using Flyte’s ImageSpec and task decorator as follows:

1from flytekit import task, workflow, ImageSpec
2from flytekit.types.directory import FlyteDirectory
3from flytekitplugins.ray import RayJobConfig
4import ray
5
6# define dependencies needed for the worker nodes
7custom_image = ImageSpec(
8    name=<image-name>,
9    registry=<image-repository>,
10    packages=['flytekitplugins-ray', 'torch', 'torchvision', 'pillow'],
11    apt_packages=["wget"],
12)
13
14# define resource and replica requirements for worker nodes
15@task(
16    container_image=custom_image,
17    requests=Resources(mem="5Gi", cpu="2", gpu="1"),
18    task_config=RayJobConfig(
19        shutdown_after_job_finishes=True,
20        worker_node_config=[
21            WorkerNodeConfig(
22                group_name="ray-job", replicas=4
23            )
24        ]),
25)
26def process_images_in_batches(input_bucket: str, ray_batch_size: int, torch_batch_size: int) -> Dict[str, int]:
27    """Collect image names and pass them in batches to ray worker nodes."""
28    image_files = FlyteDirectory.listdir(image_dir)[1:]
29
30    preds = [process_batch.remote(image_files[i:i + ray_batch_size], torch_batch_size) for i in range(0, len(image_files), ray_batch_size)]
31    # Collect the output from all worker nodes
32    pred_dits = ray.get(futures)
33    combined_preds_dict = {}
34    for d in pred_dits:
35        combined_preds_dict.update(d)
36    return combined_preds_dict

As seen above, in process_images_in_batches we can simply collect the images in our FlyteDirectory, chunk them into batches, and pass the batches to a remote ray function defined as follows:

1import torch
2import torchvision.models as models
3
4@ray.remote(num_gpus=1)
5def process_batch(ray_batch_keys: list[FlyteFile], torch_batch_size: int, batch_number: int) -> dict:
6    """Infer object class for a batch of images. Load sub-batches onto GPU for inference."""
7    model = models.resnet50(pretrained=True).to('cuda')
8    model.eval()
9    ray_batch_class_preds = []
10    for j in range(0, len(ray_batch), torch_batch_size):
11        ...  # load images on gpu as tensors of size torch_batch_size
12        out = model(batch_tensor)
13        ray_batch_class_preds += torch.argmax(out, dim=1).cpu().tolist()
14    # collect and return prediction class for each image
15    class_dict = {im.remote_source.split("/")[-1]: class_val for im, class_val in
16                  zip(ray_batch_keys, ray_batch_classes)}
17
18    return class_dict

Within process_batch, we can further create batches of images, each of which is passed to our pytorch model. The results are compiled to a dictionary and returned to our Flyte task. Note the use of @ray.remote(num_gpus=1) which allows the GPU we provisioned in our Flyte task decorator to be utilized by Ray.

Using a small subset of the MSCoco dataset (5320 images, 841MiB), we can compare the runtime of this workflow where each batch is processed serially without Ray, and then in with Ray where four workers are utilized to process batches in parallel.

image3

As expected, we see a very substantial improvement moving from serial executions to parallel executions. We also notice an improvement when running inference on a GPU as compared to a CPU. Quantitatively, we see a 72% reduction in runtime when parallelizing on a CPU and an 84% reduction in runtime when parallelizing on a GPU. Note that we would expect the discrepancy between serial and parallel executions to increase with more data as any auxiliary tasks such as provisioning the Ray cluster or downloading the torchvision mode become a smaller proportion of runtime. Additionally, other factors such as the number of replicas, batch sizes, and the resources used would have a meaningful impact on these measurements. 

The full batch inference example can be found in the unionai-examples repository. The example can be registered using the following commands:

1pip install 'unionai[byoc]' flytekitplugins-envd
2export IMAGE_SPEC_REGISTRY="<your-container-registry>"
3git clone https://github.com/unionai/unionai-examples
4cd unionai-examples/_blogs/ray-batch-inference
5unionai register ray_inference_wf.py

Then, provided you have added images to an accessible S3 directory, you can use the Union UI to launch your workflow:

image5

Integrating Ray with Flyte on Union enhances the capabilities of both platforms by providing a seamless experience for building and managing distributed applications. The integration allows for parallelizing tasks without frequent pod spin-ups, leveraging Ray's machine learning features within Flyte workflows, and including existing Ray jobs with minimal refactoring. This results in efficient, scalable, and manageable workflows, benefiting data scientists and ML engineers by enabling practices like hyperparameter tuning, batch inference, and distributed training.

Ready to try Anyscale?

Access Anyscale today to see how companies using Anyscale and Ray benefit from rapid time-to-market and faster iterations across the entire AI lifecycle.