Skip to content

vllm.attention.backends.utils

Attention backend utils

PAD_SLOT_ID module-attribute

PAD_SLOT_ID = -1

STR_NOT_IMPL_ENC_DEC_ROCM_HIP module-attribute

STR_NOT_IMPL_ENC_DEC_ROCM_HIP = "ROCm/HIP is not currently supported with encoder/decoder models."

TAttentionMetadata module-attribute

TAttentionMetadata = TypeVar(
    "TAttentionMetadata", bound="AttentionMetadata"
)

_COMPUTE_SLOT_MAPPING_NUMPY_NUMEL module-attribute

_COMPUTE_SLOT_MAPPING_NUMPY_NUMEL = 256

logger module-attribute

logger = init_logger(__name__)

CommonAttentionState

Bases: AttentionState

Source code in vllm/attention/backends/utils.py
class CommonAttentionState(AttentionState):

    def __init__(self, runner: "ModelRunnerBase"):
        self.runner = runner
        self._is_graph_capturing = False

    @contextmanager
    def graph_capture(self, max_batch_size: int):

        self._is_graph_capturing = True

        self._graph_slot_mapping = torch.full((max_batch_size, ),
                                              PAD_SLOT_ID,
                                              dtype=torch.long,
                                              device=self.runner.device)
        self._graph_seq_lens = torch.ones(max_batch_size,
                                          dtype=torch.int32,
                                          device=self.runner.device)
        self._graph_block_tables = torch.from_numpy(
            self.runner.graph_block_tables).to(device=self.runner.device)

        yield

        self._is_graph_capturing = False
        del self._graph_slot_mapping
        del self._graph_seq_lens
        del self._graph_block_tables

    def graph_clone(self, batch_size: int) -> "CommonAttentionState":
        assert self._is_graph_capturing
        return self.__class__(self.runner)

    def graph_capture_get_metadata_for_batch(
            self, batch_size: int, is_encoder_decoder_model: bool = False):
        assert self._is_graph_capturing
        attn_metadata = self.runner.attn_backend.make_metadata(
            num_prefills=0,
            num_prefill_tokens=0,
            num_decode_tokens=batch_size,
            slot_mapping=self._graph_slot_mapping[:batch_size],
            multi_modal_placeholder_index_maps=None,
            enable_kv_scales_calculation=True,
            seq_lens=None,
            seq_lens_tensor=self._graph_seq_lens[:batch_size],
            max_query_len=1,
            max_decode_query_len=1,
            max_prefill_seq_len=0,
            max_decode_seq_len=self.runner.max_seq_len_to_capture,
            query_start_loc=None,
            seq_start_loc=None,
            context_lens_tensor=None,
            block_tables=self._graph_block_tables[:batch_size],
            use_cuda_graph=True,
        )
        if is_encoder_decoder_model:
            # The encoder decoder model works only with XFormers and
            # Flash Attention backend. Assert the same.
            assert self.runner.attn_backend.get_name() in \
                   ["XFORMERS", "FLASH_ATTN", "ROCM_FLASH"], \
                f"Expected attn_backend name to be either 'XFORMERS'," \
                f"'ROCM_FLASH', or 'FLASH_ATTN', but " \
                f"got '{self.runner.attn_backend.get_name()}'"
            self._update_captured_metadata_for_enc_dec_model(
                batch_size=batch_size, attn_metadata=attn_metadata)

        return attn_metadata

    def get_graph_input_buffers(
            self,
            attn_metadata,
            is_encoder_decoder_model: bool = False) -> Dict[str, Any]:
        input_buffers = {
            "slot_mapping": attn_metadata.slot_mapping,
            "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
            "block_tables": attn_metadata.decode_metadata.block_tables,
        }
        if is_encoder_decoder_model:
            # The encoder decoder model works only with XFormers and
            # Flash Attention backend. Assert the same.
            assert self.runner.attn_backend.get_name() in \
                   ["XFORMERS", "FLASH_ATTN", "ROCM_FLASH"], \
                f"Expected attn_backend name to be either 'XFORMERS'," \
                f"'ROCM_FLASH', or 'FLASH_ATTN', but " \
                f"got '{self.runner.attn_backend.get_name()}'"
            self._add_additional_input_buffers_for_enc_dec_model(
                attn_metadata=attn_metadata, input_buffers=input_buffers)
        return input_buffers

    def prepare_graph_input_buffers(
            self,
            input_buffers,
            attn_metadata,
            is_encoder_decoder_model: bool = False) -> None:
        input_buffers["seq_lens_tensor"].copy_(
            attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
        input_buffers["block_tables"].copy_(
            attn_metadata.decode_metadata.block_tables, non_blocking=True)
        if is_encoder_decoder_model:
            # The encoder decoder model works only with XFormers and
            # Flash Attention backend. Assert the same.
            assert self.runner.attn_backend.get_name() in\
                ["XFORMERS", "FLASH_ATTN"], \
                f"Expected attn_backend name to be either 'XFORMERS' or "\
                f"'FLASH_ATTN', but "\
                f"got '{self.runner.attn_backend.get_name()}'"
            self._prepare_input_buffers_for_enc_dec_model(
                attn_metadata, input_buffers)

    def begin_forward(self, model_input) -> None:
        return

    def _update_captured_metadata_for_enc_dec_model(self, batch_size: int,
                                                    attn_metadata):
        """
        Updates the attention metadata parameters for CUDA graph capture in an
        encoder-decoder model.

        This method modifies attention-related tensors and metadata required
        for CUDA graph capture in encoder-decoder models. Specifically, it
        updates the cross-attention and encoder sequence tensors in the 
        AttentionMetadata object.
        """
        # During decode phase the cross_slot_mapping will be empty. Hence set
        # an empty tensor for CUDA Graph capture.
        attn_metadata.cross_slot_mapping = torch.tensor(
            [], dtype=torch.int).cuda()
        attn_metadata.cross_block_tables = torch.full(
            (batch_size, self.runner.get_max_block_per_batch()),
            1,
            dtype=torch.int).cuda()
        attn_metadata.encoder_seq_lens = torch.full((batch_size, ),
                                                    1,
                                                    dtype=torch.int).cuda()
        attn_metadata.encoder_seq_lens_tensor = torch.full(
            (batch_size, ), 1, dtype=torch.int).cuda()
        attn_metadata.max_encoder_seq_len = self.runner.max_seq_len_to_capture
        attn_metadata.num_encoder_tokens = 0

    def _add_additional_input_buffers_for_enc_dec_model(
            self, attn_metadata, input_buffers: Dict[str, Any]):
        """
        Saves additional input buffers specific to the encoder-decoder model
        from the attention metadata.

        This method extracts and stores encoder-decoder related input buffers
        from the `attn_metadata` into the `input_buffers` dictionary. The
        buffers include encoder sequence lengths, cross-slot mappings, and
        cross-block tables, which are essential for the encoder-decoder model
        during CUDA graph replay.
        """
        input_buffers["encoder_seq_lens_tensor"] = (
            attn_metadata.decode_metadata.encoder_seq_lens_tensor)
        input_buffers["cross_slot_mapping"] = (
            attn_metadata.decode_metadata.cross_slot_mapping)
        input_buffers["cross_block_tables"] = (
            attn_metadata.decode_metadata.cross_block_tables)

    def _prepare_input_buffers_for_enc_dec_model(self, attn_metadata,
                                                 input_buffers: Dict[str,
                                                                     Any]):
        """
        Populates input buffers with data from the encoder-decoder model's
        attention metadata.

        This method fills the input buffers with encoder-decoder specific
        tensors. It copies data from the `attn_metadata` and keyword arguments
        (`kwargs`) into corresponding buffers in the `input_buffers` dictionary.
        The copied data includes attention-related metadata as well as input 
        IDs and positional information for the encoder.
        """
        input_buffers["encoder_seq_lens_tensor"].copy_(
            attn_metadata.decode_metadata.encoder_seq_lens_tensor,
            non_blocking=True)
        input_buffers["cross_slot_mapping"].copy_(
            attn_metadata.decode_metadata.cross_slot_mapping,
            non_blocking=True)
        input_buffers["cross_block_tables"].copy_(
            attn_metadata.decode_metadata.cross_block_tables,
            non_blocking=True)

_is_graph_capturing instance-attribute

_is_graph_capturing = False

runner instance-attribute

runner = runner

__init__

__init__(runner: ModelRunnerBase)
Source code in vllm/attention/backends/utils.py
def __init__(self, runner: "ModelRunnerBase"):
    self.runner = runner
    self._is_graph_capturing = False

_add_additional_input_buffers_for_enc_dec_model

_add_additional_input_buffers_for_enc_dec_model(
    attn_metadata, input_buffers: Dict[str, Any]
)

Saves additional input buffers specific to the encoder-decoder model from the attention metadata.

This method extracts and stores encoder-decoder related input buffers from the attn_metadata into the input_buffers dictionary. The buffers include encoder sequence lengths, cross-slot mappings, and cross-block tables, which are essential for the encoder-decoder model during CUDA graph replay.

Source code in vllm/attention/backends/utils.py
def _add_additional_input_buffers_for_enc_dec_model(
        self, attn_metadata, input_buffers: Dict[str, Any]):
    """
    Saves additional input buffers specific to the encoder-decoder model
    from the attention metadata.

    This method extracts and stores encoder-decoder related input buffers
    from the `attn_metadata` into the `input_buffers` dictionary. The
    buffers include encoder sequence lengths, cross-slot mappings, and
    cross-block tables, which are essential for the encoder-decoder model
    during CUDA graph replay.
    """
    input_buffers["encoder_seq_lens_tensor"] = (
        attn_metadata.decode_metadata.encoder_seq_lens_tensor)
    input_buffers["cross_slot_mapping"] = (
        attn_metadata.decode_metadata.cross_slot_mapping)
    input_buffers["cross_block_tables"] = (
        attn_metadata.decode_metadata.cross_block_tables)

_prepare_input_buffers_for_enc_dec_model

_prepare_input_buffers_for_enc_dec_model(
    attn_metadata, input_buffers: Dict[str, Any]
)

Populates input buffers with data from the encoder-decoder model's attention metadata.

This method fills the input buffers with encoder-decoder specific tensors. It copies data from the attn_metadata and keyword arguments (kwargs) into corresponding buffers in the input_buffers dictionary. The copied data includes attention-related metadata as well as input IDs and positional information for the encoder.

Source code in vllm/attention/backends/utils.py
def _prepare_input_buffers_for_enc_dec_model(self, attn_metadata,
                                             input_buffers: Dict[str,
                                                                 Any]):
    """
    Populates input buffers with data from the encoder-decoder model's
    attention metadata.

    This method fills the input buffers with encoder-decoder specific
    tensors. It copies data from the `attn_metadata` and keyword arguments
    (`kwargs`) into corresponding buffers in the `input_buffers` dictionary.
    The copied data includes attention-related metadata as well as input 
    IDs and positional information for the encoder.
    """
    input_buffers["encoder_seq_lens_tensor"].copy_(
        attn_metadata.decode_metadata.encoder_seq_lens_tensor,
        non_blocking=True)
    input_buffers["cross_slot_mapping"].copy_(
        attn_metadata.decode_metadata.cross_slot_mapping,
        non_blocking=True)
    input_buffers["cross_block_tables"].copy_(
        attn_metadata.decode_metadata.cross_block_tables,
        non_blocking=True)

_update_captured_metadata_for_enc_dec_model

_update_captured_metadata_for_enc_dec_model(
    batch_size: int, attn_metadata
)

Updates the attention metadata parameters for CUDA graph capture in an encoder-decoder model.

This method modifies attention-related tensors and metadata required for CUDA graph capture in encoder-decoder models. Specifically, it updates the cross-attention and encoder sequence tensors in the AttentionMetadata object.

Source code in vllm/attention/backends/utils.py
def _update_captured_metadata_for_enc_dec_model(self, batch_size: int,
                                                attn_metadata):
    """
    Updates the attention metadata parameters for CUDA graph capture in an
    encoder-decoder model.

    This method modifies attention-related tensors and metadata required
    for CUDA graph capture in encoder-decoder models. Specifically, it
    updates the cross-attention and encoder sequence tensors in the 
    AttentionMetadata object.
    """
    # During decode phase the cross_slot_mapping will be empty. Hence set
    # an empty tensor for CUDA Graph capture.
    attn_metadata.cross_slot_mapping = torch.tensor(
        [], dtype=torch.int).cuda()
    attn_metadata.cross_block_tables = torch.full(
        (batch_size, self.runner.get_max_block_per_batch()),
        1,
        dtype=torch.int).cuda()
    attn_metadata.encoder_seq_lens = torch.full((batch_size, ),
                                                1,
                                                dtype=torch.int).cuda()
    attn_metadata.encoder_seq_lens_tensor = torch.full(
        (batch_size, ), 1, dtype=torch.int).cuda()
    attn_metadata.max_encoder_seq_len = self.runner.max_seq_len_to_capture
    attn_metadata.num_encoder_tokens = 0

begin_forward

begin_forward(model_input) -> None
Source code in vllm/attention/backends/utils.py
def begin_forward(self, model_input) -> None:
    return

get_graph_input_buffers

get_graph_input_buffers(
    attn_metadata, is_encoder_decoder_model: bool = False
) -> Dict[str, Any]
Source code in vllm/attention/backends/utils.py
def get_graph_input_buffers(
        self,
        attn_metadata,
        is_encoder_decoder_model: bool = False) -> Dict[str, Any]:
    input_buffers = {
        "slot_mapping": attn_metadata.slot_mapping,
        "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
        "block_tables": attn_metadata.decode_metadata.block_tables,
    }
    if is_encoder_decoder_model:
        # The encoder decoder model works only with XFormers and
        # Flash Attention backend. Assert the same.
        assert self.runner.attn_backend.get_name() in \
               ["XFORMERS", "FLASH_ATTN", "ROCM_FLASH"], \
            f"Expected attn_backend name to be either 'XFORMERS'," \
            f"'ROCM_FLASH', or 'FLASH_ATTN', but " \
            f"got '{self.runner.attn_backend.get_name()}'"
        self._add_additional_input_buffers_for_enc_dec_model(
            attn_metadata=attn_metadata, input_buffers=input_buffers)
    return input_buffers

graph_capture

graph_capture(max_batch_size: int)
Source code in vllm/attention/backends/utils.py
@contextmanager
def graph_capture(self, max_batch_size: int):

    self._is_graph_capturing = True

    self._graph_slot_mapping = torch.full((max_batch_size, ),
                                          PAD_SLOT_ID,
                                          dtype=torch.long,
                                          device=self.runner.device)
    self._graph_seq_lens = torch.ones(max_batch_size,
                                      dtype=torch.int32,
                                      device=self.runner.device)
    self._graph_block_tables = torch.from_numpy(
        self.runner.graph_block_tables).to(device=self.runner.device)

    yield

    self._is_graph_capturing = False
    del self._graph_slot_mapping
    del self._graph_seq_lens
    del self._graph_block_tables

graph_capture_get_metadata_for_batch

graph_capture_get_metadata_for_batch(
    batch_size: int, is_encoder_decoder_model: bool = False
)
Source code in vllm/attention/backends/utils.py
def graph_capture_get_metadata_for_batch(
        self, batch_size: int, is_encoder_decoder_model: bool = False):
    assert self._is_graph_capturing
    attn_metadata = self.runner.attn_backend.make_metadata(
        num_prefills=0,
        num_prefill_tokens=0,
        num_decode_tokens=batch_size,
        slot_mapping=self._graph_slot_mapping[:batch_size],
        multi_modal_placeholder_index_maps=None,
        enable_kv_scales_calculation=True,
        seq_lens=None,
        seq_lens_tensor=self._graph_seq_lens[:batch_size],
        max_query_len=1,
        max_decode_query_len=1,
        max_prefill_seq_len=0,
        max_decode_seq_len=self.runner.max_seq_len_to_capture,
        query_start_loc=None,
        seq_start_loc=None,
        context_lens_tensor=None,
        block_tables=self._graph_block_tables[:batch_size],
        use_cuda_graph=True,
    )
    if is_encoder_decoder_model:
        # The encoder decoder model works only with XFormers and
        # Flash Attention backend. Assert the same.
        assert self.runner.attn_backend.get_name() in \
               ["XFORMERS", "FLASH_ATTN", "ROCM_FLASH"], \
            f"Expected attn_backend name to be either 'XFORMERS'," \
            f"'ROCM_FLASH', or 'FLASH_ATTN', but " \
            f"got '{self.runner.attn_backend.get_name()}'"
        self._update_captured_metadata_for_enc_dec_model(
            batch_size=batch_size, attn_metadata=attn_metadata)

    return attn_metadata

graph_clone

graph_clone(batch_size: int) -> CommonAttentionState
Source code in vllm/attention/backends/utils.py
def graph_clone(self, batch_size: int) -> "CommonAttentionState":
    assert self._is_graph_capturing
    return self.__class__(self.runner)

prepare_graph_input_buffers

prepare_graph_input_buffers(
    input_buffers,
    attn_metadata,
    is_encoder_decoder_model: bool = False,
) -> None
Source code in vllm/attention/backends/utils.py
def prepare_graph_input_buffers(
        self,
        input_buffers,
        attn_metadata,
        is_encoder_decoder_model: bool = False) -> None:
    input_buffers["seq_lens_tensor"].copy_(
        attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
    input_buffers["block_tables"].copy_(
        attn_metadata.decode_metadata.block_tables, non_blocking=True)
    if is_encoder_decoder_model:
        # The encoder decoder model works only with XFormers and
        # Flash Attention backend. Assert the same.
        assert self.runner.attn_backend.get_name() in\
            ["XFORMERS", "FLASH_ATTN"], \
            f"Expected attn_backend name to be either 'XFORMERS' or "\
            f"'FLASH_ATTN', but "\
            f"got '{self.runner.attn_backend.get_name()}'"
        self._prepare_input_buffers_for_enc_dec_model(
            attn_metadata, input_buffers)

CommonMetadataBuilder

Bases: AttentionMetadataBuilder[TAttentionMetadata]

Source code in vllm/attention/backends/utils.py
class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):

    _metadata_cls: Type[TAttentionMetadata]

    def __init__(self, input_builder: "ModelInputForGPUBuilder"):
        self.input_builder = input_builder
        self.runner = input_builder.runner

        self.sliding_window = input_builder.sliding_window
        self.block_size = input_builder.block_size

    def prepare(self):
        self.slot_mapping: List[int] = []
        self.prefill_seq_lens: List[int] = []
        self.context_lens: List[int] = []
        self.block_tables: List[List[int]] = []
        self.curr_seq_lens: List[int] = []
        self.multimodal_placeholder_maps: Dict[
            str,
            MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
        self.num_prefills = 0
        self.num_prefill_tokens = 0
        self.num_decode_tokens = 0

    def _add_seq_group(
            self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
            chunked_prefill_enabled: bool):
        is_prompt = inter_data.is_prompt
        block_tables = inter_data.block_tables

        for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
             curr_sliding_window_block) in zip(
                 inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
                 inter_data.orig_seq_lens, inter_data.seq_lens,
                 inter_data.query_lens, inter_data.context_lens,
                 inter_data.curr_sliding_window_blocks):
            self.context_lens.append(context_len)
            if is_prompt:
                mm_maps = inter_data.multi_modal_placeholder_maps
                if mm_maps:
                    for modality, placeholders in mm_maps.items():
                        self.multimodal_placeholder_maps[modality].extend(
                            placeholders)

                self.num_prefills += 1
                self.num_prefill_tokens += token_len
                self.prefill_seq_lens.append(seq_len)
            else:
                assert query_len == 1, (
                    "seq_len: {}, context_len: {}, query_len: {}".format(
                        seq_len, context_len, query_len))
                self.num_decode_tokens += query_len
                self.curr_seq_lens.append(curr_seq_len)

            # Compute block table.
            # TODO(sang): Combine chunked prefill and prefix caching by
            # only allowing multiple of block_size chunk size.
            # NOTE: This only works for oooooooxxx style attention.
            block_table = []
            if inter_data.prefix_cache_hit:
                block_table = block_tables[seq_id]
            elif ((chunked_prefill_enabled or not is_prompt)
                  and block_tables is not None):
                if curr_sliding_window_block == 0:
                    block_table = block_tables[seq_id]
                else:
                    block_table = block_tables[seq_id][
                        -curr_sliding_window_block:]
            self.block_tables.append(block_table)

            # Compute slot mapping.
            is_profile_run = is_block_tables_empty(block_tables)
            start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
                                                       context_len,
                                                       self.sliding_window)
            compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
                                 seq_len, context_len, start_idx,
                                 self.block_size, inter_data.block_tables)

    def build(self, seq_lens: List[int], query_lens: List[int],
              cuda_graph_pad_size: int, batch_size: int):
        """Build attention metadata with on-device tensors.

        Args:
            seq_lens: The maybe padded sequence lengths of the input sequences.
            query_lens: The query lengths of the input sequences.
            cuda_graph_pad_size: The padding size for cuda graph.
                                 -1 if cuda graph is not used.
            batch_size: The maybe padded batch size.
        """
        for inter_data in self.input_builder.inter_data_list:
            self._add_seq_group(inter_data,
                                self.input_builder.chunked_prefill_enabled)

        device = self.runner.device
        use_captured_graph = cuda_graph_pad_size != -1

        max_query_len = max(query_lens)
        max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
        max_decode_seq_len = max(self.curr_seq_lens, default=0)
        num_decode_tokens = self.num_decode_tokens
        query_start_loc = list(accumulate(query_lens, initial=0))
        seq_start_loc = list(accumulate(seq_lens, initial=0))

        if use_captured_graph:
            self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
            self.block_tables.extend([] * cuda_graph_pad_size)
            num_decode_tokens = batch_size

            # The shape of graph_block_tables is
            # [max batch size, max context len // block size].
            input_block_tables = self.runner.graph_block_tables[:batch_size]
            for i, block_table in enumerate(self.block_tables):
                if block_table:
                    input_block_tables[i, :len(block_table)] = block_table
            block_tables = torch.from_numpy(input_block_tables).to(
                device, non_blocking=True)
        else:
            block_tables = make_tensor_with_pad(
                self.block_tables,
                pad=0,
                dtype=torch.int,
                device=device,
            )
        assert max_query_len > 0, "query_lens: {}".format(query_lens)

        assert device is not None
        context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
                                               device, self.runner.pin_memory)
        seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
                                           self.runner.pin_memory)
        slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
                                               device, self.runner.pin_memory)
        query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
                                                  device,
                                                  self.runner.pin_memory)
        seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
                                                device, self.runner.pin_memory)
        placeholder_index_maps = {
            modality: placeholder_map.index_map()
            for modality, placeholder_map in
            self.multimodal_placeholder_maps.items()
        }

        return self._metadata_cls(  # type: ignore
            num_prefills=self.num_prefills,
            slot_mapping=slot_mapping_tensor,
            multi_modal_placeholder_index_maps=placeholder_index_maps,
            enable_kv_scales_calculation=True,
            num_prefill_tokens=self.num_prefill_tokens,
            num_decode_tokens=num_decode_tokens,
            seq_lens=seq_lens,
            seq_lens_tensor=seq_lens_tensor,
            max_query_len=max_query_len,
            max_prefill_seq_len=max_prefill_seq_len,
            max_decode_seq_len=max_decode_seq_len,
            query_start_loc=query_start_loc_tensor,
            seq_start_loc=seq_start_loc_tensor,
            context_lens_tensor=context_lens_tensor,
            block_tables=block_tables,
            use_cuda_graph=use_captured_graph,
        )

_metadata_cls instance-attribute

_metadata_cls: Type[TAttentionMetadata]

block_size instance-attribute

block_size = block_size

input_builder instance-attribute

input_builder = input_builder

runner instance-attribute

runner = runner

sliding_window instance-attribute

sliding_window = sliding_window

__init__

__init__(input_builder: ModelInputForGPUBuilder)
Source code in vllm/attention/backends/utils.py
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
    self.input_builder = input_builder
    self.runner = input_builder.runner

    self.sliding_window = input_builder.sliding_window
    self.block_size = input_builder.block_size

_add_seq_group

_add_seq_group(
    inter_data: InterDataForSeqGroup,
    chunked_prefill_enabled: bool,
)
Source code in vllm/attention/backends/utils.py
def _add_seq_group(
        self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
        chunked_prefill_enabled: bool):
    is_prompt = inter_data.is_prompt
    block_tables = inter_data.block_tables

    for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
         curr_sliding_window_block) in zip(
             inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
             inter_data.orig_seq_lens, inter_data.seq_lens,
             inter_data.query_lens, inter_data.context_lens,
             inter_data.curr_sliding_window_blocks):
        self.context_lens.append(context_len)
        if is_prompt:
            mm_maps = inter_data.multi_modal_placeholder_maps
            if mm_maps:
                for modality, placeholders in mm_maps.items():
                    self.multimodal_placeholder_maps[modality].extend(
                        placeholders)

            self.num_prefills += 1
            self.num_prefill_tokens += token_len
            self.prefill_seq_lens.append(seq_len)
        else:
            assert query_len == 1, (
                "seq_len: {}, context_len: {}, query_len: {}".format(
                    seq_len, context_len, query_len))
            self.num_decode_tokens += query_len
            self.curr_seq_lens.append(curr_seq_len)

        # Compute block table.
        # TODO(sang): Combine chunked prefill and prefix caching by
        # only allowing multiple of block_size chunk size.
        # NOTE: This only works for oooooooxxx style attention.
        block_table = []
        if inter_data.prefix_cache_hit:
            block_table = block_tables[seq_id]
        elif ((chunked_prefill_enabled or not is_prompt)
              and block_tables is not None):
            if curr_sliding_window_block == 0:
                block_table = block_tables[seq_id]
            else:
                block_table = block_tables[seq_id][
                    -curr_sliding_window_block:]
        self.block_tables.append(block_table)

        # Compute slot mapping.
        is_profile_run = is_block_tables_empty(block_tables)
        start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
                                                   context_len,
                                                   self.sliding_window)
        compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
                             seq_len, context_len, start_idx,
                             self.block_size, inter_data.block_tables)

build

build(
    seq_lens: List[int],
    query_lens: List[int],
    cuda_graph_pad_size: int,
    batch_size: int,
)

Build attention metadata with on-device tensors.

Parameters:

Name Type Description Default
seq_lens List[int]

The maybe padded sequence lengths of the input sequences.

required
query_lens List[int]

The query lengths of the input sequences.

required
cuda_graph_pad_size int

The padding size for cuda graph. -1 if cuda graph is not used.

required
batch_size int

The maybe padded batch size.

required
Source code in vllm/attention/backends/utils.py
def build(self, seq_lens: List[int], query_lens: List[int],
          cuda_graph_pad_size: int, batch_size: int):
    """Build attention metadata with on-device tensors.

    Args:
        seq_lens: The maybe padded sequence lengths of the input sequences.
        query_lens: The query lengths of the input sequences.
        cuda_graph_pad_size: The padding size for cuda graph.
                             -1 if cuda graph is not used.
        batch_size: The maybe padded batch size.
    """
    for inter_data in self.input_builder.inter_data_list:
        self._add_seq_group(inter_data,
                            self.input_builder.chunked_prefill_enabled)

    device = self.runner.device
    use_captured_graph = cuda_graph_pad_size != -1

    max_query_len = max(query_lens)
    max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
    max_decode_seq_len = max(self.curr_seq_lens, default=0)
    num_decode_tokens = self.num_decode_tokens
    query_start_loc = list(accumulate(query_lens, initial=0))
    seq_start_loc = list(accumulate(seq_lens, initial=0))

    if use_captured_graph:
        self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
        self.block_tables.extend([] * cuda_graph_pad_size)
        num_decode_tokens = batch_size

        # The shape of graph_block_tables is
        # [max batch size, max context len // block size].
        input_block_tables = self.runner.graph_block_tables[:batch_size]
        for i, block_table in enumerate(self.block_tables):
            if block_table:
                input_block_tables[i, :len(block_table)] = block_table
        block_tables = torch.from_numpy(input_block_tables).to(
            device, non_blocking=True)
    else:
        block_tables = make_tensor_with_pad(
            self.block_tables,
            pad=0,
            dtype=torch.int,
            device=device,
        )
    assert max_query_len > 0, "query_lens: {}".format(query_lens)

    assert device is not None
    context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
                                           device, self.runner.pin_memory)
    seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
                                       self.runner.pin_memory)
    slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
                                           device, self.runner.pin_memory)
    query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
                                              device,
                                              self.runner.pin_memory)
    seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
                                            device, self.runner.pin_memory)
    placeholder_index_maps = {
        modality: placeholder_map.index_map()
        for modality, placeholder_map in
        self.multimodal_placeholder_maps.items()
    }

    return self._metadata_cls(  # type: ignore
        num_prefills=self.num_prefills,
        slot_mapping=slot_mapping_tensor,
        multi_modal_placeholder_index_maps=placeholder_index_maps,
        enable_kv_scales_calculation=True,
        num_prefill_tokens=self.num_prefill_tokens,
        num_decode_tokens=num_decode_tokens,
        seq_lens=seq_lens,
        seq_lens_tensor=seq_lens_tensor,
        max_query_len=max_query_len,
        max_prefill_seq_len=max_prefill_seq_len,
        max_decode_seq_len=max_decode_seq_len,
        query_start_loc=query_start_loc_tensor,
        seq_start_loc=seq_start_loc_tensor,
        context_lens_tensor=context_lens_tensor,
        block_tables=block_tables,
        use_cuda_graph=use_captured_graph,
    )

prepare

prepare()
Source code in vllm/attention/backends/utils.py
def prepare(self):
    self.slot_mapping: List[int] = []
    self.prefill_seq_lens: List[int] = []
    self.context_lens: List[int] = []
    self.block_tables: List[List[int]] = []
    self.curr_seq_lens: List[int] = []
    self.multimodal_placeholder_maps: Dict[
        str,
        MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
    self.num_prefills = 0
    self.num_prefill_tokens = 0
    self.num_decode_tokens = 0

MLADims dataclass

Source code in vllm/attention/backends/utils.py
@dataclass
class MLADims:
    q_lora_rank: Optional[int]
    kv_lora_rank: int
    qk_nope_head_dim: int
    qk_rope_head_dim: int
    v_head_dim: int

kv_lora_rank instance-attribute

kv_lora_rank: int

q_lora_rank instance-attribute

q_lora_rank: Optional[int]

qk_nope_head_dim instance-attribute

qk_nope_head_dim: int

qk_rope_head_dim instance-attribute

qk_rope_head_dim: int

v_head_dim instance-attribute

v_head_dim: int

__init__

__init__(
    q_lora_rank: Optional[int],
    kv_lora_rank: int,
    qk_nope_head_dim: int,
    qk_rope_head_dim: int,
    v_head_dim: int,
) -> None

_compute_slot_mapping_numpy

_compute_slot_mapping_numpy(
    slot_mapping: List[int],
    block_table: List[int],
    range_start: int,
    range_end: int,
    block_size: int,
)
Source code in vllm/attention/backends/utils.py
def _compute_slot_mapping_numpy(slot_mapping: List[int],
                                block_table: List[int], range_start: int,
                                range_end: int, block_size: int):
    block_table_array = np.array(block_table)
    idx = np.arange(range_start, range_end)
    block_offset = idx % block_size
    idx //= block_size
    seq_slot_mapping_array = block_table_array[idx]
    seq_slot_mapping_array *= block_size
    seq_slot_mapping_array += block_offset
    slot_mapping.extend(seq_slot_mapping_array)

_compute_slot_mapping_python

_compute_slot_mapping_python(
    slot_mapping: List[int],
    block_table: List[int],
    range_start: int,
    range_end: int,
    block_size: int,
)
Source code in vllm/attention/backends/utils.py
def _compute_slot_mapping_python(slot_mapping: List[int],
                                 block_table: List[int], range_start: int,
                                 range_end: int, block_size: int):
    for i in range(range_start, range_end):
        block_number = block_table[i // block_size]
        block_offset = i % block_size
        slot = block_number * block_size + block_offset
        slot_mapping.append(slot)

compute_slot_mapping

compute_slot_mapping(
    is_profile_run: bool,
    slot_mapping: List[int],
    seq_id: int,
    seq_len: int,
    context_len: int,
    start_idx: int,
    block_size: int,
    block_tables: Dict[int, List[int]],
)

Compute slot mapping.

Source code in vllm/attention/backends/utils.py
def compute_slot_mapping(is_profile_run: bool, slot_mapping: List[int],
                         seq_id: int, seq_len: int, context_len: int,
                         start_idx: int, block_size: int,
                         block_tables: Dict[int, List[int]]):
    """
    Compute slot mapping.
    """
    if is_profile_run:
        # During memory profiling, the block tables are not
        # initialized yet. In this case, we just use a dummy
        # slot mapping.
        # In embeddings, the block tables are {seq_id: None}.
        slot_mapping.extend([PAD_SLOT_ID] * seq_len)
        return

    # Mask the [0, start_idx) tokens of the prompt with
    # PAD_SLOT_ID, where start_idx is max(0, seq_len -
    # sliding_window). For example, if the prompt len is 10,
    # sliding window is 8, and block size is 4, the first two
    # tokens are masked and the slot mapping will be
    # [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
    padding_mask_len = max(0, start_idx - context_len)
    slot_mapping.extend([PAD_SLOT_ID] * padding_mask_len)

    range_start = max(start_idx, context_len)
    range_end = seq_len
    numel = range_end - range_start
    block_table = block_tables[seq_id]

    # numpy implementation will be faster than python if we have
    # many elements, otherwise it will be slower.
    if numel < _COMPUTE_SLOT_MAPPING_NUMPY_NUMEL:
        _compute_slot_mapping_python(slot_mapping, block_table, range_start,
                                     range_end, block_size)
    else:
        _compute_slot_mapping_numpy(slot_mapping, block_table, range_start,
                                    range_end, block_size)

compute_slot_mapping_start_idx

compute_slot_mapping_start_idx(
    is_prompt: bool,
    query_len: int,
    context_len: int,
    sliding_window: int,
)

Compute the start index of slot mapping.

Source code in vllm/attention/backends/utils.py
def compute_slot_mapping_start_idx(is_prompt: bool, query_len: int,
                                   context_len: int, sliding_window: int):
    """
    Compute the start index of slot mapping.
    """
    start_idx = 0
    if is_prompt and sliding_window is not None:
        start_idx = max(0, query_len - sliding_window)
    return start_idx

get_mla_dims

get_mla_dims(model_config: ModelConfig) -> MLADims
Source code in vllm/attention/backends/utils.py
def get_mla_dims(model_config: ModelConfig) -> MLADims:
    hf_text_config = model_config.hf_text_config

    return MLADims(
        q_lora_rank=getattr(hf_text_config, "q_lora_rank", None),
        kv_lora_rank=hf_text_config.kv_lora_rank,
        qk_nope_head_dim=hf_text_config.qk_nope_head_dim,
        qk_rope_head_dim=hf_text_config.qk_rope_head_dim,
        v_head_dim=hf_text_config.v_head_dim,
    )

get_num_prefill_decode_query_kv_tokens

get_num_prefill_decode_query_kv_tokens(
    attn_metadata, attn_type: str
) -> Tuple[int, int, int]

Calculate the number of prefill and decode tokens for query, key/value based on the attention metadata and the specified attention type.

Parameters:

Name Type Description Default
attn_metadata AttentionMetadata

Attention Metadata object.

required
attn_type AttentionType

The type of attention being used.

required

Returns: Tuple[int, int, int]: A tuple containing three integers: - The number of prefill query tokens. - The number of prefill key/value tokens. - The number of decode query tokens.

Raises:

Type Description
AssertionError

If the number of encoder tokens in attn_metadata

Source code in vllm/attention/backends/utils.py
def get_num_prefill_decode_query_kv_tokens(
    attn_metadata,
    attn_type: str,
) -> Tuple[int, int, int]:
    """
    Calculate the number of prefill and decode tokens for query, key/value
    based on the attention metadata and the specified attention type.

    Args:
        attn_metadata (AttentionMetadata): Attention Metadata object.
        attn_type (AttentionType): The type of attention being used.
    Returns:
        Tuple[int, int, int]: A tuple containing three integers:
            - The number of prefill query tokens.
            - The number of prefill key/value tokens.
            - The number of decode query tokens.

    Raises:
        AssertionError: If the number of encoder tokens in `attn_metadata` 
        is `None` when required for the calculations.
    """
    num_prefill_query_tokens = 0
    num_decode_query_tokens = 0
    num_prefill_kv_tokens = 0
    if attn_type == AttentionType.ENCODER:
        # Encoder attention is only invoked during prefill phase.
        # The same input servers a both query and key.
        assert attn_metadata.num_encoder_tokens is not None
        num_prefill_query_tokens = attn_metadata.num_encoder_tokens
        num_prefill_kv_tokens = attn_metadata.num_encoder_tokens
        num_decode_query_tokens = 0
    elif attn_type == AttentionType.ENCODER_DECODER:
        assert attn_metadata.num_encoder_tokens is not None
        num_prefill_query_tokens = attn_metadata.num_prefill_tokens
        # The key is the encoder/cross-attention.
        num_prefill_kv_tokens = attn_metadata.num_encoder_tokens
        num_decode_query_tokens = attn_metadata.num_decode_tokens
    else:  # attn_type == AttentionType.DECODER or
        # attn_type == AttentionType.ENCODER_ONLY
        num_prefill_query_tokens = attn_metadata.num_prefill_tokens
        num_prefill_kv_tokens = attn_metadata.num_prefill_tokens
        num_decode_query_tokens = attn_metadata.num_decode_tokens

    return (num_prefill_query_tokens, num_prefill_kv_tokens,
            num_decode_query_tokens)

get_seq_len_block_table_args

get_seq_len_block_table_args(
    attn_metadata, is_prompt: bool, attn_type: str
) -> tuple

The particular choice of sequence-length- and block-table-related attributes which should be extracted from attn_metadata is dependent on the type of attention operation.

Decoder attn -> select entirely decoder self-attention-related fields Encoder/decoder cross-attn -> select encoder sequence lengths & cross-attn block-tables fields Encoder attn -> select encoder sequence lengths fields & no block tables

Arguments:

  • attn_metadata: Attention metadata structure associated with attention op
  • is_prompt: True if prefill, False otherwise
  • attn_type: encoder attention, decoder self-attention, encoder/decoder cross-attention

Returns:

  • Appropriate sequence-lengths tensor
  • Appropriate max sequence-length scalar
  • Appropriate block tables (or None)
Source code in vllm/attention/backends/utils.py
def get_seq_len_block_table_args(
    attn_metadata,
    is_prompt: bool,
    attn_type: str,
) -> tuple:
    '''
    The particular choice of sequence-length- and block-table-related
    attributes which should be extracted from attn_metadata is dependent
    on the type of attention operation.

    Decoder attn -> select entirely decoder self-attention-related fields
    Encoder/decoder cross-attn -> select encoder sequence lengths & 
                                  cross-attn block-tables fields
    Encoder attn -> select encoder sequence lengths fields & no block tables

    Arguments:

    * attn_metadata: Attention metadata structure associated with attention op
    * is_prompt: True if prefill, False otherwise
    * attn_type: encoder attention, decoder self-attention,
                 encoder/decoder cross-attention

    Returns:

    * Appropriate sequence-lengths tensor
    * Appropriate max sequence-length scalar
    * Appropriate block tables (or None)
    '''

    if attn_type == AttentionType.DECODER:
        # Decoder self-attention
        # Choose max_seq_len based on whether we are in prompt_run
        if is_prompt:
            max_seq_len = attn_metadata.max_prefill_seq_len
        else:
            max_seq_len = attn_metadata.max_decode_seq_len
        return (attn_metadata.seq_lens_tensor, max_seq_len,
                attn_metadata.block_tables)
    elif attn_type == AttentionType.ENCODER_DECODER:
        # Enc/dec cross-attention KVs match encoder sequence length;
        # cross-attention utilizes special "cross" block tables
        return (attn_metadata.encoder_seq_lens_tensor,
                attn_metadata.max_encoder_seq_len,
                attn_metadata.cross_block_tables)
    elif attn_type == AttentionType.ENCODER:
        # No block tables associated with encoder attention
        return (attn_metadata.encoder_seq_lens_tensor,
                attn_metadata.max_encoder_seq_len, None)
    else:
        raise AttributeError(f"Invalid attention type {str(attn_type)}")

is_all_cross_attn_metadata_set

is_all_cross_attn_metadata_set(attn_metadata)

All attention metadata required for enc/dec cross-attention is set.

Superset of encoder attention required metadata.

Source code in vllm/attention/backends/utils.py
def is_all_cross_attn_metadata_set(attn_metadata):
    '''
    All attention metadata required for enc/dec cross-attention is set.

    Superset of encoder attention required metadata.
    '''
    return (attn_metadata.is_all_encoder_attn_metadata_set
            and (attn_metadata.cross_slot_mapping is not None)
            and (attn_metadata.cross_block_tables is not None))

is_all_encoder_attn_metadata_set

is_all_encoder_attn_metadata_set(attn_metadata)

All attention metadata required for encoder attention is set.

Source code in vllm/attention/backends/utils.py
def is_all_encoder_attn_metadata_set(attn_metadata):
    '''
    All attention metadata required for encoder attention is set.
    '''
    return ((attn_metadata.encoder_seq_lens is not None)
            and (attn_metadata.encoder_seq_lens_tensor is not None)
            and (attn_metadata.max_encoder_seq_len is not None))

is_block_tables_empty

is_block_tables_empty(block_tables: Union[None, Dict])

Check if block_tables is None or a dictionary with all None values.

Source code in vllm/attention/backends/utils.py
def is_block_tables_empty(block_tables: Union[None, Dict]):
    """
    Check if block_tables is None or a dictionary with all None values.
    """
    if block_tables is None:
        return True
    return (isinstance(block_tables, dict)
            and all(value is None for value in block_tables.values()))