class MooncakeStoreScheduler:
"""Scheduler-side component for MooncakeStoreConnector."""
def __init__(
self,
vllm_config: VllmConfig,
kv_cache_config: KVCacheConfig,
):
assert vllm_config.kv_transfer_config is not None
self.kv_role = vllm_config.kv_transfer_config.kv_role
self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"load_async", True
)
self.client = LookupKeyClient(vllm_config)
self.pcp_size = vllm_config.parallel_config.prefill_context_parallel_size
self.dcp_size = vllm_config.parallel_config.decode_context_parallel_size
self.original_block_size = vllm_config.cache_config.block_size
# LCM for multi-group HMA; bs * pcp * dcp for single-group. Matches
# the engine's own scheduler block size by construction.
self._block_size, self._hash_block_size = resolve_kv_cache_block_sizes(
kv_cache_config, vllm_config
)
# Per-request state
self.load_specs: dict[str, LoadSpec] = {} # to be loaded
self._request_trackers: dict[str, RequestTracker] = {} # scheduled new requests
self._preempted_req_ids: set[str] = set() # preempted requests
self._unfinished_requests: dict[str, tuple[Request, tuple[list[int], ...]]] = {}
self._unfinished_request_ids: set[str] = set()
def get_num_new_matched_tokens(
self,
request: Request,
num_computed_tokens: int,
) -> tuple[int, bool]:
"""Check for external KV cache hit."""
# Look up against the full prefill range, not just the prompt.
token_len = request.num_tokens // self._block_size * self._block_size
if token_len < self._block_size:
return 0, False
num_external_hit_tokens = self.client.lookup(token_len, request.block_hashes)
if num_external_hit_tokens == request.num_tokens:
# Leave a sub-block tail uncomputed for sampling, on a block
# boundary so the recv-side load mask covers every yielded chunk.
num_external_hit_tokens = max(
0,
(request.num_tokens - 1) // self._block_size * self._block_size,
)
if num_external_hit_tokens < num_computed_tokens:
need_to_allocate = 0
else:
need_to_allocate = num_external_hit_tokens - num_computed_tokens
logger.debug(
"Reqid: %s, Total tokens %d, kvpool hit tokens: %d, need to load: %d",
request.request_id,
request.num_tokens,
num_external_hit_tokens,
need_to_allocate,
)
if need_to_allocate <= 0:
return 0, False
self.load_specs[request.request_id] = LoadSpec(
vllm_cached_tokens=num_computed_tokens,
kvpool_cached_tokens=num_external_hit_tokens,
can_load=False,
)
return need_to_allocate, self.load_async
def update_state_after_alloc(
self,
request: Request,
blocks: KVCacheBlocks,
num_external_tokens: int,
):
"""Update state after block allocation."""
local_block_ids: tuple[list[int], ...] = ()
if num_external_tokens > 0:
local_block_ids = blocks.get_block_ids()
self._unfinished_requests[request.request_id] = (request, local_block_ids)
self._unfinished_request_ids.add(request.request_id)
if request.request_id not in self.load_specs:
return
if num_external_tokens == 0:
self.load_specs[request.request_id].can_load = False
return
assert (
num_external_tokens > 0
and num_external_tokens
== self.load_specs[request.request_id].kvpool_cached_tokens
- self.load_specs[request.request_id].vllm_cached_tokens
), (
f"Mismatch in number of tokens: {num_external_tokens} vs "
f"{self.load_specs[request.request_id].kvpool_cached_tokens} - "
f"{self.load_specs[request.request_id].vllm_cached_tokens}"
f" for request {request.request_id}"
)
self.load_specs[request.request_id].can_load = True
def build_connector_meta(
self, scheduler_output: SchedulerOutput
) -> KVConnectorMetadata:
"""Build connector metadata for this scheduler step."""
force_skip_save = self.kv_role == "kv_consumer"
for finished_req_id in scheduler_output.finished_req_ids:
self.load_specs.pop(finished_req_id, None)
self._request_trackers.pop(finished_req_id, None)
self._unfinished_requests.pop(finished_req_id, None)
self._unfinished_request_ids.discard(finished_req_id)
self._preempted_req_ids.discard(finished_req_id)
preempted_ids = scheduler_output.preempted_req_ids or set()
self._preempted_req_ids.update(preempted_ids)
for req_id in preempted_ids:
self.load_specs.pop(req_id, None)
if request_tracker := self._request_trackers.get(req_id):
request_tracker.reset()
self._unfinished_requests.pop(req_id, None)
meta = MooncakeStoreConnectorMetadata(
self._unfinished_request_ids,
preempted_ids,
)
# Handle new requests
for request in scheduler_output.scheduled_new_reqs:
load_spec = self.load_specs.pop(request.req_id, None)
num_tokens_to_compute = (
request.num_computed_tokens
+ scheduler_output.num_scheduled_tokens[request.req_id]
)
assert request.req_id in self._unfinished_requests
request_tuple = self._unfinished_requests.get(request.req_id)
request_real = request_tuple[0] # type: ignore[index]
if isinstance(request.block_ids, tuple):
# Multi-group: preserve per-group structure.
unfolded_block_ids = tuple(b.copy() for b in request.block_ids)
else:
# Single-group legacy: list[int] -> 1-tuple.
unfolded_block_ids = (request.block_ids.copy(),)
prefill_tokens = _new_req_prefill_tokens(request)
request_tracker = RequestTracker(
req_id=request.req_id,
token_len=num_tokens_to_compute,
allocated_block_ids=unfolded_block_ids,
num_saved_tokens=0,
token_ids=prefill_tokens[:num_tokens_to_compute],
prefill_end_tokens=len(prefill_tokens),
)
self._request_trackers[request.req_id] = request_tracker
last_chunk_tokens_num = (
len(prefill_tokens) // self._block_size * self._block_size
)
req_meta = ReqMeta.from_request_tracker(
request_tracker,
self._block_size,
load_spec=load_spec,
skip_save=force_skip_save,
block_hashes=request_real.block_hashes,
is_last_chunk=(request_tracker.token_len >= last_chunk_tokens_num),
original_block_size=self.original_block_size,
)
if req_meta is not None:
meta.add_request(req_meta)
# Handle cached (running, or MRV1 resumed-from-preemption) requests
cached_reqs = scheduler_output.scheduled_cached_reqs
if not force_skip_save:
for i, req_id in enumerate(cached_reqs.req_ids):
new_block_ids = cached_reqs.new_block_ids[i]
if not new_block_ids:
continue
req_meta = None
if req_id in self._preempted_req_ids:
# Resumed after preemption
if isinstance(new_block_ids, tuple):
new_block_ids = tuple(b.copy() for b in new_block_ids)
else:
new_block_ids = (new_block_ids.copy(),)
self._preempted_req_ids.discard(req_id)
load_spec = self.load_specs.pop(req_id, None)
request_tuple = self._unfinished_requests.get(req_id)
request_real = request_tuple[0] # type: ignore[index]
num_tokens_to_compute = (
request_real.num_computed_tokens
+ scheduler_output.num_scheduled_tokens[req_id]
)
# On resume, the request re-prefills prompt + previously
# generated tokens (all_token_ids).
prefill_tokens = list(request_real.all_token_ids)
request_tracker = RequestTracker(
req_id=req_id,
token_len=num_tokens_to_compute,
allocated_block_ids=new_block_ids,
num_saved_tokens=0,
token_ids=prefill_tokens[:num_tokens_to_compute].copy(),
prefill_end_tokens=len(prefill_tokens),
)
self._request_trackers[req_id] = request_tracker
last_chunk_tokens_num = (
len(prefill_tokens) // self._block_size * self._block_size
)
req_meta = ReqMeta.from_request_tracker(
request_tracker,
self._block_size,
load_spec=load_spec,
skip_save=force_skip_save,
block_hashes=request_real.block_hashes,
is_last_chunk=(
request_tracker.token_len >= last_chunk_tokens_num
),
original_block_size=self.original_block_size,
)
else:
# Decode/chunked request
request_tracker = self._request_trackers[req_id]
num_new_tokens = scheduler_output.num_scheduled_tokens[req_id]
req_tuple = self._unfinished_requests.get(req_id)
if req_tuple:
unfinished_req = req_tuple[0]
num_current_tokens = request_tracker.token_len
new_token_ids = unfinished_req.all_token_ids[
num_current_tokens : num_current_tokens + num_new_tokens
]
request_tracker.token_len += len(new_token_ids)
else:
raise ValueError(
f"Request {req_id} is not in _unfinished_requests"
)
num_computed_token = cached_reqs.num_computed_tokens[i]
# Use the tracker's snapshot of the prefill range so resumed
# requests keep saving past the original prompt boundary.
prefill_end = request_tracker.prefill_end_tokens
if num_computed_token >= prefill_end:
continue
request_tracker.update(new_block_ids)
last_chunk_tokens_num = (
prefill_end // self._block_size * self._block_size
)
req_meta = ReqMeta.from_request_tracker(
request_tracker,
self._block_size,
load_spec=None,
skip_save=force_skip_save,
block_hashes=unfinished_req.block_hashes,
is_last_chunk=(
request_tracker.token_len >= last_chunk_tokens_num
),
original_block_size=self.original_block_size,
)
if req_meta is not None:
meta.add_request(req_meta)
# Handle requests with pending load specs not yet scheduled
request_ids = [req.req_id for req in scheduler_output.scheduled_new_reqs]
for request_id, (
unfinished_req,
block_ids,
) in self._unfinished_requests.items():
if request_id not in request_ids and request_id not in cached_reqs.req_ids:
load_spec = self.load_specs.pop(request_id, None)
if not load_spec:
continue
num_tokens_to_compute = load_spec.kvpool_cached_tokens
request_tracker = RequestTracker(
req_id=request_id,
token_len=num_tokens_to_compute,
allocated_block_ids=block_ids,
num_saved_tokens=0,
)
self._request_trackers[request_id] = request_tracker
req_meta = ReqMeta.from_request_tracker(
request_tracker,
self._block_size,
load_spec=load_spec,
skip_save=None,
block_hashes=unfinished_req.block_hashes,
)
if req_meta is not None:
meta.add_request(req_meta)
return meta
def request_finished(
self,
request: Request,
block_ids: tuple[list[int], ...],
) -> tuple[bool, dict[str, Any] | None]:
"""Determine whether to delay freeing blocks for async save."""
if self.kv_role == "kv_consumer":
return False, None
tracker = self._request_trackers.get(request.request_id)
# Missing tracker can happen when the request is aborted before the
# connector observes the normal finished lifecycle or is preempted
# before finishing.
if tracker is None or tracker.num_saved_tokens <= 0:
return False, None
total_blocks = sum(len(g) for g in block_ids)
delay_free_blocks = total_blocks > 0
if delay_free_blocks:
logger.debug(
"Delaying free of %d blocks for request %s",
total_blocks,
request.request_id,
)
return delay_free_blocks, None