Skip to content

vllm.distributed.kv_transfer.kv_connector_agent

A centralized entrypoint to perform distributed KV cache transfer.

This implementation is a shim wrapper on two APIs exposed by kv_connector: 1. send_kv_caches_and_hidden_states 2. `recv_kv_caches_and_hidden_states

logger module-attribute

logger = init_logger(__name__)

KVTransferAgent

A class designated for distributed KV transfer

Target use cases
  1. Disaggregated prefill
  2. Remote KV cache storage
Source code in vllm/distributed/kv_transfer/kv_connector_agent.py
class KVTransferAgent:
    """
    A class designated for distributed KV transfer

    Target use cases:
        1. Disaggregated prefill
        2. Remote KV cache storage
    """

    def __init__(
        self,
        rank: int,
        local_rank: int,
        config: "VllmConfig",
    ):

        self.config = config

        if config.kv_transfer_config is None:
            raise ValueError("KVTransferConfig is not set in the VllmConfig,"
                             " cannot initialize KVConnector.")

        assert self.config.kv_transfer_config.is_kv_transfer_instance, "KV"\
            "TransferAgent should only be used when kv_connector is set."

        self.connector = KVConnectorFactory.create_connector_v0(
            rank, local_rank, config)

    def send_kv_caches_and_hidden_states(
        self,
        model_executable: torch.nn.Module,
        model_input: "ModelInputForGPUWithSamplingMetadata",
        kv_caches: list[torch.Tensor],
        hidden_or_intermediate_states: Union[torch.Tensor,
                                             IntermediateTensors],
    ) -> None:

        self.connector.send_kv_caches_and_hidden_states(
            model_executable, model_input, kv_caches,
            hidden_or_intermediate_states)

    def close(self) -> None:
        self.connector.close()

    def recv_kv_caches_and_hidden_states(
        self, model_executable: torch.nn.Module,
        model_input: "ModelInputForGPUWithSamplingMetadata",
        kv_caches: list[torch.Tensor]
    ) -> tuple[Union[torch.Tensor, IntermediateTensors], bool,
               "ModelInputForGPUWithSamplingMetadata"]:

        return self.connector.recv_kv_caches_and_hidden_states(
            model_executable, model_input, kv_caches)

config instance-attribute

config = config

connector instance-attribute

connector = create_connector_v0(rank, local_rank, config)

__init__

__init__(rank: int, local_rank: int, config: VllmConfig)
Source code in vllm/distributed/kv_transfer/kv_connector_agent.py
def __init__(
    self,
    rank: int,
    local_rank: int,
    config: "VllmConfig",
):

    self.config = config

    if config.kv_transfer_config is None:
        raise ValueError("KVTransferConfig is not set in the VllmConfig,"
                         " cannot initialize KVConnector.")

    assert self.config.kv_transfer_config.is_kv_transfer_instance, "KV"\
        "TransferAgent should only be used when kv_connector is set."

    self.connector = KVConnectorFactory.create_connector_v0(
        rank, local_rank, config)

close

close() -> None
Source code in vllm/distributed/kv_transfer/kv_connector_agent.py
def close(self) -> None:
    self.connector.close()

recv_kv_caches_and_hidden_states

recv_kv_caches_and_hidden_states(
    model_executable: Module,
    model_input: ModelInputForGPUWithSamplingMetadata,
    kv_caches: list[Tensor],
) -> tuple[
    Union[Tensor, IntermediateTensors],
    bool,
    ModelInputForGPUWithSamplingMetadata,
]
Source code in vllm/distributed/kv_transfer/kv_connector_agent.py
def recv_kv_caches_and_hidden_states(
    self, model_executable: torch.nn.Module,
    model_input: "ModelInputForGPUWithSamplingMetadata",
    kv_caches: list[torch.Tensor]
) -> tuple[Union[torch.Tensor, IntermediateTensors], bool,
           "ModelInputForGPUWithSamplingMetadata"]:

    return self.connector.recv_kv_caches_and_hidden_states(
        model_executable, model_input, kv_caches)

send_kv_caches_and_hidden_states

send_kv_caches_and_hidden_states(
    model_executable: Module,
    model_input: ModelInputForGPUWithSamplingMetadata,
    kv_caches: list[Tensor],
    hidden_or_intermediate_states: Union[
        Tensor, IntermediateTensors
    ],
) -> None
Source code in vllm/distributed/kv_transfer/kv_connector_agent.py
def send_kv_caches_and_hidden_states(
    self,
    model_executable: torch.nn.Module,
    model_input: "ModelInputForGPUWithSamplingMetadata",
    kv_caches: list[torch.Tensor],
    hidden_or_intermediate_states: Union[torch.Tensor,
                                         IntermediateTensors],
) -> None:

    self.connector.send_kv_caches_and_hidden_states(
        model_executable, model_input, kv_caches,
        hidden_or_intermediate_states)