# SPDX-License-Identifier: Apache-2.0# SPDX-FileCopyrightText: Copyright contributors to the vLLM project"""This example shows how to use Ray Data for data parallel batch inference.Ray Data is a data processing framework that can process very large datasetswith first-class support for vLLM.Ray Data provides functionality for:* Reading and writing to most popular file formats and cloud object storage.* Streaming execution, so you can run inference on datasets that far exceed the aggregate RAM of the cluster.* Scale up the workload without code changes.* Automatic sharding, load-balancing, and autoscaling across a Ray cluster, with built-in fault-tolerance and retry semantics.* Continuous batching that keeps vLLM replicas saturated and maximizes GPU utilization.* Compatible with tensor/pipeline parallel inference.Learn more about Ray Data's LLM integration:https://docs.ray.io/en/latest/data/working-with-llms.html"""importrayfrompackaging.versionimportVersionfromray.data.llmimportbuild_llm_processor,vLLMEngineProcessorConfigassertVersion(ray.__version__)>=Version("2.44.1"),("Ray version must be at least 2.44.1")# Uncomment to reduce clutter in stdout# ray.init(log_to_driver=False)# ray.data.DataContext.get_current().enable_progress_bars = False# Read one text file from S3. Ray Data supports reading multiple files# from cloud storage (such as JSONL, Parquet, CSV, binary format).ds=ray.data.read_text("s3://anonymous@air-example-data/prompts.txt")print(ds.schema())size=ds.count()print(f"Size of dataset: {size} prompts")# Configure vLLM engine.config=vLLMEngineProcessorConfig(model_source="unsloth/Llama-3.1-8B-Instruct",engine_kwargs={"enable_chunked_prefill":True,"max_num_batched_tokens":4096,"max_model_len":16384,},concurrency=1,# set the number of parallel vLLM replicasbatch_size=64,)# Create a Processor object, which will be used to# do batch inference on the datasetvllm_processor=build_llm_processor(config,preprocess=lambdarow:dict(messages=[{"role":"system","content":"You are a bot that responds with haikus."},{"role":"user","content":row["text"]},],sampling_params=dict(temperature=0.3,max_tokens=250,),),postprocess=lambdarow:dict(answer=row["generated_text"],**row,# This will return all the original columns in the dataset.),)ds=vllm_processor(ds)# Peek first 10 results.# NOTE: This is for local testing and debugging. For production use case,# one should write full result out as shown below.outputs=ds.take(limit=10)foroutputinoutputs:prompt=output["prompt"]generated_text=output["generated_text"]print(f"Prompt: {prompt!r}")print(f"Generated text: {generated_text!r}")# Write inference output data out as Parquet files to S3.# Multiple files would be written to the output destination,# and each task would write one or more files separately.## ds.write_parquet("s3://<your-output-bucket>")