Skip to content

vllm.distributed.kv_transfer.kv_connector.v1

Modules:

Name Description
base

KVConnectorBase_V1 Class for Distributed KV Cache & Hidden State

lmcache_connector
multi_connector
nixl_connector
p2p
shared_storage_connector

__all__ module-attribute

__all__ = ['KVConnectorRole', 'KVConnectorBase_V1']

KVConnectorBase_V1

Bases: ABC

Source code in vllm/distributed/kv_transfer/kv_connector/v1/base.py
class KVConnectorBase_V1(ABC):

    def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
        logger.warning(
            "Initializing KVConnectorBase_V1. This API is experimental and "
            "subject to change in the future as we iterate the design.")
        self._connector_metadata = KVConnectorMetadata()
        self._vllm_config = vllm_config
        self._role = role

    @property
    def role(self) -> KVConnectorRole:
        return self._role

    # ==============================
    # Worker-side methods
    # ==============================

    def bind_connector_metadata(
            self, connector_metadata: KVConnectorMetadata) -> None:
        """Set the connector metadata from the scheduler.

        This function should be called by the model runner every time 
        before the model execution. The metadata will be used for runtime
        KV cache loading and saving.

        Args:
            connector_metadata (dict): the connector metadata.
        """
        self._connector_metadata = connector_metadata

    def clear_connector_metadata(self) -> None:
        """Clear the connector metadata.

        This function should be called by the model runner every time 
        after the model execution.
        """
        self._connector_metadata = KVConnectorMetadata()

    def _get_connector_metadata(self) -> KVConnectorMetadata:
        """Get the connector metadata.

        This function should only be called inside the connector.

        Returns:
            ConnectorMetadata: the connector metadata.
        """
        return self._connector_metadata

    def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
        """
        Initialize with the KV caches. Useful for pre-registering the
        KV Caches in the KVConnector (e.g. for NIXL).

        Args: kv_caches:
            dictionary of layer names, kv cache
        """
        return

    @abstractmethod
    def start_load_kv(self, forward_context: "ForwardContext",
                      **kwargs) -> None:
        """
        Start loading the KV cache from the connector to vLLM's paged
        KV buffer. This is called from the forward context before the
        forward pass to enable async loading during model execution.

        Args:
            forward_context (ForwardContext): the forward context.
            **kwargs: additional arguments for the load operation

        Note:
            The number of elements in kv_caches and layer_names should be 
            the same.

        """
        pass

    @abstractmethod
    def wait_for_layer_load(self, layer_name: str) -> None:
        """
        Block until the KV for a specific layer is loaded into vLLM's
        paged buffer. This is called from within attention layer to ensure
        async copying from start_load_kv is complete.

        This interface will be useful for layer-by-layer pipelining.

        Args:
            layer_name: the name of that layer
        """
        pass

    @abstractmethod
    def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
                      attn_metadata: "AttentionMetadata", **kwargs) -> None:
        """
        Start saving a layer of KV cache from vLLM's paged buffer 
        to the connector. This is called from within attention layer to
        enable async copying during execution.

        Args:
            layer_name (str): the name of the layer.
            kv_layer (torch.Tensor): the paged KV buffer of the current 
                layer in vLLM.
            attn_metadata (AttentionMetadata): the attention metadata.
            **kwargs: additional arguments for the save operation.
        """
        pass

    @abstractmethod
    def wait_for_save(self):
        """
        Block until all the save operations is done. This is called
        as the forward context exits to ensure that the async saving
        from save_kv_layer is complete before finishing the forward.

        This prevents overwrites of paged KV buffer before saving done.
        """
        pass

    def get_finished(
        self, finished_req_ids: set[str]
    ) -> tuple[Optional[set[str]], Optional[set[str]]]:
        """
        Notifies worker-side connector ids of requests that have
        finished generating tokens.

        Returns:
            ids of requests that have finished asynchronous transfer
            (requests that previously returned True from request_finished()),
            tuple of (sending/saving ids, recving/loading ids).
            The finished saves/sends req ids must belong to a set provided in a
            call to this method (this call or a prior one).
        """
        return None, None

    # ==============================
    # Scheduler-side methods
    # ==============================

    @abstractmethod
    def get_num_new_matched_tokens(
        self,
        request: "Request",
        num_computed_tokens: int,
    ) -> tuple[int, bool]:
        """
        Get number of new tokens that can be loaded from the
        external KV cache beyond the num_computed_tokens.

        Args:
            request (Request): the request object.
            num_computed_tokens (int): the number of locally
                computed tokens for this request

        Returns:
            A tuple with the following elements:
                - The number of tokens that can be loaded from the 
                  external KV cache beyond what is already computed.
                - `True` if external KV cache tokens will be loaded
                  asynchronously (between scheduler steps). Must be
                  'False' if the first element is 0.
        """
        pass

    @abstractmethod
    def update_state_after_alloc(self, request: "Request",
                                 blocks: "KVCacheBlocks",
                                 num_external_tokens: int):
        """
        Update KVConnector state after block allocation.

        If get_num_new_matched_tokens previously returned True for a
        request, this function may be called twice for that same request -
        first when blocks are allocated for the connector tokens to be
        asynchronously loaded into, and second when any additional blocks
        are allocated, after the load/transfer is complete.

        Args:
            request (Request): the request object.
            blocks (KVCacheBlocks): the blocks allocated for the request.
            num_external_tokens (int): the number of tokens that will be
                loaded from the external KV cache.
        """
        pass

    @abstractmethod
    def build_connector_meta(
            self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
        """
        Build the connector metadata for this step.

        This function should NOT modify fields in the scheduler_output.
        Also, calling this function will reset the state of the connector.

        Args:
            scheduler_output (SchedulerOutput): the scheduler output object.
        """
        pass

    def request_finished(
        self,
        request: "Request",
        block_ids: list[int],
    ) -> tuple[bool, Optional[dict[str, Any]]]:
        """
        Called when a request has finished, before its blocks are freed.

        Returns:
            True if the request is being saved/sent asynchronously and blocks
            should not be freed until the request_id is returned from
            get_finished().
            Optional KVTransferParams to be included in the request outputs
            returned by the engine.
        """
        return False, None

_connector_metadata instance-attribute

_connector_metadata = KVConnectorMetadata()

_role instance-attribute

_role = role

_vllm_config instance-attribute

_vllm_config = vllm_config

role property

__init__

__init__(vllm_config: VllmConfig, role: KVConnectorRole)
Source code in vllm/distributed/kv_transfer/kv_connector/v1/base.py
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
    logger.warning(
        "Initializing KVConnectorBase_V1. This API is experimental and "
        "subject to change in the future as we iterate the design.")
    self._connector_metadata = KVConnectorMetadata()
    self._vllm_config = vllm_config
    self._role = role

_get_connector_metadata

_get_connector_metadata() -> KVConnectorMetadata

Get the connector metadata.

This function should only be called inside the connector.

Returns:

Name Type Description
ConnectorMetadata KVConnectorMetadata

the connector metadata.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/base.py
def _get_connector_metadata(self) -> KVConnectorMetadata:
    """Get the connector metadata.

    This function should only be called inside the connector.

    Returns:
        ConnectorMetadata: the connector metadata.
    """
    return self._connector_metadata

bind_connector_metadata

bind_connector_metadata(
    connector_metadata: KVConnectorMetadata,
) -> None

Set the connector metadata from the scheduler.

This function should be called by the model runner every time before the model execution. The metadata will be used for runtime KV cache loading and saving.

Parameters:

Name Type Description Default
connector_metadata dict

the connector metadata.

required
Source code in vllm/distributed/kv_transfer/kv_connector/v1/base.py
def bind_connector_metadata(
        self, connector_metadata: KVConnectorMetadata) -> None:
    """Set the connector metadata from the scheduler.

    This function should be called by the model runner every time 
    before the model execution. The metadata will be used for runtime
    KV cache loading and saving.

    Args:
        connector_metadata (dict): the connector metadata.
    """
    self._connector_metadata = connector_metadata

build_connector_meta abstractmethod

build_connector_meta(
    scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata

Build the connector metadata for this step.

This function should NOT modify fields in the scheduler_output. Also, calling this function will reset the state of the connector.

Parameters:

Name Type Description Default
scheduler_output SchedulerOutput

the scheduler output object.

required
Source code in vllm/distributed/kv_transfer/kv_connector/v1/base.py
@abstractmethod
def build_connector_meta(
        self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
    """
    Build the connector metadata for this step.

    This function should NOT modify fields in the scheduler_output.
    Also, calling this function will reset the state of the connector.

    Args:
        scheduler_output (SchedulerOutput): the scheduler output object.
    """
    pass

clear_connector_metadata

clear_connector_metadata() -> None

Clear the connector metadata.

This function should be called by the model runner every time after the model execution.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/base.py
def clear_connector_metadata(self) -> None:
    """Clear the connector metadata.

    This function should be called by the model runner every time 
    after the model execution.
    """
    self._connector_metadata = KVConnectorMetadata()

get_finished

get_finished(
    finished_req_ids: set[str],
) -> tuple[Optional[set[str]], Optional[set[str]]]

Notifies worker-side connector ids of requests that have finished generating tokens.

Returns:

Type Description
Optional[set[str]]

ids of requests that have finished asynchronous transfer

Optional[set[str]]

(requests that previously returned True from request_finished()),

tuple[Optional[set[str]], Optional[set[str]]]

tuple of (sending/saving ids, recving/loading ids).

tuple[Optional[set[str]], Optional[set[str]]]

The finished saves/sends req ids must belong to a set provided in a

tuple[Optional[set[str]], Optional[set[str]]]

call to this method (this call or a prior one).

Source code in vllm/distributed/kv_transfer/kv_connector/v1/base.py
def get_finished(
    self, finished_req_ids: set[str]
) -> tuple[Optional[set[str]], Optional[set[str]]]:
    """
    Notifies worker-side connector ids of requests that have
    finished generating tokens.

    Returns:
        ids of requests that have finished asynchronous transfer
        (requests that previously returned True from request_finished()),
        tuple of (sending/saving ids, recving/loading ids).
        The finished saves/sends req ids must belong to a set provided in a
        call to this method (this call or a prior one).
    """
    return None, None

get_num_new_matched_tokens abstractmethod

get_num_new_matched_tokens(
    request: Request, num_computed_tokens: int
) -> tuple[int, bool]

Get number of new tokens that can be loaded from the external KV cache beyond the num_computed_tokens.

Parameters:

Name Type Description Default
request Request

the request object.

required
num_computed_tokens int

the number of locally computed tokens for this request

required

Returns:

Type Description
tuple[int, bool]

A tuple with the following elements: - The number of tokens that can be loaded from the external KV cache beyond what is already computed. - True if external KV cache tokens will be loaded asynchronously (between scheduler steps). Must be 'False' if the first element is 0.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/base.py
@abstractmethod
def get_num_new_matched_tokens(
    self,
    request: "Request",
    num_computed_tokens: int,
) -> tuple[int, bool]:
    """
    Get number of new tokens that can be loaded from the
    external KV cache beyond the num_computed_tokens.

    Args:
        request (Request): the request object.
        num_computed_tokens (int): the number of locally
            computed tokens for this request

    Returns:
        A tuple with the following elements:
            - The number of tokens that can be loaded from the 
              external KV cache beyond what is already computed.
            - `True` if external KV cache tokens will be loaded
              asynchronously (between scheduler steps). Must be
              'False' if the first element is 0.
    """
    pass

register_kv_caches

register_kv_caches(kv_caches: dict[str, Tensor])

Initialize with the KV caches. Useful for pre-registering the KV Caches in the KVConnector (e.g. for NIXL).

Source code in vllm/distributed/kv_transfer/kv_connector/v1/base.py
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
    """
    Initialize with the KV caches. Useful for pre-registering the
    KV Caches in the KVConnector (e.g. for NIXL).

    Args: kv_caches:
        dictionary of layer names, kv cache
    """
    return

request_finished

request_finished(
    request: Request, block_ids: list[int]
) -> tuple[bool, Optional[dict[str, Any]]]

Called when a request has finished, before its blocks are freed.

Returns:

Type Description
bool

True if the request is being saved/sent asynchronously and blocks

Optional[dict[str, Any]]

should not be freed until the request_id is returned from

tuple[bool, Optional[dict[str, Any]]]

get_finished().

tuple[bool, Optional[dict[str, Any]]]

Optional KVTransferParams to be included in the request outputs

tuple[bool, Optional[dict[str, Any]]]

returned by the engine.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/base.py
def request_finished(
    self,
    request: "Request",
    block_ids: list[int],
) -> tuple[bool, Optional[dict[str, Any]]]:
    """
    Called when a request has finished, before its blocks are freed.

    Returns:
        True if the request is being saved/sent asynchronously and blocks
        should not be freed until the request_id is returned from
        get_finished().
        Optional KVTransferParams to be included in the request outputs
        returned by the engine.
    """
    return False, None

save_kv_layer abstractmethod

save_kv_layer(
    layer_name: str,
    kv_layer: Tensor,
    attn_metadata: AttentionMetadata,
    **kwargs,
) -> None

Start saving a layer of KV cache from vLLM's paged buffer to the connector. This is called from within attention layer to enable async copying during execution.

Parameters:

Name Type Description Default
layer_name str

the name of the layer.

required
kv_layer Tensor

the paged KV buffer of the current layer in vLLM.

required
attn_metadata AttentionMetadata

the attention metadata.

required
**kwargs

additional arguments for the save operation.

{}
Source code in vllm/distributed/kv_transfer/kv_connector/v1/base.py
@abstractmethod
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
                  attn_metadata: "AttentionMetadata", **kwargs) -> None:
    """
    Start saving a layer of KV cache from vLLM's paged buffer 
    to the connector. This is called from within attention layer to
    enable async copying during execution.

    Args:
        layer_name (str): the name of the layer.
        kv_layer (torch.Tensor): the paged KV buffer of the current 
            layer in vLLM.
        attn_metadata (AttentionMetadata): the attention metadata.
        **kwargs: additional arguments for the save operation.
    """
    pass

start_load_kv abstractmethod

start_load_kv(
    forward_context: ForwardContext, **kwargs
) -> None

Start loading the KV cache from the connector to vLLM's paged KV buffer. This is called from the forward context before the forward pass to enable async loading during model execution.

Parameters:

Name Type Description Default
forward_context ForwardContext

the forward context.

required
**kwargs

additional arguments for the load operation

{}
Note

The number of elements in kv_caches and layer_names should be the same.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/base.py
@abstractmethod
def start_load_kv(self, forward_context: "ForwardContext",
                  **kwargs) -> None:
    """
    Start loading the KV cache from the connector to vLLM's paged
    KV buffer. This is called from the forward context before the
    forward pass to enable async loading during model execution.

    Args:
        forward_context (ForwardContext): the forward context.
        **kwargs: additional arguments for the load operation

    Note:
        The number of elements in kv_caches and layer_names should be 
        the same.

    """
    pass

update_state_after_alloc abstractmethod

update_state_after_alloc(
    request: Request,
    blocks: KVCacheBlocks,
    num_external_tokens: int,
)

Update KVConnector state after block allocation.

If get_num_new_matched_tokens previously returned True for a request, this function may be called twice for that same request - first when blocks are allocated for the connector tokens to be asynchronously loaded into, and second when any additional blocks are allocated, after the load/transfer is complete.

Parameters:

Name Type Description Default
request Request

the request object.

required
blocks KVCacheBlocks

the blocks allocated for the request.

required
num_external_tokens int

the number of tokens that will be loaded from the external KV cache.

required
Source code in vllm/distributed/kv_transfer/kv_connector/v1/base.py
@abstractmethod
def update_state_after_alloc(self, request: "Request",
                             blocks: "KVCacheBlocks",
                             num_external_tokens: int):
    """
    Update KVConnector state after block allocation.

    If get_num_new_matched_tokens previously returned True for a
    request, this function may be called twice for that same request -
    first when blocks are allocated for the connector tokens to be
    asynchronously loaded into, and second when any additional blocks
    are allocated, after the load/transfer is complete.

    Args:
        request (Request): the request object.
        blocks (KVCacheBlocks): the blocks allocated for the request.
        num_external_tokens (int): the number of tokens that will be
            loaded from the external KV cache.
    """
    pass

wait_for_layer_load abstractmethod

wait_for_layer_load(layer_name: str) -> None

Block until the KV for a specific layer is loaded into vLLM's paged buffer. This is called from within attention layer to ensure async copying from start_load_kv is complete.

This interface will be useful for layer-by-layer pipelining.

Parameters:

Name Type Description Default
layer_name str

the name of that layer

required
Source code in vllm/distributed/kv_transfer/kv_connector/v1/base.py
@abstractmethod
def wait_for_layer_load(self, layer_name: str) -> None:
    """
    Block until the KV for a specific layer is loaded into vLLM's
    paged buffer. This is called from within attention layer to ensure
    async copying from start_load_kv is complete.

    This interface will be useful for layer-by-layer pipelining.

    Args:
        layer_name: the name of that layer
    """
    pass

wait_for_save abstractmethod

wait_for_save()

Block until all the save operations is done. This is called as the forward context exits to ensure that the async saving from save_kv_layer is complete before finishing the forward.

This prevents overwrites of paged KV buffer before saving done.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/base.py
@abstractmethod
def wait_for_save(self):
    """
    Block until all the save operations is done. This is called
    as the forward context exits to ensure that the async saving
    from save_kv_layer is complete before finishing the forward.

    This prevents overwrites of paged KV buffer before saving done.
    """
    pass

KVConnectorRole

Bases: Enum

Source code in vllm/distributed/kv_transfer/kv_connector/v1/base.py
class KVConnectorRole(enum.Enum):
    # Connector running in the scheduler process
    SCHEDULER = 0

    # Connector running in the worker process
    WORKER = 1

SCHEDULER class-attribute instance-attribute

SCHEDULER = 0

WORKER class-attribute instance-attribute

WORKER = 1