Skip to content

vllm.distributed.kv_transfer.kv_connector.base

KVConnectorBase Class for Distributed KV Cache & Hidden State communication

The class provides two primary abstract methods: 1. send_kv_caches_and_hidden_states(): Send KV caches and hidden states 2. recv_kv_caches_and_hidden_states(): Recv KV caches and hidden states

KVConnectorBaseType module-attribute

KVConnectorBaseType = Union[
    KVConnectorBase, KVConnectorBase_V1
]

KVConnectorBase

Bases: ABC

Abstract base class for a KV connector.

The class provides two primary abstract methods: 1. send_kv_caches_and_hidden_states(): Send KV caches and hidden states 2. recv_kv_caches_and_hidden_states(): Recv KV caches and hidden states

Source code in vllm/distributed/kv_transfer/kv_connector/base.py
class KVConnectorBase(ABC):
    """
    Abstract base class for a KV connector.

    The class provides two primary abstract methods:
    1. send_kv_caches_and_hidden_states(): Send KV caches and hidden states
    2. recv_kv_caches_and_hidden_states(): Recv KV caches and hidden states
    """

    @abstractmethod
    def __init__(
        self,
        rank: int,
        local_rank: int,
        config: "VllmConfig",
    ):
        raise NotImplementedError

    @abstractmethod
    def close(self) -> None:
        """Close the buffer and release resources.

        This method is responsible for cleaning up resources related to the 
        connector when it is no longer needed.

        Raises:
            NotImplementedError: This method must be implemented in subclasses.
        """
        raise NotImplementedError

    @abstractmethod
    def send_kv_caches_and_hidden_states(
        self,
        model_executable: torch.nn.Module,
        model_input: "ModelInputForGPUWithSamplingMetadata",
        kv_caches: list[torch.Tensor],
        hidden_or_intermediate_states: Union[torch.Tensor,
                                             IntermediateTensors],
    ) -> None:
        """
        Send KV caches and hidden states to the connector.

        This method processes the input tokens, KV caches, and 
        hidden/intermediate states for a given model and sends the data to the 
        decode instance.

        Args:
            model_executable (torch.nn.Module): The model executable containing 
                start and end layer information.
            model_input (ModelInputForGPUWithSamplingMetadata): The input
                metadata from vLLM.
            kv_caches (list[torch.Tensor]): List of KV caches (keys and values) 
                for each layer.
            hidden_or_intermediate_states (Union[torch.Tensor, 
            IntermediateTensors]): 
                The hidden or intermediate states associated with the tokens.

        Returns:
            None

        """

        raise NotImplementedError

    @abstractmethod
    def recv_kv_caches_and_hidden_states(
        self, model_executable: torch.nn.Module,
        model_input: "ModelInputForGPUWithSamplingMetadata",
        kv_caches: list[torch.Tensor]
    ) -> tuple[Union[torch.Tensor, IntermediateTensors], bool,
               "ModelInputForGPUWithSamplingMetadata"]:
        """
        Receive KV caches and hidden states from the connector.

        This method attempts to retrieve KV caches and hidden states for input
        tokens. If all required KV caches and hidden states are received, it
        will bypass model input, else it will fall back to normal vLLM model 
        forwarding.

        Args:
            model_executable (torch.nn.Module): 
                The model executable from vLLM modelrunner.
            model_input (ModelInputForGPUWithSamplingMetadata): 
                The model input from vLLM modelrunner.
            kv_caches (list[torch.Tensor]): 
                List of KV caches for each layer.

        Returns:
            - hidden_or_intermediate_states (torch.Tensor or
            IntermediateTensors): 
                Concatenated hidden states if all required data is retrieved, 
                otherwise `None`.
            - bypass_model_exec (bool): 
                Indicates whether the model execution can be skipped (True) or 
                needs to be redone (False).
            - model_input (ModelInputForGPUWithSamplingMetadata): 
                Optionally adjusted input metadata for re-execution when 
                `bypass_model_exec=False`.

        """

        raise NotImplementedError

__init__ abstractmethod

__init__(rank: int, local_rank: int, config: VllmConfig)
Source code in vllm/distributed/kv_transfer/kv_connector/base.py
@abstractmethod
def __init__(
    self,
    rank: int,
    local_rank: int,
    config: "VllmConfig",
):
    raise NotImplementedError

close abstractmethod

close() -> None

Close the buffer and release resources.

This method is responsible for cleaning up resources related to the connector when it is no longer needed.

Raises:

Type Description
NotImplementedError

This method must be implemented in subclasses.

Source code in vllm/distributed/kv_transfer/kv_connector/base.py
@abstractmethod
def close(self) -> None:
    """Close the buffer and release resources.

    This method is responsible for cleaning up resources related to the 
    connector when it is no longer needed.

    Raises:
        NotImplementedError: This method must be implemented in subclasses.
    """
    raise NotImplementedError

recv_kv_caches_and_hidden_states abstractmethod

recv_kv_caches_and_hidden_states(
    model_executable: Module,
    model_input: ModelInputForGPUWithSamplingMetadata,
    kv_caches: list[Tensor],
) -> tuple[
    Union[Tensor, IntermediateTensors],
    bool,
    ModelInputForGPUWithSamplingMetadata,
]

Receive KV caches and hidden states from the connector.

This method attempts to retrieve KV caches and hidden states for input tokens. If all required KV caches and hidden states are received, it will bypass model input, else it will fall back to normal vLLM model forwarding.

Parameters:

Name Type Description Default
model_executable Module

The model executable from vLLM modelrunner.

required
model_input ModelInputForGPUWithSamplingMetadata

The model input from vLLM modelrunner.

required
kv_caches list[Tensor]

List of KV caches for each layer.

required

Returns:

Type Description
Union[Tensor, IntermediateTensors]
  • hidden_or_intermediate_states (torch.Tensor or
bool

IntermediateTensors): Concatenated hidden states if all required data is retrieved, otherwise None.

ModelInputForGPUWithSamplingMetadata
  • bypass_model_exec (bool): Indicates whether the model execution can be skipped (True) or needs to be redone (False).
tuple[Union[Tensor, IntermediateTensors], bool, ModelInputForGPUWithSamplingMetadata]
  • model_input (ModelInputForGPUWithSamplingMetadata): Optionally adjusted input metadata for re-execution when bypass_model_exec=False.
Source code in vllm/distributed/kv_transfer/kv_connector/base.py
@abstractmethod
def recv_kv_caches_and_hidden_states(
    self, model_executable: torch.nn.Module,
    model_input: "ModelInputForGPUWithSamplingMetadata",
    kv_caches: list[torch.Tensor]
) -> tuple[Union[torch.Tensor, IntermediateTensors], bool,
           "ModelInputForGPUWithSamplingMetadata"]:
    """
    Receive KV caches and hidden states from the connector.

    This method attempts to retrieve KV caches and hidden states for input
    tokens. If all required KV caches and hidden states are received, it
    will bypass model input, else it will fall back to normal vLLM model 
    forwarding.

    Args:
        model_executable (torch.nn.Module): 
            The model executable from vLLM modelrunner.
        model_input (ModelInputForGPUWithSamplingMetadata): 
            The model input from vLLM modelrunner.
        kv_caches (list[torch.Tensor]): 
            List of KV caches for each layer.

    Returns:
        - hidden_or_intermediate_states (torch.Tensor or
        IntermediateTensors): 
            Concatenated hidden states if all required data is retrieved, 
            otherwise `None`.
        - bypass_model_exec (bool): 
            Indicates whether the model execution can be skipped (True) or 
            needs to be redone (False).
        - model_input (ModelInputForGPUWithSamplingMetadata): 
            Optionally adjusted input metadata for re-execution when 
            `bypass_model_exec=False`.

    """

    raise NotImplementedError

send_kv_caches_and_hidden_states abstractmethod

send_kv_caches_and_hidden_states(
    model_executable: Module,
    model_input: ModelInputForGPUWithSamplingMetadata,
    kv_caches: list[Tensor],
    hidden_or_intermediate_states: Union[
        Tensor, IntermediateTensors
    ],
) -> None

Send KV caches and hidden states to the connector.

This method processes the input tokens, KV caches, and hidden/intermediate states for a given model and sends the data to the decode instance.

Parameters:

Name Type Description Default
model_executable Module

The model executable containing start and end layer information.

required
model_input ModelInputForGPUWithSamplingMetadata

The input metadata from vLLM.

required
kv_caches list[Tensor]

List of KV caches (keys and values) for each layer.

required
IntermediateTensors])

The hidden or intermediate states associated with the tokens.

required

Returns:

Type Description
None

None

Source code in vllm/distributed/kv_transfer/kv_connector/base.py
@abstractmethod
def send_kv_caches_and_hidden_states(
    self,
    model_executable: torch.nn.Module,
    model_input: "ModelInputForGPUWithSamplingMetadata",
    kv_caches: list[torch.Tensor],
    hidden_or_intermediate_states: Union[torch.Tensor,
                                         IntermediateTensors],
) -> None:
    """
    Send KV caches and hidden states to the connector.

    This method processes the input tokens, KV caches, and 
    hidden/intermediate states for a given model and sends the data to the 
    decode instance.

    Args:
        model_executable (torch.nn.Module): The model executable containing 
            start and end layer information.
        model_input (ModelInputForGPUWithSamplingMetadata): The input
            metadata from vLLM.
        kv_caches (list[torch.Tensor]): List of KV caches (keys and values) 
            for each layer.
        hidden_or_intermediate_states (Union[torch.Tensor, 
        IntermediateTensors]): 
            The hidden or intermediate states associated with the tokens.

    Returns:
        None

    """

    raise NotImplementedError