Skip to content

vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_connector

logger module-attribute

logger = init_logger(__name__)

P2pNcclConnector

Bases: KVConnectorBase_V1

Source code in vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
class P2pNcclConnector(KVConnectorBase_V1):

    def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
        super().__init__(vllm_config=vllm_config, role=role)
        self._block_size = vllm_config.cache_config.block_size
        self._requests_need_load: dict[str, Any] = {}
        self.config = vllm_config.kv_transfer_config
        self.is_producer = self.config.is_kv_producer
        self.chunked_prefill: dict[str, Any] = {}

        self._rank = get_world_group().rank \
            if role == KVConnectorRole.WORKER else 0
        self._local_rank = get_world_group().local_rank \
            if role == KVConnectorRole.WORKER else 0

        self.p2p_nccl_engine = P2pNcclEngine(
            local_rank=self._local_rank,
            config=self.config,
            hostname="",
            port_offset=self._rank,
        ) if role == KVConnectorRole.WORKER else None

    # ==============================
    # Worker-side methods
    # ==============================

    def start_load_kv(self, forward_context: "ForwardContext",
                      **kwargs) -> None:
        """Start loading the KV cache from the connector buffer to vLLM's
        paged KV buffer.

        Args:
            forward_context (ForwardContext): the forward context.
            **kwargs: additional arguments for the load operation

        Note:
            The number of elements in kv_caches and layer_names should be
            the same.
        """

        # Only consumer/decode loads KV Cache
        if self.is_producer:
            return

        assert self.p2p_nccl_engine is not None

        attn_metadata = forward_context.attn_metadata
        if attn_metadata is None:
            return

        def inject_kv_into_layer(
            dst_kv_cache_layer: torch.Tensor,
            src_kv_cache: torch.Tensor,
            slot_mapping: torch.Tensor,
            request_id: str,
        ) -> None:
            """Inject the KV cache into the layer.

            Args:
                dst_kv_cache_layer (torch.Tensor): the destination KV cache
                    layer. In shape [2, num_pages, page_size, xxx] if not
                    using MLA, [num_pages, page_size, xxx] otherwise.
                src_kv_cache (torch.Tensor): the source KV cache. In shape
                    [2, num_tokens, xxx] if not using MLA, [num_tokens, xxx]
                    otherwise.
                slot_mapping (torch.Tensor): the slot mapping. In shape
                    [num_tokens].
                request_id (str): request id for log
            """
            dst_kv_cache_layer_shape = dst_kv_cache_layer.shape
            if isinstance(attn_metadata, MLACommonMetadata):
                num_pages = dst_kv_cache_layer_shape[0]
                page_size = dst_kv_cache_layer_shape[1]
                dst_kv_cache_layer = dst_kv_cache_layer.reshape(
                    num_pages * page_size, -1)
                self.check_tensors_except_dim(dst_kv_cache_layer, src_kv_cache,
                                              0)
                num_token = src_kv_cache.shape[0]
                if len(slot_mapping) == num_token:
                    dst_kv_cache_layer[slot_mapping, ...] = src_kv_cache
                else:
                    dst_kv_cache_layer[slot_mapping[:num_token],
                                       ...] = src_kv_cache
                    logger.warning(
                        "🚧src_kv_cache does not match, num_slot:%d, "
                        "num_token:%d, request_id:%s", len(slot_mapping),
                        num_token, request_id)

                dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
            else:
                num_pages = dst_kv_cache_layer_shape[1]
                page_size = dst_kv_cache_layer_shape[2]
                dst_kv_cache_layer = dst_kv_cache_layer.reshape(
                    2, num_pages * page_size, -1)
                self.check_tensors_except_dim(dst_kv_cache_layer, src_kv_cache,
                                              1)
                num_token = src_kv_cache.shape[1]
                if len(slot_mapping) == num_token:
                    dst_kv_cache_layer[:, slot_mapping, ...] = src_kv_cache
                else:
                    dst_kv_cache_layer[:, slot_mapping[:num_token],
                                       ...] = src_kv_cache
                    logger.warning(
                        "🚧src_kv_cache does not match, num_slot:%d, "
                        "num_token:%d, request_id:%s", len(slot_mapping),
                        num_token, request_id)

                dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)

        # Get the metadata
        metadata: KVConnectorMetadata = \
            self._get_connector_metadata()
        assert isinstance(metadata, P2pNcclConnectorMetadata)

        if metadata is None:
            return

        # Load the KV for each request each layer
        for request in metadata.requests:
            for layer_name in forward_context.no_compile_layers:
                attn_layer = forward_context.no_compile_layers[layer_name]
                kv_cache_layer = attn_layer.kv_cache[ \
                    forward_context.virtual_engine]

                kv_cache = self.p2p_nccl_engine.recv_tensor(
                    request.request_id + "#" + layer_name)

                if kv_cache is None:
                    logger.warning("🚧src_kv_cache is None, %s",
                                   request.request_id)
                    continue

                inject_kv_into_layer(kv_cache_layer, kv_cache,
                                     request.slot_mapping, request.request_id)

    def wait_for_layer_load(self, layer_name: str) -> None:
        """Blocking until the KV for a specific layer is loaded into vLLM's
        paged buffer.

        This interface will be useful for layer-by-layer pipelining.

        Args:
            layer_name: the name of that layer
        """
        return

    def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
                      attn_metadata: "AttentionMetadata", **kwargs) -> None:
        """Start saving the KV cache of the layer from vLLM's paged buffer
        to the connector.

        Args:
            layer_name (str): the name of the layer.
            kv_layer (torch.Tensor): the paged KV buffer of the current
                layer in vLLM.
            attn_metadata (AttentionMetadata): the attention metadata.
            **kwargs: additional arguments for the save operation.
        """

        # Only producer/prefill saves KV Cache
        if not self.is_producer:
            return

        assert self.p2p_nccl_engine is not None

        def extract_kv_from_layer(
            layer: torch.Tensor,
            slot_mapping: torch.Tensor,
        ) -> torch.Tensor:
            """Extract the KV cache from the layer.

            Assume the shape of the layer is (2, num_pages, page_size, xxx)
            if MLA is not used, and (num_pages, page_size, xxx) otherwise.
            """
            if isinstance(attn_metadata, MLACommonMetadata):
                num_pages, page_size = layer.shape[0], layer.shape[1]
                return layer.reshape(num_pages * page_size, -1)[slot_mapping,
                                                                ...]
            num_pages, page_size = layer.shape[1], layer.shape[2]
            return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping,
                                                               ...]

        connector_metadata = self._get_connector_metadata()
        assert isinstance(connector_metadata, P2pNcclConnectorMetadata)
        for request in connector_metadata.requests:
            request_id = request.request_id
            ip, port = self.parse_request_id(request_id, True)
            remote_address = ip + ":" + str(port + self._rank)
            kv_cache = extract_kv_from_layer(kv_layer, request.slot_mapping)
            self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
                                             kv_cache, remote_address)

    def wait_for_save(self):
        if self.is_producer:
            assert self.p2p_nccl_engine is not None
            self.p2p_nccl_engine.wait_for_sent()

    def get_finished(
            self, finished_req_ids: set[str],
            **kwargs) -> tuple[Optional[set[str]], Optional[set[str]]]:
        """
        Notifies worker-side connector ids of requests that have
        finished generating tokens.

        Returns:
            ids of requests that have finished asynchronous transfer,
            tuple of (sending/saving ids, recving/loading ids).
            The finished saves/sends req ids must belong to a set provided in a
            call to this method (this call or a prior one).
        """

        assert self.p2p_nccl_engine is not None

        forward_context: ForwardContext = get_forward_context()
        return self.p2p_nccl_engine.get_finished(finished_req_ids,
                                                 forward_context)

    # ==============================
    # Scheduler-side methods
    # ==============================

    def get_num_new_matched_tokens(
        self,
        request: "Request",
        num_computed_tokens: int,
    ) -> tuple[int, bool]:
        """
        Get number of new tokens that can be loaded from the
        external KV cache beyond the num_computed_tokens.

        Args:
            request (Request): the request object.
            num_computed_tokens (int): the number of locally
                computed tokens for this request

        Returns:
            the number of tokens that can be loaded from the
            external KV cache beyond what is already computed.
        """
        if self.is_producer:
            return 0, False

        num_external_tokens = (len(request.prompt_token_ids) - 1 -
                               num_computed_tokens)

        if num_external_tokens < 0:
            num_external_tokens = 0

        return num_external_tokens, False

    def update_state_after_alloc(self, request: "Request",
                                 blocks: "KVCacheBlocks",
                                 num_external_tokens: int):
        """
        Update KVConnector state after block allocation.
        """
        if not self.is_producer and num_external_tokens > 0:
            self._requests_need_load[request.request_id] = (
                request, blocks.get_block_ids()[0])

    def build_connector_meta(
        self,
        scheduler_output: SchedulerOutput,
    ) -> KVConnectorMetadata:
        """Build the connector metadata for this step.

        This function should NOT modify any fields in the scheduler_output.
        Also, calling this function will reset the state of the connector.

        Args:
            scheduler_output (SchedulerOutput): the scheduler output object.
        """

        meta = P2pNcclConnectorMetadata()

        for new_req in scheduler_output.scheduled_new_reqs:
            if self.is_producer:
                num_scheduled_tokens = (
                    scheduler_output.num_scheduled_tokens)[new_req.req_id]
                num_tokens = num_scheduled_tokens + new_req.num_computed_tokens
                # the request's prompt is chunked prefill
                if num_tokens < len(new_req.prompt_token_ids):
                    # 'CachedRequestData' has no attribute 'prompt_token_ids'
                    self.chunked_prefill[new_req.req_id] = (
                        new_req.block_ids[0], new_req.prompt_token_ids)
                    continue
                # the request's prompt is not chunked prefill
                meta.add_request(request_id=new_req.req_id,
                                 token_ids=new_req.prompt_token_ids,
                                 block_ids=new_req.block_ids[0],
                                 block_size=self._block_size)
                continue
            if new_req.req_id in self._requests_need_load:
                meta.add_request(request_id=new_req.req_id,
                                 token_ids=new_req.prompt_token_ids,
                                 block_ids=new_req.block_ids[0],
                                 block_size=self._block_size)
                self._requests_need_load.pop(new_req.req_id)

        cached_reqs = scheduler_output.scheduled_cached_reqs
        for i, req_id in enumerate(cached_reqs.req_ids):
            num_computed_tokens = cached_reqs.num_computed_tokens[i]
            new_block_ids = cached_reqs.new_block_ids[i]
            resumed_from_preemption = cached_reqs.resumed_from_preemption[i]

            if self.is_producer:
                num_scheduled_tokens = (
                    scheduler_output.num_scheduled_tokens)[req_id]
                num_tokens = (num_scheduled_tokens + num_computed_tokens)
                assert req_id in self.chunked_prefill
                block_ids = new_block_ids[0]
                if not resumed_from_preemption:
                    block_ids = (self.chunked_prefill[req_id][0] + block_ids)
                prompt_token_ids = self.chunked_prefill[req_id][1]
                # the request's prompt is chunked prefill again
                if num_tokens < len(prompt_token_ids):
                    self.chunked_prefill[req_id] = (block_ids,
                                                    prompt_token_ids)
                    continue
                # the request's prompt is all prefilled finally
                meta.add_request(request_id=req_id,
                                 token_ids=prompt_token_ids,
                                 block_ids=block_ids,
                                 block_size=self._block_size)
                self.chunked_prefill.pop(req_id, None)
                continue

            # NOTE(rob): here we rely on the resumed requests being
            # the first N requests in the list scheduled_cache_reqs.
            if not resumed_from_preemption:
                break
            if req_id in self._requests_need_load:
                request, _ = self._requests_need_load.pop(req_id)
                total_tokens = num_computed_tokens + 1
                token_ids = request.all_token_ids[:total_tokens]

                # NOTE(rob): For resumed req, new_block_ids is all
                # of the block_ids for the request.
                block_ids = new_block_ids[0]

                meta.add_request(request_id=req_id,
                                 token_ids=token_ids,
                                 block_ids=block_ids,
                                 block_size=self._block_size)

        # Requests loaded asynchronously are not in the scheduler_output.
        # for request_id in self._requests_need_load:
        #     request, block_ids = self._requests_need_load[request_id]
        #     meta.add_request(request_id=request.request_id,
        #                      token_ids=request.prompt_token_ids,
        #                      block_ids=block_ids,
        #                      block_size=self._block_size)

        self._requests_need_load.clear()
        return meta

    def request_finished(
        self,
        request: "Request",
        block_ids: list[int],
    ) -> tuple[bool, Optional[dict[str, Any]]]:
        """
        Called when a request has finished, before its blocks are freed.

        Returns:
            True if the request is being saved/sent asynchronously and blocks
            should not be freed until the request_id is returned from
            get_finished().
            Optional KVTransferParams to be included in the request outputs
            returned by the engine.
        """

        self.chunked_prefill.pop(request.request_id, None)

        return False, None

    # ==============================
    # Static methods
    # ==============================

    @staticmethod
    def parse_request_id(request_id: str, is_prefill=True) -> tuple[str, int]:
        # Regular expression to match the string hostname and integer port
        if is_prefill:
            pattern = r"___decode_addr_(.*):(\d+)"
        else:
            pattern = r"___prefill_addr_(.*):(\d+)___"

        # Use re.search to find the pattern in the request_id
        match = re.search(pattern, request_id)
        if match:
            # Extract the ranks
            ip = match.group(1)
            port = int(match.group(2))

            return ip, port
        raise ValueError(
            f"Request id {request_id} does not contain hostname and port")

    @staticmethod
    def check_tensors_except_dim(tensor1, tensor2, dim):
        shape1 = tensor1.size()
        shape2 = tensor2.size()

        if len(shape1) != len(shape2) or not all(
                s1 == s2
                for i, (s1, s2) in enumerate(zip(shape1, shape2)) if i != dim):
            raise NotImplementedError(
                "Currently, only symmetric TP is supported. Asymmetric TP, PP,"
                "and others will be supported in future PRs.")

_block_size instance-attribute

_block_size = block_size

_local_rank instance-attribute

_local_rank = local_rank if role == WORKER else 0

_rank instance-attribute

_rank = rank if role == WORKER else 0

_requests_need_load instance-attribute

_requests_need_load: dict[str, Any] = {}

chunked_prefill instance-attribute

chunked_prefill: dict[str, Any] = {}

config instance-attribute

config = kv_transfer_config

is_producer instance-attribute

is_producer = is_kv_producer

p2p_nccl_engine instance-attribute

p2p_nccl_engine = (
    P2pNcclEngine(
        local_rank=_local_rank,
        config=config,
        hostname="",
        port_offset=_rank,
    )
    if role == WORKER
    else None
)

__init__

__init__(vllm_config: VllmConfig, role: KVConnectorRole)
Source code in vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
    super().__init__(vllm_config=vllm_config, role=role)
    self._block_size = vllm_config.cache_config.block_size
    self._requests_need_load: dict[str, Any] = {}
    self.config = vllm_config.kv_transfer_config
    self.is_producer = self.config.is_kv_producer
    self.chunked_prefill: dict[str, Any] = {}

    self._rank = get_world_group().rank \
        if role == KVConnectorRole.WORKER else 0
    self._local_rank = get_world_group().local_rank \
        if role == KVConnectorRole.WORKER else 0

    self.p2p_nccl_engine = P2pNcclEngine(
        local_rank=self._local_rank,
        config=self.config,
        hostname="",
        port_offset=self._rank,
    ) if role == KVConnectorRole.WORKER else None

build_connector_meta

build_connector_meta(
    scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata

Build the connector metadata for this step.

This function should NOT modify any fields in the scheduler_output. Also, calling this function will reset the state of the connector.

Parameters:

Name Type Description Default
scheduler_output SchedulerOutput

the scheduler output object.

required
Source code in vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
def build_connector_meta(
    self,
    scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata:
    """Build the connector metadata for this step.

    This function should NOT modify any fields in the scheduler_output.
    Also, calling this function will reset the state of the connector.

    Args:
        scheduler_output (SchedulerOutput): the scheduler output object.
    """

    meta = P2pNcclConnectorMetadata()

    for new_req in scheduler_output.scheduled_new_reqs:
        if self.is_producer:
            num_scheduled_tokens = (
                scheduler_output.num_scheduled_tokens)[new_req.req_id]
            num_tokens = num_scheduled_tokens + new_req.num_computed_tokens
            # the request's prompt is chunked prefill
            if num_tokens < len(new_req.prompt_token_ids):
                # 'CachedRequestData' has no attribute 'prompt_token_ids'
                self.chunked_prefill[new_req.req_id] = (
                    new_req.block_ids[0], new_req.prompt_token_ids)
                continue
            # the request's prompt is not chunked prefill
            meta.add_request(request_id=new_req.req_id,
                             token_ids=new_req.prompt_token_ids,
                             block_ids=new_req.block_ids[0],
                             block_size=self._block_size)
            continue
        if new_req.req_id in self._requests_need_load:
            meta.add_request(request_id=new_req.req_id,
                             token_ids=new_req.prompt_token_ids,
                             block_ids=new_req.block_ids[0],
                             block_size=self._block_size)
            self._requests_need_load.pop(new_req.req_id)

    cached_reqs = scheduler_output.scheduled_cached_reqs
    for i, req_id in enumerate(cached_reqs.req_ids):
        num_computed_tokens = cached_reqs.num_computed_tokens[i]
        new_block_ids = cached_reqs.new_block_ids[i]
        resumed_from_preemption = cached_reqs.resumed_from_preemption[i]

        if self.is_producer:
            num_scheduled_tokens = (
                scheduler_output.num_scheduled_tokens)[req_id]
            num_tokens = (num_scheduled_tokens + num_computed_tokens)
            assert req_id in self.chunked_prefill
            block_ids = new_block_ids[0]
            if not resumed_from_preemption:
                block_ids = (self.chunked_prefill[req_id][0] + block_ids)
            prompt_token_ids = self.chunked_prefill[req_id][1]
            # the request's prompt is chunked prefill again
            if num_tokens < len(prompt_token_ids):
                self.chunked_prefill[req_id] = (block_ids,
                                                prompt_token_ids)
                continue
            # the request's prompt is all prefilled finally
            meta.add_request(request_id=req_id,
                             token_ids=prompt_token_ids,
                             block_ids=block_ids,
                             block_size=self._block_size)
            self.chunked_prefill.pop(req_id, None)
            continue

        # NOTE(rob): here we rely on the resumed requests being
        # the first N requests in the list scheduled_cache_reqs.
        if not resumed_from_preemption:
            break
        if req_id in self._requests_need_load:
            request, _ = self._requests_need_load.pop(req_id)
            total_tokens = num_computed_tokens + 1
            token_ids = request.all_token_ids[:total_tokens]

            # NOTE(rob): For resumed req, new_block_ids is all
            # of the block_ids for the request.
            block_ids = new_block_ids[0]

            meta.add_request(request_id=req_id,
                             token_ids=token_ids,
                             block_ids=block_ids,
                             block_size=self._block_size)

    # Requests loaded asynchronously are not in the scheduler_output.
    # for request_id in self._requests_need_load:
    #     request, block_ids = self._requests_need_load[request_id]
    #     meta.add_request(request_id=request.request_id,
    #                      token_ids=request.prompt_token_ids,
    #                      block_ids=block_ids,
    #                      block_size=self._block_size)

    self._requests_need_load.clear()
    return meta

check_tensors_except_dim staticmethod

check_tensors_except_dim(tensor1, tensor2, dim)
Source code in vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
@staticmethod
def check_tensors_except_dim(tensor1, tensor2, dim):
    shape1 = tensor1.size()
    shape2 = tensor2.size()

    if len(shape1) != len(shape2) or not all(
            s1 == s2
            for i, (s1, s2) in enumerate(zip(shape1, shape2)) if i != dim):
        raise NotImplementedError(
            "Currently, only symmetric TP is supported. Asymmetric TP, PP,"
            "and others will be supported in future PRs.")

get_finished

get_finished(
    finished_req_ids: set[str], **kwargs
) -> tuple[Optional[set[str]], Optional[set[str]]]

Notifies worker-side connector ids of requests that have finished generating tokens.

Returns:

Type Description
Optional[set[str]]

ids of requests that have finished asynchronous transfer,

Optional[set[str]]

tuple of (sending/saving ids, recving/loading ids).

tuple[Optional[set[str]], Optional[set[str]]]

The finished saves/sends req ids must belong to a set provided in a

tuple[Optional[set[str]], Optional[set[str]]]

call to this method (this call or a prior one).

Source code in vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
def get_finished(
        self, finished_req_ids: set[str],
        **kwargs) -> tuple[Optional[set[str]], Optional[set[str]]]:
    """
    Notifies worker-side connector ids of requests that have
    finished generating tokens.

    Returns:
        ids of requests that have finished asynchronous transfer,
        tuple of (sending/saving ids, recving/loading ids).
        The finished saves/sends req ids must belong to a set provided in a
        call to this method (this call or a prior one).
    """

    assert self.p2p_nccl_engine is not None

    forward_context: ForwardContext = get_forward_context()
    return self.p2p_nccl_engine.get_finished(finished_req_ids,
                                             forward_context)

get_num_new_matched_tokens

get_num_new_matched_tokens(
    request: Request, num_computed_tokens: int
) -> tuple[int, bool]

Get number of new tokens that can be loaded from the external KV cache beyond the num_computed_tokens.

Parameters:

Name Type Description Default
request Request

the request object.

required
num_computed_tokens int

the number of locally computed tokens for this request

required

Returns:

Type Description
int

the number of tokens that can be loaded from the

bool

external KV cache beyond what is already computed.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
def get_num_new_matched_tokens(
    self,
    request: "Request",
    num_computed_tokens: int,
) -> tuple[int, bool]:
    """
    Get number of new tokens that can be loaded from the
    external KV cache beyond the num_computed_tokens.

    Args:
        request (Request): the request object.
        num_computed_tokens (int): the number of locally
            computed tokens for this request

    Returns:
        the number of tokens that can be loaded from the
        external KV cache beyond what is already computed.
    """
    if self.is_producer:
        return 0, False

    num_external_tokens = (len(request.prompt_token_ids) - 1 -
                           num_computed_tokens)

    if num_external_tokens < 0:
        num_external_tokens = 0

    return num_external_tokens, False

parse_request_id staticmethod

parse_request_id(
    request_id: str, is_prefill=True
) -> tuple[str, int]
Source code in vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
@staticmethod
def parse_request_id(request_id: str, is_prefill=True) -> tuple[str, int]:
    # Regular expression to match the string hostname and integer port
    if is_prefill:
        pattern = r"___decode_addr_(.*):(\d+)"
    else:
        pattern = r"___prefill_addr_(.*):(\d+)___"

    # Use re.search to find the pattern in the request_id
    match = re.search(pattern, request_id)
    if match:
        # Extract the ranks
        ip = match.group(1)
        port = int(match.group(2))

        return ip, port
    raise ValueError(
        f"Request id {request_id} does not contain hostname and port")

request_finished

request_finished(
    request: Request, block_ids: list[int]
) -> tuple[bool, Optional[dict[str, Any]]]

Called when a request has finished, before its blocks are freed.

Returns:

Type Description
bool

True if the request is being saved/sent asynchronously and blocks

Optional[dict[str, Any]]

should not be freed until the request_id is returned from

tuple[bool, Optional[dict[str, Any]]]

get_finished().

tuple[bool, Optional[dict[str, Any]]]

Optional KVTransferParams to be included in the request outputs

tuple[bool, Optional[dict[str, Any]]]

returned by the engine.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
def request_finished(
    self,
    request: "Request",
    block_ids: list[int],
) -> tuple[bool, Optional[dict[str, Any]]]:
    """
    Called when a request has finished, before its blocks are freed.

    Returns:
        True if the request is being saved/sent asynchronously and blocks
        should not be freed until the request_id is returned from
        get_finished().
        Optional KVTransferParams to be included in the request outputs
        returned by the engine.
    """

    self.chunked_prefill.pop(request.request_id, None)

    return False, None

save_kv_layer

save_kv_layer(
    layer_name: str,
    kv_layer: Tensor,
    attn_metadata: AttentionMetadata,
    **kwargs,
) -> None

Start saving the KV cache of the layer from vLLM's paged buffer to the connector.

Parameters:

Name Type Description Default
layer_name str

the name of the layer.

required
kv_layer Tensor

the paged KV buffer of the current layer in vLLM.

required
attn_metadata AttentionMetadata

the attention metadata.

required
**kwargs

additional arguments for the save operation.

{}
Source code in vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
                  attn_metadata: "AttentionMetadata", **kwargs) -> None:
    """Start saving the KV cache of the layer from vLLM's paged buffer
    to the connector.

    Args:
        layer_name (str): the name of the layer.
        kv_layer (torch.Tensor): the paged KV buffer of the current
            layer in vLLM.
        attn_metadata (AttentionMetadata): the attention metadata.
        **kwargs: additional arguments for the save operation.
    """

    # Only producer/prefill saves KV Cache
    if not self.is_producer:
        return

    assert self.p2p_nccl_engine is not None

    def extract_kv_from_layer(
        layer: torch.Tensor,
        slot_mapping: torch.Tensor,
    ) -> torch.Tensor:
        """Extract the KV cache from the layer.

        Assume the shape of the layer is (2, num_pages, page_size, xxx)
        if MLA is not used, and (num_pages, page_size, xxx) otherwise.
        """
        if isinstance(attn_metadata, MLACommonMetadata):
            num_pages, page_size = layer.shape[0], layer.shape[1]
            return layer.reshape(num_pages * page_size, -1)[slot_mapping,
                                                            ...]
        num_pages, page_size = layer.shape[1], layer.shape[2]
        return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping,
                                                           ...]

    connector_metadata = self._get_connector_metadata()
    assert isinstance(connector_metadata, P2pNcclConnectorMetadata)
    for request in connector_metadata.requests:
        request_id = request.request_id
        ip, port = self.parse_request_id(request_id, True)
        remote_address = ip + ":" + str(port + self._rank)
        kv_cache = extract_kv_from_layer(kv_layer, request.slot_mapping)
        self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
                                         kv_cache, remote_address)

start_load_kv

start_load_kv(
    forward_context: ForwardContext, **kwargs
) -> None

Start loading the KV cache from the connector buffer to vLLM's paged KV buffer.

Parameters:

Name Type Description Default
forward_context ForwardContext

the forward context.

required
**kwargs

additional arguments for the load operation

{}
Note

The number of elements in kv_caches and layer_names should be the same.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
def start_load_kv(self, forward_context: "ForwardContext",
                  **kwargs) -> None:
    """Start loading the KV cache from the connector buffer to vLLM's
    paged KV buffer.

    Args:
        forward_context (ForwardContext): the forward context.
        **kwargs: additional arguments for the load operation

    Note:
        The number of elements in kv_caches and layer_names should be
        the same.
    """

    # Only consumer/decode loads KV Cache
    if self.is_producer:
        return

    assert self.p2p_nccl_engine is not None

    attn_metadata = forward_context.attn_metadata
    if attn_metadata is None:
        return

    def inject_kv_into_layer(
        dst_kv_cache_layer: torch.Tensor,
        src_kv_cache: torch.Tensor,
        slot_mapping: torch.Tensor,
        request_id: str,
    ) -> None:
        """Inject the KV cache into the layer.

        Args:
            dst_kv_cache_layer (torch.Tensor): the destination KV cache
                layer. In shape [2, num_pages, page_size, xxx] if not
                using MLA, [num_pages, page_size, xxx] otherwise.
            src_kv_cache (torch.Tensor): the source KV cache. In shape
                [2, num_tokens, xxx] if not using MLA, [num_tokens, xxx]
                otherwise.
            slot_mapping (torch.Tensor): the slot mapping. In shape
                [num_tokens].
            request_id (str): request id for log
        """
        dst_kv_cache_layer_shape = dst_kv_cache_layer.shape
        if isinstance(attn_metadata, MLACommonMetadata):
            num_pages = dst_kv_cache_layer_shape[0]
            page_size = dst_kv_cache_layer_shape[1]
            dst_kv_cache_layer = dst_kv_cache_layer.reshape(
                num_pages * page_size, -1)
            self.check_tensors_except_dim(dst_kv_cache_layer, src_kv_cache,
                                          0)
            num_token = src_kv_cache.shape[0]
            if len(slot_mapping) == num_token:
                dst_kv_cache_layer[slot_mapping, ...] = src_kv_cache
            else:
                dst_kv_cache_layer[slot_mapping[:num_token],
                                   ...] = src_kv_cache
                logger.warning(
                    "🚧src_kv_cache does not match, num_slot:%d, "
                    "num_token:%d, request_id:%s", len(slot_mapping),
                    num_token, request_id)

            dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
        else:
            num_pages = dst_kv_cache_layer_shape[1]
            page_size = dst_kv_cache_layer_shape[2]
            dst_kv_cache_layer = dst_kv_cache_layer.reshape(
                2, num_pages * page_size, -1)
            self.check_tensors_except_dim(dst_kv_cache_layer, src_kv_cache,
                                          1)
            num_token = src_kv_cache.shape[1]
            if len(slot_mapping) == num_token:
                dst_kv_cache_layer[:, slot_mapping, ...] = src_kv_cache
            else:
                dst_kv_cache_layer[:, slot_mapping[:num_token],
                                   ...] = src_kv_cache
                logger.warning(
                    "🚧src_kv_cache does not match, num_slot:%d, "
                    "num_token:%d, request_id:%s", len(slot_mapping),
                    num_token, request_id)

            dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)

    # Get the metadata
    metadata: KVConnectorMetadata = \
        self._get_connector_metadata()
    assert isinstance(metadata, P2pNcclConnectorMetadata)

    if metadata is None:
        return

    # Load the KV for each request each layer
    for request in metadata.requests:
        for layer_name in forward_context.no_compile_layers:
            attn_layer = forward_context.no_compile_layers[layer_name]
            kv_cache_layer = attn_layer.kv_cache[ \
                forward_context.virtual_engine]

            kv_cache = self.p2p_nccl_engine.recv_tensor(
                request.request_id + "#" + layer_name)

            if kv_cache is None:
                logger.warning("🚧src_kv_cache is None, %s",
                               request.request_id)
                continue

            inject_kv_into_layer(kv_cache_layer, kv_cache,
                                 request.slot_mapping, request.request_id)

update_state_after_alloc

update_state_after_alloc(
    request: Request,
    blocks: KVCacheBlocks,
    num_external_tokens: int,
)

Update KVConnector state after block allocation.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
def update_state_after_alloc(self, request: "Request",
                             blocks: "KVCacheBlocks",
                             num_external_tokens: int):
    """
    Update KVConnector state after block allocation.
    """
    if not self.is_producer and num_external_tokens > 0:
        self._requests_need_load[request.request_id] = (
            request, blocks.get_block_ids()[0])

wait_for_layer_load

wait_for_layer_load(layer_name: str) -> None

Blocking until the KV for a specific layer is loaded into vLLM's paged buffer.

This interface will be useful for layer-by-layer pipelining.

Parameters:

Name Type Description Default
layer_name str

the name of that layer

required
Source code in vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
def wait_for_layer_load(self, layer_name: str) -> None:
    """Blocking until the KV for a specific layer is loaded into vLLM's
    paged buffer.

    This interface will be useful for layer-by-layer pipelining.

    Args:
        layer_name: the name of that layer
    """
    return

wait_for_save

wait_for_save()
Source code in vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
def wait_for_save(self):
    if self.is_producer:
        assert self.p2p_nccl_engine is not None
        self.p2p_nccl_engine.wait_for_sent()

P2pNcclConnectorMetadata dataclass

Bases: KVConnectorMetadata

Source code in vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
@dataclass
class P2pNcclConnectorMetadata(KVConnectorMetadata):
    requests: list[ReqMeta]

    def __init__(self):
        self.requests = []

    def add_request(
        self,
        request_id: str,
        token_ids: list[int],
        block_ids: list[int],
        block_size: int,
    ) -> None:
        self.requests.append(
            ReqMeta.make_meta(request_id, token_ids, block_ids, block_size))

requests instance-attribute

requests: list[ReqMeta] = []

__init__

__init__()
Source code in vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
def __init__(self):
    self.requests = []

add_request

add_request(
    request_id: str,
    token_ids: list[int],
    block_ids: list[int],
    block_size: int,
) -> None
Source code in vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
def add_request(
    self,
    request_id: str,
    token_ids: list[int],
    block_ids: list[int],
    block_size: int,
) -> None:
    self.requests.append(
        ReqMeta.make_meta(request_id, token_ids, block_ids, block_size))

ReqMeta dataclass

Source code in vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
@dataclass
class ReqMeta:
    # Request Id
    request_id: str
    # Request tokens
    token_ids: torch.Tensor
    # Slot mappings, should have the same length as token_ids
    slot_mapping: torch.Tensor

    @staticmethod
    def make_meta(request_id: str, token_ids: list[int], block_ids: list[int],
                  block_size: int) -> "ReqMeta":
        valid_num_tokens = len(token_ids)
        token_ids_tensor = torch.tensor(token_ids)
        block_ids_tensor = torch.tensor(block_ids)
        num_blocks = block_ids_tensor.shape[0]
        block_offsets = torch.arange(0, block_size)
        slot_mapping = block_offsets.reshape((1, block_size)) + \
                block_ids_tensor.reshape((num_blocks, 1)) * block_size
        slot_mapping = slot_mapping.flatten()[:valid_num_tokens]

        return ReqMeta(
            request_id=request_id,
            token_ids=token_ids_tensor,
            slot_mapping=slot_mapping,
        )

request_id instance-attribute

request_id: str

slot_mapping instance-attribute

slot_mapping: Tensor

token_ids instance-attribute

token_ids: Tensor

__init__

__init__(
    request_id: str, token_ids: Tensor, slot_mapping: Tensor
) -> None

make_meta staticmethod

make_meta(
    request_id: str,
    token_ids: list[int],
    block_ids: list[int],
    block_size: int,
) -> ReqMeta
Source code in vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
@staticmethod
def make_meta(request_id: str, token_ids: list[int], block_ids: list[int],
              block_size: int) -> "ReqMeta":
    valid_num_tokens = len(token_ids)
    token_ids_tensor = torch.tensor(token_ids)
    block_ids_tensor = torch.tensor(block_ids)
    num_blocks = block_ids_tensor.shape[0]
    block_offsets = torch.arange(0, block_size)
    slot_mapping = block_offsets.reshape((1, block_size)) + \
            block_ids_tensor.reshape((num_blocks, 1)) * block_size
    slot_mapping = slot_mapping.flatten()[:valid_num_tokens]

    return ReqMeta(
        request_id=request_id,
        token_ids=token_ids_tensor,
        slot_mapping=slot_mapping,
    )