Skip to content

vllm.distributed.kv_transfer.kv_connector.v1.mooncake.store.connector

MooncakeStoreConnector - KV cache connector using MooncakeDistributedStore.

Unlike MooncakeConnector which does direct P2P transfer, this connector uses MooncakeDistributedStore as a shared KV cache pool. Both producer and consumer instances read/write KV to/from the store independently, enabling prefix caching via hash-based deduplication.

MooncakeStoreConnector

Bases: KVConnectorBase_V1, SupportsHMA

KV connector using MooncakeDistributedStore as shared KV pool.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/mooncake/store/connector.py
class MooncakeStoreConnector(KVConnectorBase_V1, SupportsHMA):
    """KV connector using MooncakeDistributedStore as shared KV pool."""

    @property
    def prefer_cross_layer_blocks(self) -> bool:
        extra_config = self._kv_transfer_config.kv_connector_extra_config
        return (
            str(extra_config.get("enable_cross_layers_blocks", "False")).lower()
            == "true"
        )

    @staticmethod
    def _validate_kv_cache_config(
        vllm_config: VllmConfig, kv_cache_config: KVCacheConfig
    ) -> None:
        from vllm.v1.kv_cache_interface import CrossAttentionSpec, MambaSpec

        unsupported: list[str] = []
        cache_block_size = vllm_config.cache_config.block_size
        for g_idx, g in enumerate(kv_cache_config.kv_cache_groups):
            spec = g.kv_cache_spec
            if isinstance(spec, CrossAttentionSpec):
                unsupported.append(f"group {g_idx}: CrossAttentionSpec")
            # Enforce Mamba align mode
            if isinstance(spec, MambaSpec) and spec.block_size != cache_block_size:
                unsupported.append(
                    f"group {g_idx}: MambaSpec with block_size="
                    f"{spec.block_size} != cache_config.block_size="
                    f"{cache_block_size} (mamba_cache_mode != 'align')"
                )
        pcp = vllm_config.parallel_config.prefill_context_parallel_size
        dcp = vllm_config.parallel_config.decode_context_parallel_size
        if len(kv_cache_config.kv_cache_groups) > 1 and pcp * dcp > 1:
            unsupported.append(
                f"PCP/DCP > 1 (pcp={pcp}, dcp={dcp}) with hybrid attention"
            )
        if unsupported:
            raise ValueError(
                "MooncakeStoreConnector does not support: " + "; ".join(unsupported)
            )

    def __init__(
        self,
        vllm_config: VllmConfig,
        role: KVConnectorRole,
        kv_cache_config: KVCacheConfig | None = None,
    ):
        super().__init__(
            vllm_config=vllm_config,
            role=role,
            kv_cache_config=kv_cache_config,  # type: ignore[arg-type]
        )
        assert vllm_config.kv_transfer_config is not None
        assert kv_cache_config is not None, "kv_cache_config is required"
        self._validate_kv_cache_config(vllm_config, kv_cache_config)
        self._kv_cache_config = kv_cache_config
        self.kv_role = vllm_config.kv_transfer_config.kv_role
        self._kv_cache_events: MooncakeStoreKVEvents | None = None

        self.connector_scheduler: MooncakeStoreScheduler | None = None
        self.connector_worker: MooncakeStoreWorker | None = None

        if role == KVConnectorRole.SCHEDULER:
            self.connector_scheduler = MooncakeStoreScheduler(
                vllm_config, kv_cache_config
            )
        else:
            self.connector_worker = MooncakeStoreWorker(vllm_config, kv_cache_config)

    # ============================================================
    # Scheduler-side methods
    # ============================================================

    def get_num_new_matched_tokens(
        self,
        request: Request,
        num_computed_tokens: int,
    ) -> tuple[int, bool]:
        assert self.connector_scheduler is not None
        return self.connector_scheduler.get_num_new_matched_tokens(
            request, num_computed_tokens
        )

    def update_state_after_alloc(
        self,
        request: Request,
        blocks: KVCacheBlocks,
        num_external_tokens: int,
    ):
        assert self.connector_scheduler is not None
        return self.connector_scheduler.update_state_after_alloc(
            request, blocks, num_external_tokens
        )

    def build_connector_meta(
        self,
        scheduler_output: SchedulerOutput,
    ) -> KVConnectorMetadata:
        assert self.connector_scheduler is not None
        return self.connector_scheduler.build_connector_meta(scheduler_output)

    def request_finished(
        self,
        request: Request,
        block_ids: list[int],
    ) -> tuple[bool, dict[str, Any] | None]:
        return self.request_finished_all_groups(request, (block_ids,))

    def request_finished_all_groups(
        self,
        request: Request,
        block_ids: tuple[list[int], ...],
    ) -> tuple[bool, dict[str, Any] | None]:
        assert self.connector_scheduler is not None
        return self.connector_scheduler.request_finished(request, block_ids)

    def update_connector_output(self, connector_output: KVConnectorOutput):
        kv_cache_events = connector_output.kv_cache_events
        if not kv_cache_events or not isinstance(
            kv_cache_events, MooncakeStoreKVEvents
        ):
            return

        if self._kv_cache_events is None:
            self._kv_cache_events = kv_cache_events
        else:
            self._kv_cache_events.add_events(kv_cache_events.get_all_events())
            self._kv_cache_events.increment_workers(
                kv_cache_events.get_number_of_workers()
            )

    def take_events(self) -> Iterable[KVCacheEvent]:
        if self._kv_cache_events is not None:
            self._kv_cache_events.aggregate()
            yield from self._kv_cache_events.get_all_events()
            self._kv_cache_events.clear_events()
            self._kv_cache_events = None

    # ============================================================
    # Worker-side methods
    # ============================================================

    def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
        assert self.connector_worker is not None
        self.connector_worker.register_kv_caches(kv_caches)

    def register_cross_layers_kv_cache(
        self, kv_cache: torch.Tensor, attn_backend: type
    ):
        assert self.connector_worker is not None
        assert (
            self._kv_cache_config is not None
            and len(self._kv_cache_config.kv_cache_groups) == 1
        ), "Cross-layer KV cache does not supported with hybrid models"
        self.connector_worker.register_cross_layers_kv_caches(kv_cache)

    def start_load_kv(self, forward_context: ForwardContext, **kwargs: Any) -> None:
        # No-op: loads are issued in get_finished() for compute overlap.
        pass

    def wait_for_layer_load(self, layer_name: str) -> None:
        # No layerwise support - no-op
        return

    def save_kv_layer(
        self,
        layer_name: str,
        kv_layer: torch.Tensor,
        attn_metadata: AttentionMetadata,
        **kwargs: Any,
    ) -> None:
        # No layerwise support - no-op
        return

    def wait_for_save(self):
        # No-op: stores are issued in get_finished() for compute overlap.
        pass

    def get_finished(
        self, finished_req_ids: set[str]
    ) -> tuple[set[str] | None, set[str] | None]:
        assert self.connector_worker is not None
        metadata = self._get_connector_metadata()
        assert isinstance(metadata, MooncakeStoreConnectorMetadata)
        return self.connector_worker.get_finished(finished_req_ids, metadata)

    def get_block_ids_with_load_errors(self) -> set[int]:
        assert self.connector_worker is not None
        return self.connector_worker.get_block_ids_with_load_errors()

    def get_kv_connector_kv_cache_events(
        self,
    ) -> MooncakeStoreKVEvents | None:
        assert self.connector_worker is not None
        events = self.connector_worker.get_kv_events()
        if not events:
            return None

        kv_events = MooncakeStoreKVEvents(num_workers=1)
        kv_events.add_events(events)
        return kv_events

    def get_kv_connector_stats(self) -> KVConnectorStats | None:
        if self.connector_worker is None:
            return None
        return self.connector_worker.get_kv_connector_stats()

    @classmethod
    def build_kv_connector_stats(
        cls, data: dict[str, Any] | None = None
    ) -> KVConnectorStats | None:
        return (
            MooncakeStoreConnectorStats(data=data)
            if data is not None
            else MooncakeStoreConnectorStats()
        )

    @classmethod
    def build_prom_metrics(
        cls,
        vllm_config: VllmConfig,
        metric_types: dict[type[PromMetric], type[PromMetricT]],
        labelnames: list[str],
        per_engine_labelvalues: dict[int, list[object]],
    ) -> KVConnectorPromMetrics:
        return MooncakeStorePromMetrics(
            vllm_config, metric_types, labelnames, per_engine_labelvalues
        )

MooncakeStoreKVEvents

Bases: KVConnectorKVEvents

KV event aggregation for MooncakeStoreConnector.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/mooncake/store/connector.py
class MooncakeStoreKVEvents(KVConnectorKVEvents):
    """KV event aggregation for MooncakeStoreConnector."""

    def __init__(self, num_workers: int) -> None:
        self._aggregator = KVEventAggregator(num_workers)

    def add_events(self, events: list[KVCacheEvent]) -> None:
        self._aggregator.add_events(events)

    def aggregate(self) -> "MooncakeStoreKVEvents":
        common_events = self._aggregator.get_common_events()
        self._aggregator.clear_events()
        self._aggregator.add_events(common_events)
        self._aggregator.reset_workers()
        return self

    def increment_workers(self, count: int = 1) -> None:
        self._aggregator.increment_workers(count)

    def get_all_events(self) -> list[KVCacheEvent]:
        return self._aggregator.get_all_events()

    def get_number_of_workers(self) -> int:
        return self._aggregator.get_number_of_workers()

    def clear_events(self) -> None:
        self._aggregator.clear_events()
        self._aggregator.reset_workers()

    def __repr__(self) -> str:
        return f"<MooncakeStoreKVEvents events={self.get_all_events()}>"