Skip to content

vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector

logger module-attribute

logger = init_logger(__name__)

LMCacheConnectorV1

Bases: KVConnectorBase_V1

Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py
class LMCacheConnectorV1(KVConnectorBase_V1):

    def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
        super().__init__(vllm_config=vllm_config, role=role)
        self._lmcache_engine = LMCacheConnectorV1Impl(vllm_config, role, self)

    # ==============================
    # Worker-side methods
    # ==============================
    def start_load_kv(self, forward_context: "ForwardContext",
                      **kwargs) -> None:
        """
        Start loading the KV cache from the connector to vLLM's paged
        KV buffer. This is called from the forward context before the
        forward pass to enable async loading during model execution.

        Args:
            forward_context (ForwardContext): the forward context.
            **kwargs: additional arguments for the load operation

        Note:
            The number of elements in kv_caches and layer_names should be 
            the same.

        """
        self._lmcache_engine.start_load_kv(forward_context, **kwargs)

    def wait_for_layer_load(self, layer_name: str) -> None:
        """
        Block until the KV for a specific layer is loaded into vLLM's
        paged buffer. This is called from within attention layer to ensure
        async copying from start_load_kv is complete.

        This interface will be useful for layer-by-layer pipelining.

        Args:
            layer_name: the name of that layer
        """
        self._lmcache_engine.wait_for_layer_load(layer_name)

    def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
                      attn_metadata: "AttentionMetadata", **kwargs) -> None:
        """
        Start saving the a layer of KV cache from vLLM's paged buffer 
        to the connector. This is called from within attention layer to
        enable async copying during execution.

        Args:
            layer_name (str): the name of the layer.
            kv_layer (torch.Tensor): the paged KV buffer of the current 
                layer in vLLM.
            attn_metadata (AttentionMetadata): the attention metadata.
            **kwargs: additional arguments for the save operation.
        """
        self._lmcache_engine.save_kv_layer(layer_name, kv_layer, attn_metadata,
                                           **kwargs)

    def wait_for_save(self):
        """
        Block until all the save operations is done. This is called
        as the forward context exits to ensure that the async saving
        from save_kv_layer is complete before finishing the forward.

        This prevents overwrites of paged KV buffer before saving done.
        """
        self._lmcache_engine.wait_for_save()

    def get_finished(
        self, finished_req_ids: set[str]
    ) -> tuple[Optional[set[str]], Optional[set[str]]]:
        """
        Notifies worker-side connector ids of requests that have
        finished generating tokens.

        Returns:
            ids of requests that have finished asynchronous transfer
            (requests that previously returned True from request_finished()),
            tuple of (sending/saving ids, recving/loading ids).
            The finished saves/sends req ids must belong to a set provided in a
            call to this method (this call or a prior one).
        """
        return self._lmcache_engine.get_finished(finished_req_ids)

    # ==============================
    # Scheduler-side methods
    # ==============================
    def get_num_new_matched_tokens(
        self,
        request: "Request",
        num_computed_tokens: int,
    ) -> tuple[int, bool]:
        """
        Get number of new tokens that can be loaded from the
        external KV cache beyond the num_computed_tokens.

        Args:
            request (Request): the request object.
            num_computed_tokens (int): the number of locally
                computed tokens for this request

        Returns:
            the number of tokens that can be loaded from the 
            external KV cache beyond what is already computed.
        """
        return self._lmcache_engine.get_num_new_matched_tokens(
            request, num_computed_tokens), False

    def update_state_after_alloc(self, request: "Request",
                                 blocks: "KVCacheBlocks",
                                 num_external_tokens: int):
        """
        Update KVConnector state after block allocation.
        """
        self._lmcache_engine.update_state_after_alloc(request,
                                                      num_external_tokens)

    def build_connector_meta(
            self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
        """
        Build the connector metadata for this step.

        This function should NOT modify fields in the scheduler_output.
        Also, calling this function will reset the state of the connector.

        Args:
            scheduler_output (SchedulerOutput): the scheduler output object.
        """
        return self._lmcache_engine.build_connector_meta(scheduler_output)

    def request_finished(
        self,
        request: "Request",
        block_ids: list[int],
    ) -> tuple[bool, Optional[dict[str, Any]]]:
        """
        Called when a request has finished, before its blocks are freed.

        Returns:
            True if the request is being saved/sent asynchronously and blocks
            should not be freed until the request_id is returned from
            get_finished().
            Optional KVTransferParams to be included in the request outputs
            returned by the engine.
        """
        return self._lmcache_engine.request_finished(request, block_ids)

_lmcache_engine instance-attribute

_lmcache_engine = LMCacheConnectorV1Impl(
    vllm_config, role, self
)

__init__

__init__(vllm_config: VllmConfig, role: KVConnectorRole)
Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
    super().__init__(vllm_config=vllm_config, role=role)
    self._lmcache_engine = LMCacheConnectorV1Impl(vllm_config, role, self)

build_connector_meta

build_connector_meta(
    scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata

Build the connector metadata for this step.

This function should NOT modify fields in the scheduler_output. Also, calling this function will reset the state of the connector.

Parameters:

Name Type Description Default
scheduler_output SchedulerOutput

the scheduler output object.

required
Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py
def build_connector_meta(
        self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
    """
    Build the connector metadata for this step.

    This function should NOT modify fields in the scheduler_output.
    Also, calling this function will reset the state of the connector.

    Args:
        scheduler_output (SchedulerOutput): the scheduler output object.
    """
    return self._lmcache_engine.build_connector_meta(scheduler_output)

get_finished

get_finished(
    finished_req_ids: set[str],
) -> tuple[Optional[set[str]], Optional[set[str]]]

Notifies worker-side connector ids of requests that have finished generating tokens.

Returns:

Type Description
Optional[set[str]]

ids of requests that have finished asynchronous transfer

Optional[set[str]]

(requests that previously returned True from request_finished()),

tuple[Optional[set[str]], Optional[set[str]]]

tuple of (sending/saving ids, recving/loading ids).

tuple[Optional[set[str]], Optional[set[str]]]

The finished saves/sends req ids must belong to a set provided in a

tuple[Optional[set[str]], Optional[set[str]]]

call to this method (this call or a prior one).

Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py
def get_finished(
    self, finished_req_ids: set[str]
) -> tuple[Optional[set[str]], Optional[set[str]]]:
    """
    Notifies worker-side connector ids of requests that have
    finished generating tokens.

    Returns:
        ids of requests that have finished asynchronous transfer
        (requests that previously returned True from request_finished()),
        tuple of (sending/saving ids, recving/loading ids).
        The finished saves/sends req ids must belong to a set provided in a
        call to this method (this call or a prior one).
    """
    return self._lmcache_engine.get_finished(finished_req_ids)

get_num_new_matched_tokens

get_num_new_matched_tokens(
    request: Request, num_computed_tokens: int
) -> tuple[int, bool]

Get number of new tokens that can be loaded from the external KV cache beyond the num_computed_tokens.

Parameters:

Name Type Description Default
request Request

the request object.

required
num_computed_tokens int

the number of locally computed tokens for this request

required

Returns:

Type Description
int

the number of tokens that can be loaded from the

bool

external KV cache beyond what is already computed.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py
def get_num_new_matched_tokens(
    self,
    request: "Request",
    num_computed_tokens: int,
) -> tuple[int, bool]:
    """
    Get number of new tokens that can be loaded from the
    external KV cache beyond the num_computed_tokens.

    Args:
        request (Request): the request object.
        num_computed_tokens (int): the number of locally
            computed tokens for this request

    Returns:
        the number of tokens that can be loaded from the 
        external KV cache beyond what is already computed.
    """
    return self._lmcache_engine.get_num_new_matched_tokens(
        request, num_computed_tokens), False

request_finished

request_finished(
    request: Request, block_ids: list[int]
) -> tuple[bool, Optional[dict[str, Any]]]

Called when a request has finished, before its blocks are freed.

Returns:

Type Description
bool

True if the request is being saved/sent asynchronously and blocks

Optional[dict[str, Any]]

should not be freed until the request_id is returned from

tuple[bool, Optional[dict[str, Any]]]

get_finished().

tuple[bool, Optional[dict[str, Any]]]

Optional KVTransferParams to be included in the request outputs

tuple[bool, Optional[dict[str, Any]]]

returned by the engine.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py
def request_finished(
    self,
    request: "Request",
    block_ids: list[int],
) -> tuple[bool, Optional[dict[str, Any]]]:
    """
    Called when a request has finished, before its blocks are freed.

    Returns:
        True if the request is being saved/sent asynchronously and blocks
        should not be freed until the request_id is returned from
        get_finished().
        Optional KVTransferParams to be included in the request outputs
        returned by the engine.
    """
    return self._lmcache_engine.request_finished(request, block_ids)

save_kv_layer

save_kv_layer(
    layer_name: str,
    kv_layer: Tensor,
    attn_metadata: AttentionMetadata,
    **kwargs,
) -> None

Start saving the a layer of KV cache from vLLM's paged buffer to the connector. This is called from within attention layer to enable async copying during execution.

Parameters:

Name Type Description Default
layer_name str

the name of the layer.

required
kv_layer Tensor

the paged KV buffer of the current layer in vLLM.

required
attn_metadata AttentionMetadata

the attention metadata.

required
**kwargs

additional arguments for the save operation.

{}
Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
                  attn_metadata: "AttentionMetadata", **kwargs) -> None:
    """
    Start saving the a layer of KV cache from vLLM's paged buffer 
    to the connector. This is called from within attention layer to
    enable async copying during execution.

    Args:
        layer_name (str): the name of the layer.
        kv_layer (torch.Tensor): the paged KV buffer of the current 
            layer in vLLM.
        attn_metadata (AttentionMetadata): the attention metadata.
        **kwargs: additional arguments for the save operation.
    """
    self._lmcache_engine.save_kv_layer(layer_name, kv_layer, attn_metadata,
                                       **kwargs)

start_load_kv

start_load_kv(
    forward_context: ForwardContext, **kwargs
) -> None

Start loading the KV cache from the connector to vLLM's paged KV buffer. This is called from the forward context before the forward pass to enable async loading during model execution.

Parameters:

Name Type Description Default
forward_context ForwardContext

the forward context.

required
**kwargs

additional arguments for the load operation

{}
Note

The number of elements in kv_caches and layer_names should be the same.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py
def start_load_kv(self, forward_context: "ForwardContext",
                  **kwargs) -> None:
    """
    Start loading the KV cache from the connector to vLLM's paged
    KV buffer. This is called from the forward context before the
    forward pass to enable async loading during model execution.

    Args:
        forward_context (ForwardContext): the forward context.
        **kwargs: additional arguments for the load operation

    Note:
        The number of elements in kv_caches and layer_names should be 
        the same.

    """
    self._lmcache_engine.start_load_kv(forward_context, **kwargs)

update_state_after_alloc

update_state_after_alloc(
    request: Request,
    blocks: KVCacheBlocks,
    num_external_tokens: int,
)

Update KVConnector state after block allocation.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py
def update_state_after_alloc(self, request: "Request",
                             blocks: "KVCacheBlocks",
                             num_external_tokens: int):
    """
    Update KVConnector state after block allocation.
    """
    self._lmcache_engine.update_state_after_alloc(request,
                                                  num_external_tokens)

wait_for_layer_load

wait_for_layer_load(layer_name: str) -> None

Block until the KV for a specific layer is loaded into vLLM's paged buffer. This is called from within attention layer to ensure async copying from start_load_kv is complete.

This interface will be useful for layer-by-layer pipelining.

Parameters:

Name Type Description Default
layer_name str

the name of that layer

required
Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py
def wait_for_layer_load(self, layer_name: str) -> None:
    """
    Block until the KV for a specific layer is loaded into vLLM's
    paged buffer. This is called from within attention layer to ensure
    async copying from start_load_kv is complete.

    This interface will be useful for layer-by-layer pipelining.

    Args:
        layer_name: the name of that layer
    """
    self._lmcache_engine.wait_for_layer_load(layer_name)

wait_for_save

wait_for_save()

Block until all the save operations is done. This is called as the forward context exits to ensure that the async saving from save_kv_layer is complete before finishing the forward.

This prevents overwrites of paged KV buffer before saving done.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py
def wait_for_save(self):
    """
    Block until all the save operations is done. This is called
    as the forward context exits to ensure that the async saving
    from save_kv_layer is complete before finishing the forward.

    This prevents overwrites of paged KV buffer before saving done.
    """
    self._lmcache_engine.wait_for_save()