Skip to content

vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector

logger module-attribute

logger = init_logger(__name__)

ReqMeta dataclass

Source code in vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py
@dataclass
class ReqMeta:
    # Request tokens
    token_ids: torch.Tensor
    # Slot mappings, should have the same length as token_ids
    slot_mapping: torch.Tensor
    # Is store or load
    is_store: bool

    @staticmethod
    def make_meta(token_ids: list[int], block_ids: list[int], block_size: int,
                  is_store: bool) -> "ReqMeta":
        valid_num_tokens = align_to_block_size(len(token_ids), block_size)
        token_ids_tensor = torch.tensor(token_ids)[:valid_num_tokens]
        block_ids_tensor = torch.tensor(block_ids)
        num_blocks = block_ids_tensor.shape[0]
        block_offsets = torch.arange(0, block_size)
        slot_mapping = block_offsets.reshape((1, block_size)) + \
                block_ids_tensor.reshape((num_blocks, 1)) * block_size
        slot_mapping = slot_mapping.flatten()[:valid_num_tokens]
        return ReqMeta(
            token_ids=token_ids_tensor,
            slot_mapping=slot_mapping,
            is_store=is_store,
        )

is_store instance-attribute

is_store: bool

slot_mapping instance-attribute

slot_mapping: Tensor

token_ids instance-attribute

token_ids: Tensor

__init__

__init__(
    token_ids: Tensor, slot_mapping: Tensor, is_store: bool
) -> None

make_meta staticmethod

make_meta(
    token_ids: list[int],
    block_ids: list[int],
    block_size: int,
    is_store: bool,
) -> ReqMeta
Source code in vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py
@staticmethod
def make_meta(token_ids: list[int], block_ids: list[int], block_size: int,
              is_store: bool) -> "ReqMeta":
    valid_num_tokens = align_to_block_size(len(token_ids), block_size)
    token_ids_tensor = torch.tensor(token_ids)[:valid_num_tokens]
    block_ids_tensor = torch.tensor(block_ids)
    num_blocks = block_ids_tensor.shape[0]
    block_offsets = torch.arange(0, block_size)
    slot_mapping = block_offsets.reshape((1, block_size)) + \
            block_ids_tensor.reshape((num_blocks, 1)) * block_size
    slot_mapping = slot_mapping.flatten()[:valid_num_tokens]
    return ReqMeta(
        token_ids=token_ids_tensor,
        slot_mapping=slot_mapping,
        is_store=is_store,
    )

SharedStorageConnector

Bases: KVConnectorBase_V1

Source code in vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py
class SharedStorageConnector(KVConnectorBase_V1):
    # NOTE: This is Simple debug implementation of the KV connector.
    # It save / load the KV cache to / from the disk.
    # It does extra work which will overwrite the existing prefix-cache in GPU
    # - to remove the overhead, need to add some "mask" in the ReqMeta class

    def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
        super().__init__(vllm_config=vllm_config, role=role)
        self._block_size = vllm_config.cache_config.block_size
        self._requests_need_load: dict[str, Request] = {}
        transfer_config = vllm_config.kv_transfer_config
        self._storage_path = transfer_config.get_from_extra_config(
            "shared_storage_path", "/tmp")
        logger.info(vllm_config.kv_transfer_config)
        logger.info("Shared storage path is %s", self._storage_path)

    def start_load_kv(self, forward_context: "ForwardContext",
                      **kwargs) -> None:
        """Start loading the KV cache from the connector buffer to vLLM's 
        paged KV buffer.

        Args:
            forward_context (ForwardContext): the forward context.
            **kwargs: additional arguments for the load operation

        Note:
            The number of elements in kv_caches and layer_names should be 
            the same.
        """
        attn_metadata = forward_context.attn_metadata

        def inject_kv_into_layer(
            dst_kv_cache_layer: torch.Tensor,
            src_kv_cache: torch.Tensor,
            slot_mapping: torch.Tensor,
        ) -> None:
            """Inject the KV cache into the layer.

            Args:
                dst_kv_cache_layer (torch.Tensor): the destination KV cache 
                    layer. In shape [2, num_pages, page_size, xxx] if not 
                    using MLA, [num_pages, page_size, xxx] otherwise.
                src_kv_cache (torch.Tensor): the source KV cache. In shape
                    [2, num_tokens, xxx] if not using MLA, [num_tokens, xxx] 
                    otherwise.
                slot_mapping (torch.Tensor): the slot mapping. In shape 
                    [num_tokens].
            """
            dst_kv_cache_layer_shape = dst_kv_cache_layer.shape
            if isinstance(attn_metadata, MLACommonMetadata):
                num_pages = dst_kv_cache_layer_shape[0]
                page_size = dst_kv_cache_layer_shape[1]
                dst_kv_cache_layer = dst_kv_cache_layer.reshape(
                    num_pages * page_size, -1)
                dst_kv_cache_layer[slot_mapping, ...] = src_kv_cache
                dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
            else:
                num_pages = dst_kv_cache_layer_shape[1]
                page_size = dst_kv_cache_layer_shape[2]
                dst_kv_cache_layer = dst_kv_cache_layer.reshape(
                    2, num_pages * page_size, -1)
                dst_kv_cache_layer[:, slot_mapping, ...] = src_kv_cache
                dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)

        # Get the metadata
        metadata: KVConnectorMetadata = self._get_connector_metadata()
        assert isinstance(metadata, SharedStorageConnectorMetadata)

        if metadata is None:
            logger.warning(
                "In connector.start_load_kv, but the connector metadata is None"
            )
            return

        attn_metadata = forward_context.attn_metadata
        if attn_metadata is None:
            logger.warning(
                "In connector.start_load_kv, but the attn_metadata is None")
            return

        # Load the KV for each request each layer
        for request in metadata.requests:
            if request.is_store:
                continue
            logger.info("Inject KV cache of %d tokens to the paged memory",
                        len(request.slot_mapping))
            for layer_name in forward_context.no_compile_layers:
                attn_layer = forward_context.no_compile_layers[layer_name]
                kv_cache_layer = attn_layer.kv_cache[\
                        forward_context.virtual_engine]

                filename = self._generate_filename_debug(
                    layer_name, request.token_ids)
                kv_cache = safetensors.torch.load_file(
                    filename)["kv_cache"].cuda()
                inject_kv_into_layer(kv_cache_layer, kv_cache,
                                     request.slot_mapping)

    def wait_for_layer_load(self, layer_name: str) -> None:
        """Blocking until the KV for a specific layer is loaded into vLLM's
        paged buffer. 

        This interface will be useful for layer-by-layer pipelining.

        Args:
            layer_name: the name of that layer
        """
        return

    def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
                      attn_metadata: "AttentionMetadata", **kwargs) -> None:
        """Start saving the KV cache of the layer from vLLM's paged buffer 
        to the connector.

        Args:
            layer_name (str): the name of the layer.
            kv_layer (torch.Tensor): the paged KV buffer of the current 
                layer in vLLM.
            attn_metadata (AttentionMetadata): the attention metadata.
            **kwargs: additional arguments for the save operation.
        """

        def extract_kv_from_layer(
            layer: torch.Tensor,
            slot_mapping: torch.Tensor,
        ) -> torch.Tensor:
            """Extract the KV cache from the layer.

            Assume the shape of the layer is (2, num_pages, page_size, xxx)
            if MLA is not used, and (num_pages, page_size, xxx) otherwise.
            """
            if isinstance(attn_metadata, MLACommonMetadata):
                num_pages, page_size = layer.shape[0], layer.shape[1]
                return layer.reshape(num_pages * page_size, -1)[slot_mapping,
                                                                ...]
            num_pages, page_size = layer.shape[1], layer.shape[2]
            return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping,
                                                               ...]

        connector_metadata = self._get_connector_metadata()
        assert isinstance(connector_metadata, SharedStorageConnectorMetadata)
        for request in connector_metadata.requests:
            if request.is_store:
                filename = self._generate_filename_debug(
                    layer_name, request.token_ids)
                kv_cache = extract_kv_from_layer(kv_layer,
                                                 request.slot_mapping)
                tensors = {"kv_cache": kv_cache.detach().cpu()}
                safetensors.torch.save_file(tensors, filename)

    def wait_for_save(self):
        return

    def get_num_new_matched_tokens(
        self,
        request: "Request",
        num_computed_tokens: int,
    ) -> tuple[int, bool]:
        """
        Get number of new tokens that can be loaded from the
        external KV cache beyond the num_computed_tokens.

        Args:
            request (Request): the request object.
            num_computed_tokens (int): the number of locally
                computed tokens for this request

        Returns:
            the number of tokens that can be loaded from the 
            external KV cache beyond what is already computed.
        """
        # NOTE: in this debug implementation, we assume that the prompt is
        # cached_prompt + newly_generated_single_token
        # Therefore, we use prompt_token_ids[:-1] to determine the folder name

        # NOTE: in current v1 scheduler, the num_computed_tokens is aligned
        # with the block granularity. And it expects the returned blocks and
        # num_computed_tokens to also be aligned with the block granularity.
        if not self._found_match_for_request(request):
            return 0, False

        logger.info("External Cache Hit!")

        # Now, first num_tokens_to_check tokens are hit, we need to prepare
        # the metadata for the worker connector to correctly load the KV
        num_tokens_to_check = align_to_block_size(
            len(request.prompt_token_ids) - 1, self._block_size)

        return num_tokens_to_check - num_computed_tokens, False

    def update_state_after_alloc(self, request: "Request",
                                 blocks: "KVCacheBlocks",
                                 num_external_tokens: int):
        """
        Update KVConnector state after block allocation.

        If blocks were allocated, add to _requests_need_load,
        such that we load the KVs in the next forward pass.
        """
        if num_external_tokens > 0:
            self._requests_need_load[request.request_id] = request

    def build_connector_meta(
        self,
        scheduler_output: SchedulerOutput,
    ) -> KVConnectorMetadata:
        """Build the connector metadata for this step.

        This function should NOT modify any fields in the scheduler_output.
        Also, calling this function will reset the state of the connector.

        Args:
            scheduler_output (SchedulerOutput): the scheduler output object.
        """
        meta = SharedStorageConnectorMetadata()

        total_need_load = 0
        for new_req in scheduler_output.scheduled_new_reqs:
            if new_req.req_id in self._requests_need_load:
                meta.add_request(token_ids=new_req.prompt_token_ids,
                                 block_ids=new_req.block_ids[0],
                                 block_size=self._block_size,
                                 is_store=False)
                total_need_load += 1
            else:
                # NOTE: here, we set the store and load being exclusive,
                # but a single request can have both store and load.
                # NOTE(rob): for this debug implementation, we only cache
                # the original prompt tokens.
                if not self._found_match_for_request(new_req):
                    meta.add_request(token_ids=new_req.prompt_token_ids,
                                     block_ids=new_req.block_ids[0],
                                     block_size=self._block_size,
                                     is_store=True)

        cached_reqs = scheduler_output.scheduled_cached_reqs
        for i, req_id in enumerate(cached_reqs.req_ids):
            num_computed_tokens = cached_reqs.num_computed_tokens[i]
            num_new_tokens = scheduler_output.num_scheduled_tokens[req_id]
            new_block_ids = cached_reqs.new_block_ids[i]
            resumed_from_preemption = cached_reqs.resumed_from_preemption[i]

            # NOTE(rob): here we rely on the resumed requests being
            # the first N requests in the list scheduled_cache_reqs.
            if not resumed_from_preemption:
                break
            if req_id in self._requests_need_load:
                # NOTE(rob): cached_req_data does not have the full
                # list of token ids (only new tokens). So we look it
                # up in the actual request object.
                request = self._requests_need_load[req_id]
                total_tokens = num_computed_tokens + num_new_tokens
                token_ids = request.all_token_ids[:total_tokens]

                # NOTE(rob): For resumed req, new_block_ids is all
                # of the block_ids for the request.
                block_ids = new_block_ids[0]

                meta.add_request(token_ids=token_ids,
                                 block_ids=block_ids,
                                 block_size=self._block_size,
                                 is_store=False)
                total_need_load += 1

        assert total_need_load == len(self._requests_need_load)
        self._requests_need_load.clear()
        return meta

    # ==============================
    # Helper functions
    # ==============================

    def _found_match_for_request(
        self,
        request: "Request",
    ) -> bool:
        """Check if the cache is hit for the request.
        """
        num_tokens_to_check = align_to_block_size(
            len(request.prompt_token_ids) - 1, self._block_size)
        foldername = self._generate_foldername_debug(torch.tensor(
            request.prompt_token_ids)[:num_tokens_to_check],
                                                     create_folder=False)
        return os.path.exists(foldername)

    def _generate_foldername_debug(
        self,
        input_ids: torch.Tensor,
        create_folder=False,
    ) -> str:
        """Generate a folder name based on the hash of the bytes of the input 
        ids.
        """
        input_ids_bytes = input_ids.numpy().tobytes()
        input_ids_hash = hashlib.md5(input_ids_bytes,
                                     usedforsecurity=False).hexdigest()
        foldername = os.path.join(self._storage_path, input_ids_hash)
        if create_folder:
            os.makedirs(foldername, exist_ok=True)
        return foldername

    def _generate_filename_debug(
        self,
        layer_name: str,
        input_ids: torch.Tensor,
    ) -> str:
        """Generate a file name based on the layer name and the hash 
        of the bytes of the input ids.
        """
        foldername = self._generate_foldername_debug(input_ids,
                                                     create_folder=True)
        return os.path.join(foldername, f"{layer_name}.safetensors")

_block_size instance-attribute

_block_size = block_size

_requests_need_load instance-attribute

_requests_need_load: dict[str, Request] = {}

_storage_path instance-attribute

_storage_path = get_from_extra_config(
    "shared_storage_path", "/tmp"
)

__init__

__init__(vllm_config: VllmConfig, role: KVConnectorRole)
Source code in vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
    super().__init__(vllm_config=vllm_config, role=role)
    self._block_size = vllm_config.cache_config.block_size
    self._requests_need_load: dict[str, Request] = {}
    transfer_config = vllm_config.kv_transfer_config
    self._storage_path = transfer_config.get_from_extra_config(
        "shared_storage_path", "/tmp")
    logger.info(vllm_config.kv_transfer_config)
    logger.info("Shared storage path is %s", self._storage_path)

_found_match_for_request

_found_match_for_request(request: Request) -> bool

Check if the cache is hit for the request.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py
def _found_match_for_request(
    self,
    request: "Request",
) -> bool:
    """Check if the cache is hit for the request.
    """
    num_tokens_to_check = align_to_block_size(
        len(request.prompt_token_ids) - 1, self._block_size)
    foldername = self._generate_foldername_debug(torch.tensor(
        request.prompt_token_ids)[:num_tokens_to_check],
                                                 create_folder=False)
    return os.path.exists(foldername)

_generate_filename_debug

_generate_filename_debug(
    layer_name: str, input_ids: Tensor
) -> str

Generate a file name based on the layer name and the hash of the bytes of the input ids.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py
def _generate_filename_debug(
    self,
    layer_name: str,
    input_ids: torch.Tensor,
) -> str:
    """Generate a file name based on the layer name and the hash 
    of the bytes of the input ids.
    """
    foldername = self._generate_foldername_debug(input_ids,
                                                 create_folder=True)
    return os.path.join(foldername, f"{layer_name}.safetensors")

_generate_foldername_debug

_generate_foldername_debug(
    input_ids: Tensor, create_folder=False
) -> str

Generate a folder name based on the hash of the bytes of the input ids.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py
def _generate_foldername_debug(
    self,
    input_ids: torch.Tensor,
    create_folder=False,
) -> str:
    """Generate a folder name based on the hash of the bytes of the input 
    ids.
    """
    input_ids_bytes = input_ids.numpy().tobytes()
    input_ids_hash = hashlib.md5(input_ids_bytes,
                                 usedforsecurity=False).hexdigest()
    foldername = os.path.join(self._storage_path, input_ids_hash)
    if create_folder:
        os.makedirs(foldername, exist_ok=True)
    return foldername

build_connector_meta

build_connector_meta(
    scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata

Build the connector metadata for this step.

This function should NOT modify any fields in the scheduler_output. Also, calling this function will reset the state of the connector.

Parameters:

Name Type Description Default
scheduler_output SchedulerOutput

the scheduler output object.

required
Source code in vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py
def build_connector_meta(
    self,
    scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata:
    """Build the connector metadata for this step.

    This function should NOT modify any fields in the scheduler_output.
    Also, calling this function will reset the state of the connector.

    Args:
        scheduler_output (SchedulerOutput): the scheduler output object.
    """
    meta = SharedStorageConnectorMetadata()

    total_need_load = 0
    for new_req in scheduler_output.scheduled_new_reqs:
        if new_req.req_id in self._requests_need_load:
            meta.add_request(token_ids=new_req.prompt_token_ids,
                             block_ids=new_req.block_ids[0],
                             block_size=self._block_size,
                             is_store=False)
            total_need_load += 1
        else:
            # NOTE: here, we set the store and load being exclusive,
            # but a single request can have both store and load.
            # NOTE(rob): for this debug implementation, we only cache
            # the original prompt tokens.
            if not self._found_match_for_request(new_req):
                meta.add_request(token_ids=new_req.prompt_token_ids,
                                 block_ids=new_req.block_ids[0],
                                 block_size=self._block_size,
                                 is_store=True)

    cached_reqs = scheduler_output.scheduled_cached_reqs
    for i, req_id in enumerate(cached_reqs.req_ids):
        num_computed_tokens = cached_reqs.num_computed_tokens[i]
        num_new_tokens = scheduler_output.num_scheduled_tokens[req_id]
        new_block_ids = cached_reqs.new_block_ids[i]
        resumed_from_preemption = cached_reqs.resumed_from_preemption[i]

        # NOTE(rob): here we rely on the resumed requests being
        # the first N requests in the list scheduled_cache_reqs.
        if not resumed_from_preemption:
            break
        if req_id in self._requests_need_load:
            # NOTE(rob): cached_req_data does not have the full
            # list of token ids (only new tokens). So we look it
            # up in the actual request object.
            request = self._requests_need_load[req_id]
            total_tokens = num_computed_tokens + num_new_tokens
            token_ids = request.all_token_ids[:total_tokens]

            # NOTE(rob): For resumed req, new_block_ids is all
            # of the block_ids for the request.
            block_ids = new_block_ids[0]

            meta.add_request(token_ids=token_ids,
                             block_ids=block_ids,
                             block_size=self._block_size,
                             is_store=False)
            total_need_load += 1

    assert total_need_load == len(self._requests_need_load)
    self._requests_need_load.clear()
    return meta

get_num_new_matched_tokens

get_num_new_matched_tokens(
    request: Request, num_computed_tokens: int
) -> tuple[int, bool]

Get number of new tokens that can be loaded from the external KV cache beyond the num_computed_tokens.

Parameters:

Name Type Description Default
request Request

the request object.

required
num_computed_tokens int

the number of locally computed tokens for this request

required

Returns:

Type Description
int

the number of tokens that can be loaded from the

bool

external KV cache beyond what is already computed.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py
def get_num_new_matched_tokens(
    self,
    request: "Request",
    num_computed_tokens: int,
) -> tuple[int, bool]:
    """
    Get number of new tokens that can be loaded from the
    external KV cache beyond the num_computed_tokens.

    Args:
        request (Request): the request object.
        num_computed_tokens (int): the number of locally
            computed tokens for this request

    Returns:
        the number of tokens that can be loaded from the 
        external KV cache beyond what is already computed.
    """
    # NOTE: in this debug implementation, we assume that the prompt is
    # cached_prompt + newly_generated_single_token
    # Therefore, we use prompt_token_ids[:-1] to determine the folder name

    # NOTE: in current v1 scheduler, the num_computed_tokens is aligned
    # with the block granularity. And it expects the returned blocks and
    # num_computed_tokens to also be aligned with the block granularity.
    if not self._found_match_for_request(request):
        return 0, False

    logger.info("External Cache Hit!")

    # Now, first num_tokens_to_check tokens are hit, we need to prepare
    # the metadata for the worker connector to correctly load the KV
    num_tokens_to_check = align_to_block_size(
        len(request.prompt_token_ids) - 1, self._block_size)

    return num_tokens_to_check - num_computed_tokens, False

save_kv_layer

save_kv_layer(
    layer_name: str,
    kv_layer: Tensor,
    attn_metadata: AttentionMetadata,
    **kwargs,
) -> None

Start saving the KV cache of the layer from vLLM's paged buffer to the connector.

Parameters:

Name Type Description Default
layer_name str

the name of the layer.

required
kv_layer Tensor

the paged KV buffer of the current layer in vLLM.

required
attn_metadata AttentionMetadata

the attention metadata.

required
**kwargs

additional arguments for the save operation.

{}
Source code in vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
                  attn_metadata: "AttentionMetadata", **kwargs) -> None:
    """Start saving the KV cache of the layer from vLLM's paged buffer 
    to the connector.

    Args:
        layer_name (str): the name of the layer.
        kv_layer (torch.Tensor): the paged KV buffer of the current 
            layer in vLLM.
        attn_metadata (AttentionMetadata): the attention metadata.
        **kwargs: additional arguments for the save operation.
    """

    def extract_kv_from_layer(
        layer: torch.Tensor,
        slot_mapping: torch.Tensor,
    ) -> torch.Tensor:
        """Extract the KV cache from the layer.

        Assume the shape of the layer is (2, num_pages, page_size, xxx)
        if MLA is not used, and (num_pages, page_size, xxx) otherwise.
        """
        if isinstance(attn_metadata, MLACommonMetadata):
            num_pages, page_size = layer.shape[0], layer.shape[1]
            return layer.reshape(num_pages * page_size, -1)[slot_mapping,
                                                            ...]
        num_pages, page_size = layer.shape[1], layer.shape[2]
        return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping,
                                                           ...]

    connector_metadata = self._get_connector_metadata()
    assert isinstance(connector_metadata, SharedStorageConnectorMetadata)
    for request in connector_metadata.requests:
        if request.is_store:
            filename = self._generate_filename_debug(
                layer_name, request.token_ids)
            kv_cache = extract_kv_from_layer(kv_layer,
                                             request.slot_mapping)
            tensors = {"kv_cache": kv_cache.detach().cpu()}
            safetensors.torch.save_file(tensors, filename)

start_load_kv

start_load_kv(
    forward_context: ForwardContext, **kwargs
) -> None

Start loading the KV cache from the connector buffer to vLLM's paged KV buffer.

Parameters:

Name Type Description Default
forward_context ForwardContext

the forward context.

required
**kwargs

additional arguments for the load operation

{}
Note

The number of elements in kv_caches and layer_names should be the same.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py
def start_load_kv(self, forward_context: "ForwardContext",
                  **kwargs) -> None:
    """Start loading the KV cache from the connector buffer to vLLM's 
    paged KV buffer.

    Args:
        forward_context (ForwardContext): the forward context.
        **kwargs: additional arguments for the load operation

    Note:
        The number of elements in kv_caches and layer_names should be 
        the same.
    """
    attn_metadata = forward_context.attn_metadata

    def inject_kv_into_layer(
        dst_kv_cache_layer: torch.Tensor,
        src_kv_cache: torch.Tensor,
        slot_mapping: torch.Tensor,
    ) -> None:
        """Inject the KV cache into the layer.

        Args:
            dst_kv_cache_layer (torch.Tensor): the destination KV cache 
                layer. In shape [2, num_pages, page_size, xxx] if not 
                using MLA, [num_pages, page_size, xxx] otherwise.
            src_kv_cache (torch.Tensor): the source KV cache. In shape
                [2, num_tokens, xxx] if not using MLA, [num_tokens, xxx] 
                otherwise.
            slot_mapping (torch.Tensor): the slot mapping. In shape 
                [num_tokens].
        """
        dst_kv_cache_layer_shape = dst_kv_cache_layer.shape
        if isinstance(attn_metadata, MLACommonMetadata):
            num_pages = dst_kv_cache_layer_shape[0]
            page_size = dst_kv_cache_layer_shape[1]
            dst_kv_cache_layer = dst_kv_cache_layer.reshape(
                num_pages * page_size, -1)
            dst_kv_cache_layer[slot_mapping, ...] = src_kv_cache
            dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
        else:
            num_pages = dst_kv_cache_layer_shape[1]
            page_size = dst_kv_cache_layer_shape[2]
            dst_kv_cache_layer = dst_kv_cache_layer.reshape(
                2, num_pages * page_size, -1)
            dst_kv_cache_layer[:, slot_mapping, ...] = src_kv_cache
            dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)

    # Get the metadata
    metadata: KVConnectorMetadata = self._get_connector_metadata()
    assert isinstance(metadata, SharedStorageConnectorMetadata)

    if metadata is None:
        logger.warning(
            "In connector.start_load_kv, but the connector metadata is None"
        )
        return

    attn_metadata = forward_context.attn_metadata
    if attn_metadata is None:
        logger.warning(
            "In connector.start_load_kv, but the attn_metadata is None")
        return

    # Load the KV for each request each layer
    for request in metadata.requests:
        if request.is_store:
            continue
        logger.info("Inject KV cache of %d tokens to the paged memory",
                    len(request.slot_mapping))
        for layer_name in forward_context.no_compile_layers:
            attn_layer = forward_context.no_compile_layers[layer_name]
            kv_cache_layer = attn_layer.kv_cache[\
                    forward_context.virtual_engine]

            filename = self._generate_filename_debug(
                layer_name, request.token_ids)
            kv_cache = safetensors.torch.load_file(
                filename)["kv_cache"].cuda()
            inject_kv_into_layer(kv_cache_layer, kv_cache,
                                 request.slot_mapping)

update_state_after_alloc

update_state_after_alloc(
    request: Request,
    blocks: KVCacheBlocks,
    num_external_tokens: int,
)

Update KVConnector state after block allocation.

If blocks were allocated, add to _requests_need_load, such that we load the KVs in the next forward pass.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py
def update_state_after_alloc(self, request: "Request",
                             blocks: "KVCacheBlocks",
                             num_external_tokens: int):
    """
    Update KVConnector state after block allocation.

    If blocks were allocated, add to _requests_need_load,
    such that we load the KVs in the next forward pass.
    """
    if num_external_tokens > 0:
        self._requests_need_load[request.request_id] = request

wait_for_layer_load

wait_for_layer_load(layer_name: str) -> None

Blocking until the KV for a specific layer is loaded into vLLM's paged buffer.

This interface will be useful for layer-by-layer pipelining.

Parameters:

Name Type Description Default
layer_name str

the name of that layer

required
Source code in vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py
def wait_for_layer_load(self, layer_name: str) -> None:
    """Blocking until the KV for a specific layer is loaded into vLLM's
    paged buffer. 

    This interface will be useful for layer-by-layer pipelining.

    Args:
        layer_name: the name of that layer
    """
    return

wait_for_save

wait_for_save()
Source code in vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py
def wait_for_save(self):
    return

SharedStorageConnectorMetadata dataclass

Bases: KVConnectorMetadata

Source code in vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py
@dataclass
class SharedStorageConnectorMetadata(KVConnectorMetadata):
    requests: list[ReqMeta]

    def __init__(self):
        self.requests = []

    def add_request(
        self,
        token_ids: list[int],
        block_ids: list[int],
        block_size: int,
        is_store: bool,
    ) -> None:
        self.requests.append(
            ReqMeta.make_meta(token_ids, block_ids, block_size, is_store))

requests instance-attribute

requests: list[ReqMeta] = []

__init__

__init__()
Source code in vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py
def __init__(self):
    self.requests = []

add_request

add_request(
    token_ids: list[int],
    block_ids: list[int],
    block_size: int,
    is_store: bool,
) -> None
Source code in vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py
def add_request(
    self,
    token_ids: list[int],
    block_ids: list[int],
    block_size: int,
    is_store: bool,
) -> None:
    self.requests.append(
        ReqMeta.make_meta(token_ids, block_ids, block_size, is_store))

align_to_block_size

align_to_block_size(num_tokens: int, block_size) -> int

Align the number of tokens to the block size.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py
def align_to_block_size(num_tokens: int, block_size) -> int:
    """Align the number of tokens to the block size.
    """
    return (num_tokens - 1) // block_size * block_size