Skip to content

vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_engine

DEFAULT_MEM_POOL_SIZE_GB module-attribute

DEFAULT_MEM_POOL_SIZE_GB = 32

logger module-attribute

logger = getLogger(__name__)

P2pNcclEngine

Source code in vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
class P2pNcclEngine:

    def __init__(self,
                 local_rank: int,
                 config: KVTransferConfig,
                 hostname: str = "",
                 port_offset: int = 0,
                 library_path: Optional[str] = None) -> None:
        self.config = config
        self.rank = port_offset
        self.local_rank = local_rank
        self.device = torch.device(f"cuda:{self.local_rank}")
        self.nccl = NCCLLibrary(library_path)

        if not hostname:
            hostname = get_ip()
        port = int(self.config.kv_port) + port_offset
        if port == 0:
            raise ValueError("Port cannot be 0")
        self._hostname = hostname
        self._port = port

        # Each card corresponds to a ZMQ address.
        self.zmq_address = f"{self._hostname}:{self._port}"

        # The `http_port` must be consistent with the port of OpenAI.
        self.http_address = (
            f"{self._hostname}:"
            f"{self.config.kv_connector_extra_config['http_port']}")

        # If `proxy_ip` or `proxy_port` is `""`,
        # then the ping thread will not be enabled.
        proxy_ip = self.config.get_from_extra_config("proxy_ip", "")
        proxy_port = self.config.get_from_extra_config("proxy_port", "")
        if proxy_ip == "" or proxy_port == "":
            self.proxy_address = ""
        else:
            self.proxy_address = proxy_ip + ":" + proxy_port

        self.context = zmq.Context()
        self.router_socket = self.context.socket(zmq.ROUTER)
        self.router_socket.bind(f"tcp://{self.zmq_address}")

        self.poller = zmq.Poller()
        self.poller.register(self.router_socket, zmq.POLLIN)

        self.send_store_cv = threading.Condition()
        self.send_queue_cv = threading.Condition()
        self.recv_store_cv = threading.Condition()

        self.send_stream = torch.cuda.Stream()
        self.recv_stream = torch.cuda.Stream()

        mem_pool_size_gb = self.config.get_from_extra_config(
            "mem_pool_size_gb", DEFAULT_MEM_POOL_SIZE_GB)
        self.pool = TensorMemoryPool(max_block_size=int(mem_pool_size_gb) *
                                     1024**3)  # GB

        # The sending type includes tree mutually exclusive options:
        # PUT, GET, PUT_ASYNC.
        self.send_type = self.config.get_from_extra_config("send_type", "PUT")
        if self.send_type == "GET":
            # tensor_id: torch.Tensor
            self.send_store: dict[str, torch.Tensor] = {}
        else:
            # PUT or PUT_ASYNC
            # tensor_id: torch.Tensor
            self.send_queue: deque[list[Any]] = deque()
            self.send_request_id_to_tensor_ids: dict[str, set[str]] = {}
            if self.send_type == "PUT_ASYNC":
                self._send_thread = threading.Thread(target=self._send_async,
                                                     daemon=True)
                self._send_thread.start()

        # tensor_id: torch.Tensor/(addr, dtype, shape)
        self.recv_store: dict[str, Any] = {}
        self.recv_request_id_to_tensor_ids: dict[str, set[str]] = {}
        self.socks: dict[str, Any] = {}  # remote_address: client socket
        self.comms: dict[str, Any] = {}  # remote_address: (ncclComm_t, rank)

        self.buffer_size = 0
        self.buffer_size_threshold = float(self.config.kv_buffer_size)

        self.nccl_num_channels = self.config.get_from_extra_config(
            "nccl_num_channels", "8")

        self._listener_thread = threading.Thread(
            target=self._listen_for_requests, daemon=True)
        self._listener_thread.start()

        self._ping_thread = None
        if port_offset == 0 and self.proxy_address != "":
            self._ping_thread = threading.Thread(target=self._ping,
                                                 daemon=True)
            self._ping_thread.start()

        logger.info(
            "💯P2pNcclEngine init, rank:%d, local_rank:%d, http_address:%s, "
            "zmq_address:%s, proxy_address:%s, send_type:%s, buffer_size_"
            "threshold:%.2f, nccl_num_channels:%s", self.rank, self.local_rank,
            self.http_address, self.zmq_address, self.proxy_address,
            self.send_type, self.buffer_size_threshold, self.nccl_num_channels)

    def _create_connect(self, remote_address: typing.Optional[str] = None):
        assert remote_address is not None
        if remote_address not in self.socks:
            sock = self.context.socket(zmq.DEALER)
            sock.setsockopt_string(zmq.IDENTITY, self.zmq_address)
            sock.connect(f"tcp://{remote_address}")
            self.socks[remote_address] = sock
            if remote_address in self.comms:
                logger.info("👋comm exists, remote_address:%s, comms:%s",
                            remote_address, self.comms)
                return sock, self.comms[remote_address]

            unique_id = self.nccl.ncclGetUniqueId()
            data = {"cmd": "NEW", "unique_id": bytes(unique_id.internal)}
            sock.send(msgpack.dumps(data))

            with torch.cuda.device(self.device):
                rank = 0
                with set_p2p_nccl_context(self.nccl_num_channels):
                    comm: ncclComm_t = self.nccl.ncclCommInitRank(
                        2, unique_id, rank)
                self.comms[remote_address] = (comm, rank)
                logger.info("🤝ncclCommInitRank Success, %s👉%s, MyRank: %s",
                            self.zmq_address, remote_address, rank)

        return self.socks[remote_address], self.comms[remote_address]

    def send_tensor(
        self,
        tensor_id: str,
        tensor: torch.Tensor,
        remote_address: typing.Optional[str] = None,
    ) -> bool:
        if remote_address is None:
            with self.recv_store_cv:
                self.recv_store[tensor_id] = tensor
                self.recv_store_cv.notify()
            return True
        else:
            if self.send_type == "PUT":
                return self._send_sync(tensor_id, tensor, remote_address)
            elif self.send_type == "PUT_ASYNC":
                with self.send_queue_cv:
                    self.send_queue.append([tensor_id, remote_address, tensor])
                    self.send_queue_cv.notify()
            else:  # GET
                with self.send_store_cv:
                    tensor_size = tensor.element_size() * tensor.numel()
                    while (self.buffer_size + tensor_size
                           > self.buffer_size_threshold):
                        oldest_tenser_id = next(iter(self.send_store))
                        oldest_tenser = self.send_store.pop(oldest_tenser_id)
                        oldest_tenser_size = oldest_tenser.element_size(
                        ) * oldest_tenser.numel()
                        self.buffer_size -= oldest_tenser_size
                        logger.info(
                            "⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d,"
                            " buffer_size:%d, oldest_tenser_size:%d, rank:%d",
                            remote_address, tensor_id, tensor_size,
                            self.buffer_size, oldest_tenser_size, self.rank)

                    self.send_store[tensor_id] = tensor
                    self.buffer_size += tensor_size
                    logger.debug(
                        "🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, "
                        "shape:%s, rank:%d, buffer_size:%d(%.2f%%)",
                        remote_address, tensor_id, tensor_size, tensor.shape,
                        self.rank, self.buffer_size,
                        self.buffer_size / self.buffer_size_threshold * 100)

        return True

    def recv_tensor(
        self,
        tensor_id: str,
        remote_address: typing.Optional[str] = None,
    ) -> torch.Tensor:
        if self.send_type == "PUT" or self.send_type == "PUT_ASYNC":
            start_time = time.time()
            with self.recv_store_cv:
                while tensor_id not in self.recv_store:
                    self.recv_store_cv.wait()
                tensor = self.recv_store[tensor_id]

            if tensor is not None:
                if isinstance(tensor, tuple):
                    addr, dtype, shape = tensor
                    tensor = self.pool.load_tensor(addr, dtype, shape,
                                                   self.device)
                else:
                    self.buffer_size -= (tensor.element_size() *
                                         tensor.numel())
            else:
                duration = time.time() - start_time
                logger.warning(
                    "🔴[PUT]Recv From %s, tensor_id:%s, duration:%.3fms, "
                    "rank:%d", remote_address, tensor_id, duration * 1000,
                    self.rank)
            return tensor

        # GET
        if remote_address is None:
            return None

        if remote_address not in self.socks:
            self._create_connect(remote_address)

        sock = self.socks[remote_address]
        comm, rank = self.comms[remote_address]

        data = {"cmd": "GET", "tensor_id": tensor_id}
        sock.send(msgpack.dumps(data))

        message = sock.recv()
        data = msgpack.loads(message)
        if data["ret"] != 0:
            logger.warning("🔴[GET]Recv From %s, tensor_id: %s, ret: %d",
                           remote_address, tensor_id, data["ret"])
            return None

        tensor = torch.empty(data["shape"],
                             dtype=getattr(torch, data["dtype"]),
                             device=self.device)

        self._recv(comm, tensor, rank ^ 1, self.recv_stream)

        return tensor

    def _listen_for_requests(self):
        while True:
            socks = dict(self.poller.poll())
            if self.router_socket in socks:
                remote_address, message = self.router_socket.recv_multipart()
                data = msgpack.loads(message)
                if data["cmd"] == "NEW":
                    unique_id = self.nccl.unique_id_from_bytes(
                        bytes(data["unique_id"]))
                    with torch.cuda.device(self.device):
                        rank = 1
                        with set_p2p_nccl_context(self.nccl_num_channels):
                            comm: ncclComm_t = self.nccl.ncclCommInitRank(
                                2, unique_id, rank)
                        self.comms[remote_address.decode()] = (comm, rank)
                        logger.info(
                            "🤝ncclCommInitRank Success, %s👈%s, MyRank:%s",
                            self.zmq_address, remote_address.decode(), rank)
                elif data["cmd"] == "PUT":
                    tensor_id = data["tensor_id"]
                    try:
                        with torch.cuda.stream(self.recv_stream):
                            tensor = torch.empty(data["shape"],
                                                 dtype=getattr(
                                                     torch, data["dtype"]),
                                                 device=self.device)
                        self.router_socket.send_multipart(
                            [remote_address, b"0"])
                        comm, rank = self.comms[remote_address.decode()]
                        self._recv(comm, tensor, rank ^ 1, self.recv_stream)
                        tensor_size = tensor.element_size() * tensor.numel()
                        if (self.buffer_size + tensor_size
                                > self.buffer_size_threshold):
                            # Store Tensor in memory pool
                            addr = self.pool.store_tensor(tensor)
                            tensor = (addr, tensor.dtype, tensor.shape)
                            logger.warning(
                                "🔴[PUT]Recv Tensor, Out Of Threshold, "
                                "%s👈%s, data:%s, addr:%d", self.zmq_address,
                                remote_address.decode(), data, addr)
                        else:
                            self.buffer_size += tensor_size

                    except torch.cuda.OutOfMemoryError:
                        self.router_socket.send_multipart(
                            [remote_address, b"1"])
                        tensor = None
                        logger.warning(
                            "🔴[PUT]Recv Tensor, Out Of Memory, %s👈%s, "
                            "data:%s", self.zmq_address,
                            remote_address.decode(), data)

                    with self.recv_store_cv:
                        self.recv_store[tensor_id] = tensor
                        self._have_received_tensor_id(tensor_id)
                        self.recv_store_cv.notify()

                elif data["cmd"] == "GET":
                    tensor_id = data["tensor_id"]
                    with self.send_store_cv:
                        tensor = self.send_store.pop(tensor_id, None)
                        if tensor is not None:
                            data = {
                                "ret": 0,
                                "shape": tensor.shape,
                                "dtype":
                                str(tensor.dtype).replace("torch.", "")
                            }
                            # LRU
                            self.send_store[tensor_id] = tensor
                            self._have_sent_tensor_id(tensor_id)
                        else:
                            data = {"ret": 1}

                    self.router_socket.send_multipart(
                        [remote_address, msgpack.dumps(data)])

                    if data["ret"] == 0:
                        comm, rank = self.comms[remote_address.decode()]
                        self._send(comm, tensor.to(self.device), rank ^ 1,
                                   self.send_stream)
                else:
                    logger.warning(
                        "🚧Unexpected, Received message from %s, data:%s",
                        remote_address, data)

    def _have_sent_tensor_id(self, tensor_id: str):
        request_id = tensor_id.split('#')[0]
        if request_id not in self.send_request_id_to_tensor_ids:
            self.send_request_id_to_tensor_ids[request_id] = set()
        self.send_request_id_to_tensor_ids[request_id].add(tensor_id)

    def _have_received_tensor_id(self, tensor_id: str):
        request_id = tensor_id.split('#')[0]
        if request_id not in self.recv_request_id_to_tensor_ids:
            self.recv_request_id_to_tensor_ids[request_id] = set()
        self.recv_request_id_to_tensor_ids[request_id].add(tensor_id)

    def _send_async(self):
        while True:
            with self.send_queue_cv:
                while not self.send_queue:
                    self.send_queue_cv.wait()
                tensor_id, remote_address, tensor = self.send_queue.popleft()
                if not self.send_queue:
                    self.send_queue_cv.notify()
            self._send_sync(tensor_id, tensor, remote_address)

    def wait_for_sent(self):
        if self.send_type == "PUT_ASYNC":
            start_time = time.time()
            with self.send_queue_cv:
                while self.send_queue:
                    self.send_queue_cv.wait()
            duration = time.time() - start_time
            logger.debug(
                "🚧[PUT_ASYNC]It took %.3fms to wait for the send_queue"
                " to be empty, rank:%d", duration * 1000, self.rank)

    def _send_sync(
        self,
        tensor_id: str,
        tensor: torch.Tensor,
        remote_address: typing.Optional[str] = None,
    ) -> bool:
        if remote_address is None:
            return False
        if remote_address not in self.socks:
            self._create_connect(remote_address)

        sock = self.socks[remote_address]
        comm, rank = self.comms[remote_address]
        data = {
            "cmd": "PUT",
            "tensor_id": tensor_id,
            "shape": tensor.shape,
            "dtype": str(tensor.dtype).replace("torch.", "")
        }
        sock.send(msgpack.dumps(data))

        response = sock.recv()
        if response != b"0":
            logger.error(
                "🔴Send Tensor, Peer Out Of Memory/Threshold, %s 👉 %s, "
                "MyRank:%s, data:%s, tensor:%s, size:%fGB, response:%s",
                self.zmq_address, remote_address, rank, data, tensor.shape,
                tensor.element_size() * tensor.numel() / 1024**3,
                response.decode())
            return False

        self._send(comm, tensor.to(self.device), rank ^ 1, self.send_stream)

        if self.send_type == "PUT_ASYNC":
            self._have_sent_tensor_id(tensor_id)

        return True

    def get_finished(
        self, finished_req_ids: set[str], forward_context: "ForwardContext"
    ) -> tuple[Optional[set[str]], Optional[set[str]]]:
        """
        Notifies worker-side connector ids of requests that have
        finished generating tokens.

        Returns:
            ids of requests that have finished asynchronous transfer,
            tuple of (sending/saving ids, recving/loading ids).
            The finished saves/sends req ids must belong to a set provided in a
            call to this method (this call or a prior one).
        """

        # Clear the buffer upon request completion.
        for request_id in finished_req_ids:
            for layer_name in forward_context.no_compile_layers:
                tensor_id = request_id + "#" + layer_name
                if tensor_id in self.recv_store:
                    with self.recv_store_cv:
                        tensor = self.recv_store.pop(tensor_id, None)
                        self.send_request_id_to_tensor_ids.pop(
                            request_id, None)
                        self.recv_request_id_to_tensor_ids.pop(
                            request_id, None)
                    addr = 0
                    if isinstance(tensor, tuple):
                        addr, _, _ = tensor
                        self.pool.free(addr)

        # TODO:Retrieve requests that have already sent the KV cache.
        finished_sending: set[str] = set()

        # TODO:Retrieve requests that have already received the KV cache.
        finished_recving: set[str] = set()

        return finished_sending or None, finished_recving or None

    def _ping(self):
        sock = self.context.socket(zmq.DEALER)
        sock.setsockopt_string(zmq.IDENTITY, self.zmq_address)
        logger.debug("ping start, zmq_address:%s", self.zmq_address)
        sock.connect(f"tcp://{self.proxy_address}")
        data = {
            "type": "P" if self.config.is_kv_producer else "D",
            "http_address": self.http_address,
            "zmq_address": self.zmq_address
        }
        while True:
            sock.send(msgpack.dumps(data))
            time.sleep(3)

    def _send(self, comm, tensor: torch.Tensor, dst: int, stream=None):
        assert tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
            f"but the input tensor is on {tensor.device}")
        if stream is None:
            stream = current_stream()

        with torch.cuda.stream(stream):
            self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(),
                               ncclDataTypeEnum.from_torch(tensor.dtype), dst,
                               comm, cudaStream_t(stream.cuda_stream))
        stream.synchronize()

    def _recv(self, comm, tensor: torch.Tensor, src: int, stream=None):
        assert tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
            f"but the input tensor is on {tensor.device}")
        if stream is None:
            stream = current_stream()

        with torch.cuda.stream(stream):
            self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(),
                               ncclDataTypeEnum.from_torch(tensor.dtype), src,
                               comm, cudaStream_t(stream.cuda_stream))
        stream.synchronize()

    def close(self) -> None:
        self._listener_thread.join()
        if self.send_type == "PUT_ASYNC":
            self._send_thread.join()
        if self._ping_thread is not None:
            self._ping_thread.join()

_hostname instance-attribute

_hostname = hostname

_listener_thread instance-attribute

_listener_thread = Thread(
    target=_listen_for_requests, daemon=True
)

_ping_thread instance-attribute

_ping_thread = None

_port instance-attribute

_port = port

_send_thread instance-attribute

_send_thread = Thread(target=_send_async, daemon=True)

buffer_size instance-attribute

buffer_size = 0

buffer_size_threshold instance-attribute

buffer_size_threshold = float(kv_buffer_size)

comms instance-attribute

comms: dict[str, Any] = {}

config instance-attribute

config = config

context instance-attribute

context = Context()

device instance-attribute

device = device(f'cuda:{local_rank}')

http_address instance-attribute

http_address = (
    f"{_hostname}:{kv_connector_extra_config['http_port']}"
)

local_rank instance-attribute

local_rank = local_rank

nccl instance-attribute

nccl = NCCLLibrary(library_path)

nccl_num_channels instance-attribute

nccl_num_channels = get_from_extra_config(
    "nccl_num_channels", "8"
)

poller instance-attribute

poller = Poller()

pool instance-attribute

pool = TensorMemoryPool(
    max_block_size=int(mem_pool_size_gb) * 1024**3
)

proxy_address instance-attribute

proxy_address = ''

rank instance-attribute

rank = port_offset

recv_request_id_to_tensor_ids instance-attribute

recv_request_id_to_tensor_ids: dict[str, set[str]] = {}

recv_store instance-attribute

recv_store: dict[str, Any] = {}

recv_store_cv instance-attribute

recv_store_cv = Condition()

recv_stream instance-attribute

recv_stream = Stream()

router_socket instance-attribute

router_socket = socket(ROUTER)

send_queue instance-attribute

send_queue: deque[list[Any]] = deque()

send_queue_cv instance-attribute

send_queue_cv = Condition()

send_request_id_to_tensor_ids instance-attribute

send_request_id_to_tensor_ids: dict[str, set[str]] = {}

send_store instance-attribute

send_store: dict[str, Tensor] = {}

send_store_cv instance-attribute

send_store_cv = Condition()

send_stream instance-attribute

send_stream = Stream()

send_type instance-attribute

send_type = get_from_extra_config('send_type', 'PUT')

socks instance-attribute

socks: dict[str, Any] = {}

zmq_address instance-attribute

zmq_address = f'{_hostname}:{_port}'

__init__

__init__(
    local_rank: int,
    config: KVTransferConfig,
    hostname: str = "",
    port_offset: int = 0,
    library_path: Optional[str] = None,
) -> None
Source code in vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py
def __init__(self,
             local_rank: int,
             config: KVTransferConfig,
             hostname: str = "",
             port_offset: int = 0,
             library_path: Optional[str] = None) -> None:
    self.config = config
    self.rank = port_offset
    self.local_rank = local_rank
    self.device = torch.device(f"cuda:{self.local_rank}")
    self.nccl = NCCLLibrary(library_path)

    if not hostname:
        hostname = get_ip()
    port = int(self.config.kv_port) + port_offset
    if port == 0:
        raise ValueError("Port cannot be 0")
    self._hostname = hostname
    self._port = port

    # Each card corresponds to a ZMQ address.
    self.zmq_address = f"{self._hostname}:{self._port}"

    # The `http_port` must be consistent with the port of OpenAI.
    self.http_address = (
        f"{self._hostname}:"
        f"{self.config.kv_connector_extra_config['http_port']}")

    # If `proxy_ip` or `proxy_port` is `""`,
    # then the ping thread will not be enabled.
    proxy_ip = self.config.get_from_extra_config("proxy_ip", "")
    proxy_port = self.config.get_from_extra_config("proxy_port", "")
    if proxy_ip == "" or proxy_port == "":
        self.proxy_address = ""
    else:
        self.proxy_address = proxy_ip + ":" + proxy_port

    self.context = zmq.Context()
    self.router_socket = self.context.socket(zmq.ROUTER)
    self.router_socket.bind(f"tcp://{self.zmq_address}")

    self.poller = zmq.Poller()
    self.poller.register(self.router_socket, zmq.POLLIN)

    self.send_store_cv = threading.Condition()
    self.send_queue_cv = threading.Condition()
    self.recv_store_cv = threading.Condition()

    self.send_stream = torch.cuda.Stream()
    self.recv_stream = torch.cuda.Stream()

    mem_pool_size_gb = self.config.get_from_extra_config(
        "mem_pool_size_gb", DEFAULT_MEM_POOL_SIZE_GB)
    self.pool = TensorMemoryPool(max_block_size=int(mem_pool_size_gb) *
                                 1024**3)  # GB

    # The sending type includes tree mutually exclusive options:
    # PUT, GET, PUT_ASYNC.
    self.send_type = self.config.get_from_extra_config("send_type", "PUT")
    if self.send_type == "GET":
        # tensor_id: torch.Tensor
        self.send_store: dict[str, torch.Tensor] = {}
    else:
        # PUT or PUT_ASYNC
        # tensor_id: torch.Tensor
        self.send_queue: deque[list[Any]] = deque()
        self.send_request_id_to_tensor_ids: dict[str, set[str]] = {}
        if self.send_type == "PUT_ASYNC":
            self._send_thread = threading.Thread(target=self._send_async,
                                                 daemon=True)
            self._send_thread.start()

    # tensor_id: torch.Tensor/(addr, dtype, shape)
    self.recv_store: dict[str, Any] = {}
    self.recv_request_id_to_tensor_ids: dict[str, set[str]] = {}
    self.socks: dict[str, Any] = {}  # remote_address: client socket
    self.comms: dict[str, Any] = {}  # remote_address: (ncclComm_t, rank)

    self.buffer_size = 0
    self.buffer_size_threshold = float(self.config.kv_buffer_size)

    self.nccl_num_channels = self.config.get_from_extra_config(
        "nccl_num_channels", "8")

    self._listener_thread = threading.Thread(
        target=self._listen_for_requests, daemon=True)
    self._listener_thread.start()

    self._ping_thread = None
    if port_offset == 0 and self.proxy_address != "":
        self._ping_thread = threading.Thread(target=self._ping,
                                             daemon=True)
        self._ping_thread.start()

    logger.info(
        "💯P2pNcclEngine init, rank:%d, local_rank:%d, http_address:%s, "
        "zmq_address:%s, proxy_address:%s, send_type:%s, buffer_size_"
        "threshold:%.2f, nccl_num_channels:%s", self.rank, self.local_rank,
        self.http_address, self.zmq_address, self.proxy_address,
        self.send_type, self.buffer_size_threshold, self.nccl_num_channels)

_create_connect

_create_connect(remote_address: Optional[str] = None)
Source code in vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py
def _create_connect(self, remote_address: typing.Optional[str] = None):
    assert remote_address is not None
    if remote_address not in self.socks:
        sock = self.context.socket(zmq.DEALER)
        sock.setsockopt_string(zmq.IDENTITY, self.zmq_address)
        sock.connect(f"tcp://{remote_address}")
        self.socks[remote_address] = sock
        if remote_address in self.comms:
            logger.info("👋comm exists, remote_address:%s, comms:%s",
                        remote_address, self.comms)
            return sock, self.comms[remote_address]

        unique_id = self.nccl.ncclGetUniqueId()
        data = {"cmd": "NEW", "unique_id": bytes(unique_id.internal)}
        sock.send(msgpack.dumps(data))

        with torch.cuda.device(self.device):
            rank = 0
            with set_p2p_nccl_context(self.nccl_num_channels):
                comm: ncclComm_t = self.nccl.ncclCommInitRank(
                    2, unique_id, rank)
            self.comms[remote_address] = (comm, rank)
            logger.info("🤝ncclCommInitRank Success, %s👉%s, MyRank: %s",
                        self.zmq_address, remote_address, rank)

    return self.socks[remote_address], self.comms[remote_address]

_have_received_tensor_id

_have_received_tensor_id(tensor_id: str)
Source code in vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py
def _have_received_tensor_id(self, tensor_id: str):
    request_id = tensor_id.split('#')[0]
    if request_id not in self.recv_request_id_to_tensor_ids:
        self.recv_request_id_to_tensor_ids[request_id] = set()
    self.recv_request_id_to_tensor_ids[request_id].add(tensor_id)

_have_sent_tensor_id

_have_sent_tensor_id(tensor_id: str)
Source code in vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py
def _have_sent_tensor_id(self, tensor_id: str):
    request_id = tensor_id.split('#')[0]
    if request_id not in self.send_request_id_to_tensor_ids:
        self.send_request_id_to_tensor_ids[request_id] = set()
    self.send_request_id_to_tensor_ids[request_id].add(tensor_id)

_listen_for_requests

_listen_for_requests()
Source code in vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py
def _listen_for_requests(self):
    while True:
        socks = dict(self.poller.poll())
        if self.router_socket in socks:
            remote_address, message = self.router_socket.recv_multipart()
            data = msgpack.loads(message)
            if data["cmd"] == "NEW":
                unique_id = self.nccl.unique_id_from_bytes(
                    bytes(data["unique_id"]))
                with torch.cuda.device(self.device):
                    rank = 1
                    with set_p2p_nccl_context(self.nccl_num_channels):
                        comm: ncclComm_t = self.nccl.ncclCommInitRank(
                            2, unique_id, rank)
                    self.comms[remote_address.decode()] = (comm, rank)
                    logger.info(
                        "🤝ncclCommInitRank Success, %s👈%s, MyRank:%s",
                        self.zmq_address, remote_address.decode(), rank)
            elif data["cmd"] == "PUT":
                tensor_id = data["tensor_id"]
                try:
                    with torch.cuda.stream(self.recv_stream):
                        tensor = torch.empty(data["shape"],
                                             dtype=getattr(
                                                 torch, data["dtype"]),
                                             device=self.device)
                    self.router_socket.send_multipart(
                        [remote_address, b"0"])
                    comm, rank = self.comms[remote_address.decode()]
                    self._recv(comm, tensor, rank ^ 1, self.recv_stream)
                    tensor_size = tensor.element_size() * tensor.numel()
                    if (self.buffer_size + tensor_size
                            > self.buffer_size_threshold):
                        # Store Tensor in memory pool
                        addr = self.pool.store_tensor(tensor)
                        tensor = (addr, tensor.dtype, tensor.shape)
                        logger.warning(
                            "🔴[PUT]Recv Tensor, Out Of Threshold, "
                            "%s👈%s, data:%s, addr:%d", self.zmq_address,
                            remote_address.decode(), data, addr)
                    else:
                        self.buffer_size += tensor_size

                except torch.cuda.OutOfMemoryError:
                    self.router_socket.send_multipart(
                        [remote_address, b"1"])
                    tensor = None
                    logger.warning(
                        "🔴[PUT]Recv Tensor, Out Of Memory, %s👈%s, "
                        "data:%s", self.zmq_address,
                        remote_address.decode(), data)

                with self.recv_store_cv:
                    self.recv_store[tensor_id] = tensor
                    self._have_received_tensor_id(tensor_id)
                    self.recv_store_cv.notify()

            elif data["cmd"] == "GET":
                tensor_id = data["tensor_id"]
                with self.send_store_cv:
                    tensor = self.send_store.pop(tensor_id, None)
                    if tensor is not None:
                        data = {
                            "ret": 0,
                            "shape": tensor.shape,
                            "dtype":
                            str(tensor.dtype).replace("torch.", "")
                        }
                        # LRU
                        self.send_store[tensor_id] = tensor
                        self._have_sent_tensor_id(tensor_id)
                    else:
                        data = {"ret": 1}

                self.router_socket.send_multipart(
                    [remote_address, msgpack.dumps(data)])

                if data["ret"] == 0:
                    comm, rank = self.comms[remote_address.decode()]
                    self._send(comm, tensor.to(self.device), rank ^ 1,
                               self.send_stream)
            else:
                logger.warning(
                    "🚧Unexpected, Received message from %s, data:%s",
                    remote_address, data)

_ping

_ping()
Source code in vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py
def _ping(self):
    sock = self.context.socket(zmq.DEALER)
    sock.setsockopt_string(zmq.IDENTITY, self.zmq_address)
    logger.debug("ping start, zmq_address:%s", self.zmq_address)
    sock.connect(f"tcp://{self.proxy_address}")
    data = {
        "type": "P" if self.config.is_kv_producer else "D",
        "http_address": self.http_address,
        "zmq_address": self.zmq_address
    }
    while True:
        sock.send(msgpack.dumps(data))
        time.sleep(3)

_recv

_recv(comm, tensor: Tensor, src: int, stream=None)
Source code in vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py
def _recv(self, comm, tensor: torch.Tensor, src: int, stream=None):
    assert tensor.device == self.device, (
        f"this nccl communicator is created to work on {self.device}, "
        f"but the input tensor is on {tensor.device}")
    if stream is None:
        stream = current_stream()

    with torch.cuda.stream(stream):
        self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(),
                           ncclDataTypeEnum.from_torch(tensor.dtype), src,
                           comm, cudaStream_t(stream.cuda_stream))
    stream.synchronize()

_send

_send(comm, tensor: Tensor, dst: int, stream=None)
Source code in vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py
def _send(self, comm, tensor: torch.Tensor, dst: int, stream=None):
    assert tensor.device == self.device, (
        f"this nccl communicator is created to work on {self.device}, "
        f"but the input tensor is on {tensor.device}")
    if stream is None:
        stream = current_stream()

    with torch.cuda.stream(stream):
        self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(),
                           ncclDataTypeEnum.from_torch(tensor.dtype), dst,
                           comm, cudaStream_t(stream.cuda_stream))
    stream.synchronize()

_send_async

_send_async()
Source code in vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py
def _send_async(self):
    while True:
        with self.send_queue_cv:
            while not self.send_queue:
                self.send_queue_cv.wait()
            tensor_id, remote_address, tensor = self.send_queue.popleft()
            if not self.send_queue:
                self.send_queue_cv.notify()
        self._send_sync(tensor_id, tensor, remote_address)

_send_sync

_send_sync(
    tensor_id: str,
    tensor: Tensor,
    remote_address: Optional[str] = None,
) -> bool
Source code in vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py
def _send_sync(
    self,
    tensor_id: str,
    tensor: torch.Tensor,
    remote_address: typing.Optional[str] = None,
) -> bool:
    if remote_address is None:
        return False
    if remote_address not in self.socks:
        self._create_connect(remote_address)

    sock = self.socks[remote_address]
    comm, rank = self.comms[remote_address]
    data = {
        "cmd": "PUT",
        "tensor_id": tensor_id,
        "shape": tensor.shape,
        "dtype": str(tensor.dtype).replace("torch.", "")
    }
    sock.send(msgpack.dumps(data))

    response = sock.recv()
    if response != b"0":
        logger.error(
            "🔴Send Tensor, Peer Out Of Memory/Threshold, %s 👉 %s, "
            "MyRank:%s, data:%s, tensor:%s, size:%fGB, response:%s",
            self.zmq_address, remote_address, rank, data, tensor.shape,
            tensor.element_size() * tensor.numel() / 1024**3,
            response.decode())
        return False

    self._send(comm, tensor.to(self.device), rank ^ 1, self.send_stream)

    if self.send_type == "PUT_ASYNC":
        self._have_sent_tensor_id(tensor_id)

    return True

close

close() -> None
Source code in vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py
def close(self) -> None:
    self._listener_thread.join()
    if self.send_type == "PUT_ASYNC":
        self._send_thread.join()
    if self._ping_thread is not None:
        self._ping_thread.join()

get_finished

get_finished(
    finished_req_ids: set[str],
    forward_context: ForwardContext,
) -> tuple[Optional[set[str]], Optional[set[str]]]

Notifies worker-side connector ids of requests that have finished generating tokens.

Returns:

Type Description
Optional[set[str]]

ids of requests that have finished asynchronous transfer,

Optional[set[str]]

tuple of (sending/saving ids, recving/loading ids).

tuple[Optional[set[str]], Optional[set[str]]]

The finished saves/sends req ids must belong to a set provided in a

tuple[Optional[set[str]], Optional[set[str]]]

call to this method (this call or a prior one).

Source code in vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py
def get_finished(
    self, finished_req_ids: set[str], forward_context: "ForwardContext"
) -> tuple[Optional[set[str]], Optional[set[str]]]:
    """
    Notifies worker-side connector ids of requests that have
    finished generating tokens.

    Returns:
        ids of requests that have finished asynchronous transfer,
        tuple of (sending/saving ids, recving/loading ids).
        The finished saves/sends req ids must belong to a set provided in a
        call to this method (this call or a prior one).
    """

    # Clear the buffer upon request completion.
    for request_id in finished_req_ids:
        for layer_name in forward_context.no_compile_layers:
            tensor_id = request_id + "#" + layer_name
            if tensor_id in self.recv_store:
                with self.recv_store_cv:
                    tensor = self.recv_store.pop(tensor_id, None)
                    self.send_request_id_to_tensor_ids.pop(
                        request_id, None)
                    self.recv_request_id_to_tensor_ids.pop(
                        request_id, None)
                addr = 0
                if isinstance(tensor, tuple):
                    addr, _, _ = tensor
                    self.pool.free(addr)

    # TODO:Retrieve requests that have already sent the KV cache.
    finished_sending: set[str] = set()

    # TODO:Retrieve requests that have already received the KV cache.
    finished_recving: set[str] = set()

    return finished_sending or None, finished_recving or None

recv_tensor

recv_tensor(
    tensor_id: str, remote_address: Optional[str] = None
) -> Tensor
Source code in vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py
def recv_tensor(
    self,
    tensor_id: str,
    remote_address: typing.Optional[str] = None,
) -> torch.Tensor:
    if self.send_type == "PUT" or self.send_type == "PUT_ASYNC":
        start_time = time.time()
        with self.recv_store_cv:
            while tensor_id not in self.recv_store:
                self.recv_store_cv.wait()
            tensor = self.recv_store[tensor_id]

        if tensor is not None:
            if isinstance(tensor, tuple):
                addr, dtype, shape = tensor
                tensor = self.pool.load_tensor(addr, dtype, shape,
                                               self.device)
            else:
                self.buffer_size -= (tensor.element_size() *
                                     tensor.numel())
        else:
            duration = time.time() - start_time
            logger.warning(
                "🔴[PUT]Recv From %s, tensor_id:%s, duration:%.3fms, "
                "rank:%d", remote_address, tensor_id, duration * 1000,
                self.rank)
        return tensor

    # GET
    if remote_address is None:
        return None

    if remote_address not in self.socks:
        self._create_connect(remote_address)

    sock = self.socks[remote_address]
    comm, rank = self.comms[remote_address]

    data = {"cmd": "GET", "tensor_id": tensor_id}
    sock.send(msgpack.dumps(data))

    message = sock.recv()
    data = msgpack.loads(message)
    if data["ret"] != 0:
        logger.warning("🔴[GET]Recv From %s, tensor_id: %s, ret: %d",
                       remote_address, tensor_id, data["ret"])
        return None

    tensor = torch.empty(data["shape"],
                         dtype=getattr(torch, data["dtype"]),
                         device=self.device)

    self._recv(comm, tensor, rank ^ 1, self.recv_stream)

    return tensor

send_tensor

send_tensor(
    tensor_id: str,
    tensor: Tensor,
    remote_address: Optional[str] = None,
) -> bool
Source code in vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py
def send_tensor(
    self,
    tensor_id: str,
    tensor: torch.Tensor,
    remote_address: typing.Optional[str] = None,
) -> bool:
    if remote_address is None:
        with self.recv_store_cv:
            self.recv_store[tensor_id] = tensor
            self.recv_store_cv.notify()
        return True
    else:
        if self.send_type == "PUT":
            return self._send_sync(tensor_id, tensor, remote_address)
        elif self.send_type == "PUT_ASYNC":
            with self.send_queue_cv:
                self.send_queue.append([tensor_id, remote_address, tensor])
                self.send_queue_cv.notify()
        else:  # GET
            with self.send_store_cv:
                tensor_size = tensor.element_size() * tensor.numel()
                while (self.buffer_size + tensor_size
                       > self.buffer_size_threshold):
                    oldest_tenser_id = next(iter(self.send_store))
                    oldest_tenser = self.send_store.pop(oldest_tenser_id)
                    oldest_tenser_size = oldest_tenser.element_size(
                    ) * oldest_tenser.numel()
                    self.buffer_size -= oldest_tenser_size
                    logger.info(
                        "⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d,"
                        " buffer_size:%d, oldest_tenser_size:%d, rank:%d",
                        remote_address, tensor_id, tensor_size,
                        self.buffer_size, oldest_tenser_size, self.rank)

                self.send_store[tensor_id] = tensor
                self.buffer_size += tensor_size
                logger.debug(
                    "🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, "
                    "shape:%s, rank:%d, buffer_size:%d(%.2f%%)",
                    remote_address, tensor_id, tensor_size, tensor.shape,
                    self.rank, self.buffer_size,
                    self.buffer_size / self.buffer_size_threshold * 100)

    return True

wait_for_sent

wait_for_sent()
Source code in vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py
def wait_for_sent(self):
    if self.send_type == "PUT_ASYNC":
        start_time = time.time()
        with self.send_queue_cv:
            while self.send_queue:
                self.send_queue_cv.wait()
        duration = time.time() - start_time
        logger.debug(
            "🚧[PUT_ASYNC]It took %.3fms to wait for the send_queue"
            " to be empty, rank:%d", duration * 1000, self.rank)

set_p2p_nccl_context

set_p2p_nccl_context(num_channels: str)
Source code in vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py
@contextmanager
def set_p2p_nccl_context(num_channels: str):
    original_values: dict[str, Any] = {}
    env_vars = [
        'NCCL_MAX_NCHANNELS',
        'NCCL_MIN_NCHANNELS',
        'NCCL_CUMEM_ENABLE',
        'NCCL_BUFFSIZE',
        'NCCL_PROTO',  # LL,LL128,SIMPLE
        'NCCL_ALGO',  # RING,TREE
    ]

    for var in env_vars:
        original_values[var] = os.environ.get(var)

    logger.info("set_p2p_nccl_context, original_values: %s", original_values)

    try:
        os.environ['NCCL_MAX_NCHANNELS'] = num_channels
        os.environ['NCCL_MIN_NCHANNELS'] = num_channels
        os.environ['NCCL_CUMEM_ENABLE'] = '1'
        yield
    finally:
        for var in env_vars:
            if original_values[var] is not None:
                os.environ[var] = original_values[var]
            else:
                os.environ.pop(var, None)