class MooncakeStoreConnector(KVConnectorBase_V1, SupportsHMA):
"""KV connector using MooncakeDistributedStore as shared KV pool."""
@property
def prefer_cross_layer_blocks(self) -> bool:
extra_config = self._kv_transfer_config.kv_connector_extra_config
return (
str(extra_config.get("enable_cross_layers_blocks", "False")).lower()
== "true"
)
@staticmethod
def _validate_kv_cache_config(
vllm_config: VllmConfig, kv_cache_config: KVCacheConfig
) -> None:
from vllm.v1.kv_cache_interface import CrossAttentionSpec, MambaSpec
unsupported: list[str] = []
cache_block_size = vllm_config.cache_config.block_size
for g_idx, g in enumerate(kv_cache_config.kv_cache_groups):
spec = g.kv_cache_spec
if isinstance(spec, CrossAttentionSpec):
unsupported.append(f"group {g_idx}: CrossAttentionSpec")
# Enforce Mamba align mode
if isinstance(spec, MambaSpec) and spec.block_size != cache_block_size:
unsupported.append(
f"group {g_idx}: MambaSpec with block_size="
f"{spec.block_size} != cache_config.block_size="
f"{cache_block_size} (mamba_cache_mode != 'align')"
)
pcp = vllm_config.parallel_config.prefill_context_parallel_size
dcp = vllm_config.parallel_config.decode_context_parallel_size
if len(kv_cache_config.kv_cache_groups) > 1 and pcp * dcp > 1:
unsupported.append(
f"PCP/DCP > 1 (pcp={pcp}, dcp={dcp}) with hybrid attention"
)
if unsupported:
raise ValueError(
"MooncakeStoreConnector does not support: " + "; ".join(unsupported)
)
def __init__(
self,
vllm_config: VllmConfig,
role: KVConnectorRole,
kv_cache_config: KVCacheConfig | None = None,
):
super().__init__(
vllm_config=vllm_config,
role=role,
kv_cache_config=kv_cache_config, # type: ignore[arg-type]
)
assert vllm_config.kv_transfer_config is not None
assert kv_cache_config is not None, "kv_cache_config is required"
self._validate_kv_cache_config(vllm_config, kv_cache_config)
self._kv_cache_config = kv_cache_config
self.kv_role = vllm_config.kv_transfer_config.kv_role
self._kv_cache_events: MooncakeStoreKVEvents | None = None
self.connector_scheduler: MooncakeStoreScheduler | None = None
self.connector_worker: MooncakeStoreWorker | None = None
if role == KVConnectorRole.SCHEDULER:
self.connector_scheduler = MooncakeStoreScheduler(
vllm_config, kv_cache_config
)
else:
self.connector_worker = MooncakeStoreWorker(vllm_config, kv_cache_config)
# ============================================================
# Scheduler-side methods
# ============================================================
def get_num_new_matched_tokens(
self,
request: Request,
num_computed_tokens: int,
) -> tuple[int, bool]:
assert self.connector_scheduler is not None
return self.connector_scheduler.get_num_new_matched_tokens(
request, num_computed_tokens
)
def update_state_after_alloc(
self,
request: Request,
blocks: KVCacheBlocks,
num_external_tokens: int,
):
assert self.connector_scheduler is not None
return self.connector_scheduler.update_state_after_alloc(
request, blocks, num_external_tokens
)
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata:
assert self.connector_scheduler is not None
return self.connector_scheduler.build_connector_meta(scheduler_output)
def request_finished(
self,
request: Request,
block_ids: list[int],
) -> tuple[bool, dict[str, Any] | None]:
return self.request_finished_all_groups(request, (block_ids,))
def request_finished_all_groups(
self,
request: Request,
block_ids: tuple[list[int], ...],
) -> tuple[bool, dict[str, Any] | None]:
assert self.connector_scheduler is not None
return self.connector_scheduler.request_finished(request, block_ids)
def update_connector_output(self, connector_output: KVConnectorOutput):
kv_cache_events = connector_output.kv_cache_events
if not kv_cache_events or not isinstance(
kv_cache_events, MooncakeStoreKVEvents
):
return
if self._kv_cache_events is None:
self._kv_cache_events = kv_cache_events
else:
self._kv_cache_events.add_events(kv_cache_events.get_all_events())
self._kv_cache_events.increment_workers(
kv_cache_events.get_number_of_workers()
)
def take_events(self) -> Iterable[KVCacheEvent]:
if self._kv_cache_events is not None:
self._kv_cache_events.aggregate()
yield from self._kv_cache_events.get_all_events()
self._kv_cache_events.clear_events()
self._kv_cache_events = None
# ============================================================
# Worker-side methods
# ============================================================
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
assert self.connector_worker is not None
self.connector_worker.register_kv_caches(kv_caches)
def register_cross_layers_kv_cache(
self, kv_cache: torch.Tensor, attn_backend: type
):
assert self.connector_worker is not None
assert (
self._kv_cache_config is not None
and len(self._kv_cache_config.kv_cache_groups) == 1
), "Cross-layer KV cache does not supported with hybrid models"
self.connector_worker.register_cross_layers_kv_caches(kv_cache)
def start_load_kv(self, forward_context: ForwardContext, **kwargs: Any) -> None:
# No-op: loads are issued in get_finished() for compute overlap.
pass
def wait_for_layer_load(self, layer_name: str) -> None:
# No layerwise support - no-op
return
def save_kv_layer(
self,
layer_name: str,
kv_layer: torch.Tensor,
attn_metadata: AttentionMetadata,
**kwargs: Any,
) -> None:
# No layerwise support - no-op
return
def wait_for_save(self):
# No-op: stores are issued in get_finished() for compute overlap.
pass
def get_finished(
self, finished_req_ids: set[str]
) -> tuple[set[str] | None, set[str] | None]:
assert self.connector_worker is not None
metadata = self._get_connector_metadata()
assert isinstance(metadata, MooncakeStoreConnectorMetadata)
return self.connector_worker.get_finished(finished_req_ids, metadata)
def get_block_ids_with_load_errors(self) -> set[int]:
assert self.connector_worker is not None
return self.connector_worker.get_block_ids_with_load_errors()
def get_kv_connector_kv_cache_events(
self,
) -> MooncakeStoreKVEvents | None:
assert self.connector_worker is not None
events = self.connector_worker.get_kv_events()
if not events:
return None
kv_events = MooncakeStoreKVEvents(num_workers=1)
kv_events.add_events(events)
return kv_events
def get_kv_connector_stats(self) -> KVConnectorStats | None:
if self.connector_worker is None:
return None
return self.connector_worker.get_kv_connector_stats()
@classmethod
def build_kv_connector_stats(
cls, data: dict[str, Any] | None = None
) -> KVConnectorStats | None:
return (
MooncakeStoreConnectorStats(data=data)
if data is not None
else MooncakeStoreConnectorStats()
)
@classmethod
def build_prom_metrics(
cls,
vllm_config: VllmConfig,
metric_types: dict[type[PromMetric], type[PromMetricT]],
labelnames: list[str],
per_engine_labelvalues: dict[int, list[object]],
) -> KVConnectorPromMetrics:
return MooncakeStorePromMetrics(
vllm_config, metric_types, labelnames, per_engine_labelvalues
)