Skip to content

vllm.distributed.kv_transfer.kv_connector.lmcache_connector

LMCache KV Cache Connector for Distributed Machine Learning Inference

The LMCacheConnector can (1) transfer KV caches between prefill vLLM worker (KV cache producer) and decode vLLM worker (KV cache consumer) using LMCache; (2) offload and share KV caches.

logger module-attribute

logger = init_logger(__name__)

LMCacheConnector

Bases: KVConnectorBase

Source code in vllm/distributed/kv_transfer/kv_connector/lmcache_connector.py
class LMCacheConnector(KVConnectorBase):

    def __init__(
        self,
        rank: int,
        local_rank: int,
        config: VllmConfig,
    ):

        self.transfer_config = config.kv_transfer_config
        self.vllm_config = config

        from lmcache.experimental.cache_engine import LMCacheEngineBuilder
        from lmcache.integration.vllm.utils import ENGINE_NAME
        from lmcache.integration.vllm.vllm_adapter import (
            RetrieveStatus, StoreStatus, init_lmcache_engine,
            lmcache_retrieve_kv, lmcache_should_retrieve, lmcache_should_store,
            lmcache_store_kv)
        logger.info("Initializing LMCacheConfig under kv_transfer_config %s",
                    self.transfer_config)

        # TODO (Jiayi): Find model_config, parallel_config, and cache_config
        self.engine = init_lmcache_engine(config.model_config,
                                          config.parallel_config,
                                          config.cache_config)
        self.lmcache_engine_name = ENGINE_NAME
        self.lmcache_engine_builder = LMCacheEngineBuilder

        self.model_config = config.model_config
        self.parallel_config = config.parallel_config
        self.cache_config = config.cache_config
        self.lmcache_retrieve_kv = lmcache_retrieve_kv
        self.lmcache_store_kv = lmcache_store_kv
        self.lmcache_should_retrieve = lmcache_should_retrieve
        self.lmcache_should_store = lmcache_should_store
        self.store_status = StoreStatus
        self.retrieve_status = RetrieveStatus

    def recv_kv_caches_and_hidden_states(
        self, model_executable: torch.nn.Module,
        model_input: "ModelInputForGPUWithSamplingMetadata",
        kv_caches: list[torch.Tensor]
    ) -> tuple[Union[torch.Tensor, IntermediateTensors], bool,
               "ModelInputForGPUWithSamplingMetadata"]:

        retrieve_status = self.lmcache_should_retrieve(model_input)
        model_input, bypass_model_exec, hidden_or_intermediate_states =\
            self.lmcache_retrieve_kv(
                model_executable, model_input, self.cache_config, kv_caches,
                retrieve_status)
        return hidden_or_intermediate_states, bypass_model_exec, model_input

    def send_kv_caches_and_hidden_states(
        self,
        model_executable: torch.nn.Module,
        model_input: "ModelInputForGPUWithSamplingMetadata",
        kv_caches: list[torch.Tensor],
        hidden_or_intermediate_states: Union[torch.Tensor,
                                             IntermediateTensors],
    ) -> None:

        store_status = self.lmcache_should_store(model_input)
        self.lmcache_store_kv(
            self.model_config,
            self.parallel_config,
            self.cache_config,
            model_executable,
            model_input,
            kv_caches,
            store_status,
        )

    def close(self):
        self.lmcache_engine_builder.destroy(self.lmcache_engine_name)

cache_config instance-attribute

cache_config = cache_config

engine instance-attribute

engine = init_lmcache_engine(
    model_config, parallel_config, cache_config
)

lmcache_engine_builder instance-attribute

lmcache_engine_builder = LMCacheEngineBuilder

lmcache_engine_name instance-attribute

lmcache_engine_name = ENGINE_NAME

lmcache_retrieve_kv instance-attribute

lmcache_retrieve_kv = lmcache_retrieve_kv

lmcache_should_retrieve instance-attribute

lmcache_should_retrieve = lmcache_should_retrieve

lmcache_should_store instance-attribute

lmcache_should_store = lmcache_should_store

lmcache_store_kv instance-attribute

lmcache_store_kv = lmcache_store_kv

model_config instance-attribute

model_config = model_config

parallel_config instance-attribute

parallel_config = parallel_config

retrieve_status instance-attribute

retrieve_status = RetrieveStatus

store_status instance-attribute

store_status = StoreStatus

transfer_config instance-attribute

transfer_config = kv_transfer_config

vllm_config instance-attribute

vllm_config = config

__init__

__init__(rank: int, local_rank: int, config: VllmConfig)
Source code in vllm/distributed/kv_transfer/kv_connector/lmcache_connector.py
def __init__(
    self,
    rank: int,
    local_rank: int,
    config: VllmConfig,
):

    self.transfer_config = config.kv_transfer_config
    self.vllm_config = config

    from lmcache.experimental.cache_engine import LMCacheEngineBuilder
    from lmcache.integration.vllm.utils import ENGINE_NAME
    from lmcache.integration.vllm.vllm_adapter import (
        RetrieveStatus, StoreStatus, init_lmcache_engine,
        lmcache_retrieve_kv, lmcache_should_retrieve, lmcache_should_store,
        lmcache_store_kv)
    logger.info("Initializing LMCacheConfig under kv_transfer_config %s",
                self.transfer_config)

    # TODO (Jiayi): Find model_config, parallel_config, and cache_config
    self.engine = init_lmcache_engine(config.model_config,
                                      config.parallel_config,
                                      config.cache_config)
    self.lmcache_engine_name = ENGINE_NAME
    self.lmcache_engine_builder = LMCacheEngineBuilder

    self.model_config = config.model_config
    self.parallel_config = config.parallel_config
    self.cache_config = config.cache_config
    self.lmcache_retrieve_kv = lmcache_retrieve_kv
    self.lmcache_store_kv = lmcache_store_kv
    self.lmcache_should_retrieve = lmcache_should_retrieve
    self.lmcache_should_store = lmcache_should_store
    self.store_status = StoreStatus
    self.retrieve_status = RetrieveStatus

close

close()
Source code in vllm/distributed/kv_transfer/kv_connector/lmcache_connector.py
def close(self):
    self.lmcache_engine_builder.destroy(self.lmcache_engine_name)

recv_kv_caches_and_hidden_states

recv_kv_caches_and_hidden_states(
    model_executable: Module,
    model_input: ModelInputForGPUWithSamplingMetadata,
    kv_caches: list[Tensor],
) -> tuple[
    Union[Tensor, IntermediateTensors],
    bool,
    ModelInputForGPUWithSamplingMetadata,
]
Source code in vllm/distributed/kv_transfer/kv_connector/lmcache_connector.py
def recv_kv_caches_and_hidden_states(
    self, model_executable: torch.nn.Module,
    model_input: "ModelInputForGPUWithSamplingMetadata",
    kv_caches: list[torch.Tensor]
) -> tuple[Union[torch.Tensor, IntermediateTensors], bool,
           "ModelInputForGPUWithSamplingMetadata"]:

    retrieve_status = self.lmcache_should_retrieve(model_input)
    model_input, bypass_model_exec, hidden_or_intermediate_states =\
        self.lmcache_retrieve_kv(
            model_executable, model_input, self.cache_config, kv_caches,
            retrieve_status)
    return hidden_or_intermediate_states, bypass_model_exec, model_input

send_kv_caches_and_hidden_states

send_kv_caches_and_hidden_states(
    model_executable: Module,
    model_input: ModelInputForGPUWithSamplingMetadata,
    kv_caches: list[Tensor],
    hidden_or_intermediate_states: Union[
        Tensor, IntermediateTensors
    ],
) -> None
Source code in vllm/distributed/kv_transfer/kv_connector/lmcache_connector.py
def send_kv_caches_and_hidden_states(
    self,
    model_executable: torch.nn.Module,
    model_input: "ModelInputForGPUWithSamplingMetadata",
    kv_caches: list[torch.Tensor],
    hidden_or_intermediate_states: Union[torch.Tensor,
                                         IntermediateTensors],
) -> None:

    store_status = self.lmcache_should_store(model_input)
    self.lmcache_store_kv(
        self.model_config,
        self.parallel_config,
        self.cache_config,
        model_executable,
        model_input,
        kv_caches,
        store_status,
    )