# SPDX-License-Identifier: Apache-2.0# SPDX-FileCopyrightText: Copyright contributors to the vLLM project"""Demonstrates reinforcement learning using vLLM and Ray,with native weight syncing APIs at engine instance.The script separates training and inference workloads onto distinct GPUsso that Ray can manage process placement and inter-process communication.A Hugging Face Transformer model occupies one GPU for training, whereas a2x tensor-parallel vLLM inference engine occupies two GPUs.The example performs the following steps:* Load the training model on one gpu (scheduled via ray)* Initialize the inference model with dummy weights across two gpus using vLLM's tensor parallelism and Ray placement groups.* Generate gibberish from a list of prompts using the randomly initialized inference engine.* Update the weights of the training model and broadcast the updated weights to the inference engine by using a Ray collective RPC group.* Generating from the list of prompts after weight sync should result in sensible outputs.This example assumes a single-node cluster with three GPUs, but Raysupports multi-node clusters. vLLM expects the GPUs are only used for vLLMworkloads. Residual GPU activity interferes with vLLM memory profiling andcauses unexpected behavior."""importosimportrayfromray.util.placement_groupimportplacement_groupfromray.util.scheduling_strategiesimportPlacementGroupSchedulingStrategyfromtransformersimportAutoModelForCausalLMfromvllmimportLLM,SamplingParamsfromvllm.configimportWeightTransferConfigfromvllm.distributed.weight_transfer.nccl_engineimport(NCCLWeightTransferEngine,)fromvllm.utils.network_utilsimportget_ip,get_open_portMODEL_NAME="facebook/opt-125m"# MODEL_NAME = "inference-optimization/Qwen3-0.6B-W4A16-G128"classMyLLM(LLM):"""Configure the vLLM worker for Ray placement group execution."""def__init__(self,*args,**kwargs):os.environ["VLLM_RAY_BUNDLE_INDICES"]="0,1"super().__init__(*args,**kwargs)@ray.remote(num_gpus=1)classTrainModel:"""Ray actor that wraps the training model on a dedicated GPU."""def__init__(self,model_name:str):self.model=AutoModelForCausalLM.from_pretrained(model_name,).to("cuda:0")self.port=get_open_port()self.master_address=get_ip()defget_master_address_and_port(self):returnself.master_address,self.portdefget_weight_metadata(self):"""Return weight names, dtypes, and shapes for weight transfer."""names=[]dtype_names=[]shapes=[]forname,pinself.model.named_parameters():names.append(name)dtype_names.append(str(p.dtype).split(".")[-1])shapes.append(list(p.shape))returnnames,dtype_names,shapesdefinit_weight_transfer_group(self,world_size):"""Initialize the NCCL process group for weight transfer."""self.model_update_group=NCCLWeightTransferEngine.trainer_init(dict(master_address=self.master_address,master_port=self.port,world_size=world_size,),)defbroadcast_weights(self,packed:bool=True):"""Broadcast weights to the inference engine."""NCCLWeightTransferEngine.trainer_send_weights(iterator=self.model.named_parameters(),group=self.model_update_group,packed=packed,)# Initialize Ray and set the visible devices. The vLLM engine will# be placed on GPUs 1 and 2.ray.init()# Create a placement group that reserves GPU 1–2 for the vLLM inference engine.# Learn more about Ray placement groups:# https://docs.ray.io/en/latest/placement-groups.html# Launch the training model actor. Ray's resource scheduler will allocate# 1 GPU (via num_gpus=1 in the decorator), ensuring pg_inference gets different GPUs.train_model=TrainModel.remote(MODEL_NAME)pg_inference=placement_group([{"GPU":1,"CPU":0}]*2)ray.get(pg_inference.ready())scheduling_inference=PlacementGroupSchedulingStrategy(placement_group=pg_inference,placement_group_capture_child_tasks=True,placement_group_bundle_index=0,)# Launch the vLLM inference engine. The `enforce_eager` flag reduces# start-up latency.# Note: Weight transfer APIs (init_weight_transfer_engine, update_weights)# are now native to vLLM workers.llm=ray.remote(num_cpus=0,num_gpus=0,scheduling_strategy=scheduling_inference,)(MyLLM).remote(model=MODEL_NAME,enforce_eager=True,tensor_parallel_size=2,data_parallel_size=1,distributed_executor_backend="ray",weight_transfer_config=WeightTransferConfig(backend="nccl"),load_format="dummy",quantization="fp8",)# Generate text from the prompts.prompts=["Hello, my name is","The president of the United States is","The capital of France is","The future of AI is",]sampling_params=SamplingParams(temperature=0)outputs=ray.get(llm.generate.remote(prompts,sampling_params))# Generate text with the initial model. The output is expected to be nonsense# because the weights are randomly initialized.print("-"*50)foroutputinoutputs:prompt=output.promptgenerated_text=output.outputs[0].textprint(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")print("-"*50)# Set up the communication channel between the training process and the# inference engine.master_address,master_port=ray.get(train_model.get_master_address_and_port.remote())world_size=ray.get(llm.get_world_size.remote())+1# +1 for the trainerinference_handle=llm.init_weight_transfer_engine.remote(dict(init_info=dict(master_address=master_address,master_port=master_port,rank_offset=1,world_size=world_size,)))# Initialize weight transfer group on both the training actor and inference enginetrain_handle=train_model.init_weight_transfer_group.remote(world_size)ray.get([train_handle,inference_handle])# Synchronize the updated weights to the inference engine using batched API.# Collect all weight metadata from the training actornames,dtype_names,shapes=ray.get(train_model.get_weight_metadata.remote())# Issue update_weights call with NCCL-specific update info# packed=True enables efficient batched tensor broadcastinginference_handle=llm.update_weights.remote(dict(update_info=dict(names=names,dtype_names=dtype_names,shapes=shapes,packed=True,)))# Broadcast all weights from trainer using the weight transfer APItrain_handle=train_model.broadcast_weights.remote(packed=True)ray.get([train_handle,inference_handle])# Generate text with the updated model. The output is expected to be normal# because the weights are updated.outputs_updated=ray.get(llm.generate.remote(prompts,sampling_params))print("-"*50)foroutputinoutputs_updated:prompt=output.promptgenerated_text=output.outputs[0].textprint(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")print("-"*50)
# SPDX-License-Identifier: Apache-2.0# SPDX-FileCopyrightText: Copyright contributors to the vLLM project"""Demonstrates async reinforcement learning using vLLM and Ray,with native weight syncing APIs at engine instance.The script separates training and inference workloads onto distinct GPUsso that Ray can manage process placement and inter-process communication.A Hugging Face Transformer model occupies one GPU for training, whereas a2x tensor-parallel vLLM inference engine occupies two GPUs.The example performs the following steps:* Load the training model on one gpu (scheduled via ray)* Initialize the inference model with dummy weights across two gpus using vLLM's tensor parallelism and Ray placement groups.* Generate gibberish from a list of prompts using the randomly initialized inference engine.* Pause generation once generation completes for one sequence* Update the weights of the training model and broadcast the updated weights to the inference engine by using a Ray collective RPC group.* Resume generation and print out the resultsThis example assumes a single-node cluster with three GPUs, but Raysupports multi-node clusters. vLLM expects the GPUs are only used for vLLMworkloads. Residual GPU activity interferes with vLLM memory profiling andcauses unexpected behavior."""importosimportuuidfromdataclassesimportasdictimportrayimporttorchfromray.util.placement_groupimportplacement_groupfromray.util.scheduling_strategiesimportPlacementGroupSchedulingStrategyfromtransformersimportAutoModelForCausalLM,AutoTokenizerimportvllmfromvllmimportSamplingParamsfromvllm.configimportWeightTransferConfigfromvllm.distributed.weight_transfer.baseimport(WeightTransferInitRequest,WeightTransferUpdateRequest,)fromvllm.distributed.weight_transfer.nccl_engineimport(NCCLWeightTransferEngine,NCCLWeightTransferInitInfo,NCCLWeightTransferUpdateInfo,)fromvllm.utils.network_utilsimportget_ip,get_open_portfromvllm.v1.executorimportExecutorMODEL_NAME="facebook/opt-125m"classMyLLM(vllm.AsyncLLMEngine):"""Configure the vLLM worker for Ray placement group execution."""def__init__(self,**kwargs):os.environ["VLLM_RAY_BUNDLE_INDICES"]="0,1"engine_args=vllm.AsyncEngineArgs(**kwargs)vllm_config=engine_args.create_engine_config()executor_class=Executor.get_class(vllm_config)super().__init__(vllm_config=vllm_config,executor_class=executor_class,log_requests=engine_args.enable_log_requests,log_stats=notengine_args.disable_log_stats,)asyncdefgenerate_with_retry(self,prompt_token_ids:list[int],sampling_params:vllm.SamplingParams)->vllm.RequestOutput:finish_reason="abort"whilefinish_reason=="abort":asyncforrequest_outputinself.generate({"prompt_token_ids":prompt_token_ids},sampling_params,request_id=str(uuid.uuid4()),):output=request_outputfinish_reason=output.outputs[0].finish_reasoniffinish_reason=="abort":print(f"ABORT, prompt_token_ids: {prompt_token_ids}, "f"generated token_ids: {list(output.outputs[0].token_ids)}")prompt_token_ids=prompt_token_ids+list(output.outputs[0].token_ids)returnoutput@ray.remote(num_gpus=1)classTrainModel:"""Ray actor that wraps the training model on a dedicated GPU."""def__init__(self,model_name:str):self.model=AutoModelForCausalLM.from_pretrained(model_name,dtype=torch.bfloat16).to("cuda:0")self.port=get_open_port()self.master_address=get_ip()defget_master_address_and_port(self):returnself.master_address,self.portdefget_weight_metadata(self):"""Return weight names, dtypes, and shapes for weight transfer."""names=[]dtype_names=[]shapes=[]forname,pinself.model.named_parameters():names.append(name)dtype_names.append(str(p.dtype).split(".")[-1])shapes.append(list(p.shape))returnnames,dtype_names,shapesdefinit_weight_transfer_group(self,world_size):"""Initialize the NCCL process group for weight transfer."""self.model_update_group=NCCLWeightTransferEngine.trainer_init(dict(master_address=self.master_address,master_port=self.port,world_size=world_size,),)defbroadcast_weights(self,packed:bool=True):"""Broadcast weights to the inference engine."""NCCLWeightTransferEngine.trainer_send_weights(iterator=self.model.named_parameters(),group=self.model_update_group,packed=packed,)# Initialize Ray and set the visible devices. The vLLM engine will# be placed on GPUs 1 and 2.ray.init()# Launch the training model actor. Ray's resource scheduler will allocate# 1 GPU (via num_gpus=1 in the decorator), ensuring pg_inference gets different GPUs.train_model=TrainModel.remote(MODEL_NAME)# Create a placement group that reserves GPU 1–2 for the vLLM inference engine.# Learn more about Ray placement groups:# https://docs.ray.io/en/latest/placement-groups.htmlpg_inference=placement_group([{"GPU":1,"CPU":0}]*2)ray.get(pg_inference.ready())scheduling_inference=PlacementGroupSchedulingStrategy(placement_group=pg_inference,placement_group_capture_child_tasks=True,placement_group_bundle_index=0,)# Launch the vLLM inference engine. The `enforce_eager` flag reduces# start-up latency.# Note: Weight transfer APIs (init_weight_transfer_engine, update_weights)# are now native to vLLM workers.llm=ray.remote(num_cpus=0,num_gpus=0,scheduling_strategy=scheduling_inference,)(MyLLM).remote(model=MODEL_NAME,enforce_eager=True,tensor_parallel_size=2,distributed_executor_backend="ray",load_format="dummy",weight_transfer_config=WeightTransferConfig(backend="nccl"),)# Generate text from the prompts.prompts=["My name is","The president of the United States is","The capital of France is","The future of AI is",]# Tokenize prompts to token IDstokenizer=AutoTokenizer.from_pretrained(MODEL_NAME)prompt_token_ids_list=[tokenizer.encode(prompt,add_special_tokens=False)forpromptinprompts]sampling_params=[SamplingParams(temperature=0,max_tokens=2),SamplingParams(temperature=0,max_tokens=32),SamplingParams(temperature=0,max_tokens=32),SamplingParams(temperature=0,max_tokens=32),]# Set up the communication channel between the training process and the# inference engine.master_address,master_port=ray.get(train_model.get_master_address_and_port.remote())world_size=3# 1 trainer + 2 inference workers (tensor_parallel_size=2)inference_handle=llm.init_weight_transfer_engine.remote(WeightTransferInitRequest(init_info=asdict(NCCLWeightTransferInitInfo(master_address=master_address,master_port=master_port,rank_offset=1,world_size=world_size,))))# Initialize weight transfer group on both the training actor and inference enginetrain_handle=train_model.init_weight_transfer_group.remote(world_size)ray.get([train_handle,inference_handle])generation_futures=[llm.generate_with_retry.remote(prompt_token_ids,params)forprompt_token_ids,paramsinzip(prompt_token_ids_list,sampling_params)]finished,pending=ray.wait(generation_futures,num_returns=1)# Pause generation in preparation for weight syncray.get(llm.pause_generation.remote(wait_for_inflight_requests=False))# Synchronize the updated weights to the inference engine using batched API.# Collect all weight metadata from the training actornames,dtype_names,shapes=ray.get(train_model.get_weight_metadata.remote())# Issue update_weights call with NCCL-specific update info# packed=True enables efficient batched tensor broadcastinginference_handle=llm.update_weights.remote(WeightTransferUpdateRequest(update_info=asdict(NCCLWeightTransferUpdateInfo(names=names,dtype_names=dtype_names,shapes=shapes,packed=True,))))# Broadcast all weights from trainer using the weight transfer APItrain_handle=train_model.broadcast_weights.remote(packed=True)ray.get([train_handle,inference_handle])# Resume generation since weight sync is completeray.get(llm.resume_generation.remote())# Get outputs separately - finished completed before pause, pending were paused/resumedfinished_outputs=ray.get(finished)pending_outputs=ray.get(pending)# Requests that finished before the pause: all generation used original weightsprint("-"*50)print("Requests that completed BEFORE weight change:")print("-"*50)foroutputinfinished_outputs:prompt_text=tokenizer.decode(output.prompt_token_ids)print(f"Prompt: {prompt_text!r}")print(f"Generated (with original weights): {output.outputs[0].text!r}")print("-"*50)# Requests that were paused mid-generation: some text before, some after weight changeprint("Requests that were PAUSED and RESUMED after weight change:")print("-"*50)foroutputinpending_outputs:# Decode the full prompt token IDs (original + generated before pause)full_prompt_text=tokenizer.decode(output.prompt_token_ids)# Find the original prompt by checking which one this output started withoriginal_prompt=next(pforpinpromptsiffull_prompt_text.startswith(p))# output.prompt_token_ids contains original prompt + tokens generated before pause# output.outputs[0].text is what was generated after resuming with new weightstext_before_pause=full_prompt_text[len(original_prompt):]text_after_pause=output.outputs[0].textprint(f"Original prompt: {original_prompt!r}")print(f"Generated before weight change: {text_before_pause!r}")print(f"Generated after weight change: {text_after_pause!r}")print("-"*50)