Skip to content

vllm.v1.attention.backends.mamba_attn

Mamba2AttentionBackend

Bases: AttentionBackend

Source code in vllm/v1/attention/backends/mamba_attn.py
class Mamba2AttentionBackend(AttentionBackend):

    @staticmethod
    def get_builder_cls() -> type["Mamba2AttentionMetadataBuilder"]:
        return Mamba2AttentionMetadataBuilder

get_builder_cls staticmethod

get_builder_cls() -> type[Mamba2AttentionMetadataBuilder]
Source code in vllm/v1/attention/backends/mamba_attn.py
@staticmethod
def get_builder_cls() -> type["Mamba2AttentionMetadataBuilder"]:
    return Mamba2AttentionMetadataBuilder

Mamba2AttentionMetadata dataclass

Source code in vllm/v1/attention/backends/mamba_attn.py
@dataclass
class Mamba2AttentionMetadata:
    num_prefills: int
    num_prefill_tokens: int
    num_decodes: int
    num_decode_tokens: int
    query_start_loc: torch.Tensor
    seq_lens: torch.Tensor

    has_initial_states: torch.Tensor
    prep_initial_states: bool
    chunk_size: int
    seq_idx: torch.Tensor
    chunk_indices: torch.Tensor
    chunk_offsets: torch.Tensor

    state_indices_tensor: torch.Tensor  # shape: [batch,]

chunk_indices instance-attribute

chunk_indices: Tensor

chunk_offsets instance-attribute

chunk_offsets: Tensor

chunk_size instance-attribute

chunk_size: int

has_initial_states instance-attribute

has_initial_states: Tensor

num_decode_tokens instance-attribute

num_decode_tokens: int

num_decodes instance-attribute

num_decodes: int

num_prefill_tokens instance-attribute

num_prefill_tokens: int

num_prefills instance-attribute

num_prefills: int

prep_initial_states instance-attribute

prep_initial_states: bool

query_start_loc instance-attribute

query_start_loc: Tensor

seq_idx instance-attribute

seq_idx: Tensor

seq_lens instance-attribute

seq_lens: Tensor

state_indices_tensor instance-attribute

state_indices_tensor: Tensor

__init__

__init__(
    num_prefills: int,
    num_prefill_tokens: int,
    num_decodes: int,
    num_decode_tokens: int,
    query_start_loc: Tensor,
    seq_lens: Tensor,
    has_initial_states: Tensor,
    prep_initial_states: bool,
    chunk_size: int,
    seq_idx: Tensor,
    chunk_indices: Tensor,
    chunk_offsets: Tensor,
    state_indices_tensor: Tensor,
) -> None

Mamba2AttentionMetadataBuilder

Bases: AttentionMetadataBuilder[Mamba2AttentionMetadata]

Source code in vllm/v1/attention/backends/mamba_attn.py
class Mamba2AttentionMetadataBuilder(
        AttentionMetadataBuilder[Mamba2AttentionMetadata]):

    def __init__(self, runner: "GPUModelRunner", kv_cache_spec: MambaSpec,
                 block_table: BlockTable):
        self.runner = runner
        self.kv_cache_spec = kv_cache_spec
        self.block_table = block_table
        self.chunk_size = get_mamba2_chunk_size(runner.vllm_config)

    def reorder_batch(self, input_batch: "InputBatch",
                      scheduler_output: "SchedulerOutput") -> bool:
        # NOTE (Chen): Copied from MLACommonMetadataBuilder and
        # FlashInferMetadataBuilder. Should be refactored later to avoid code
        # duplication of these 3 functions.
        # We now want to reorder the batch so that the "decode" requests are and
        # the front and the "prefill" requests are at the using the least amount
        # swaps possible. (NOTE for now we loosely use "decode" to mean requests
        # where attention is likely memory-bound and "prefill" to mean requests
        # where attention is likely compute-bound, TODO(lucas): figure out a
        # better naming here)
        decodes = []
        prefills = []
        num_decode_tokens = 0
        num_prefill_tokens = 0

        for i, req_id in enumerate(input_batch.req_ids):
            num_tokens = scheduler_output.num_scheduled_tokens[req_id]
            # for now treat 1 scheduled token as "decode" even if its not,
            # we should update this to something like < 8 in the future but
            # currently the decode run only supports num_tokens = 1
            if num_tokens == 1:
                decodes.append(i)
                num_decode_tokens += num_tokens
            else:
                prefills.append(i)
                num_prefill_tokens += num_tokens

        # We hope that this is fairly minimal since decodes
        # should be around for a number of iterations so hopefully they are
        # relatively stationary (and new request are generally appended to the
        # persistent batch so already should be at the back)
        # To achieve this we loop over the decodes in descending order and
        # the prefills in ascending order. We swap decodes from the  "back"
        # i.e. past where the last decode should be in the reodorered with
        # prefills from the front of the batch.
        # `decodes` and `prefills` are already in ascending order just based on
        # the above loop
        num_decodes = len(decodes)
        num_prefills = len(prefills)
        modified_batch = False

        for i in range(1, min(num_decodes, num_prefills) + 1):
            # If the decode is at the "back" of the batch, i, we can swap it
            # with the prefill closest to the front of the batch
            decode_idx = decodes[num_decodes - i]
            if decode_idx < num_decodes:
                break

            input_batch.swap_states(prefills[i - 1], decode_idx)
            modified_batch = True

        # Save for next `build` call
        # TODO(lucas): this is a bit of a hack, we should probably have a
        # better way of doing this
        self._num_decodes = num_decodes
        self._num_prefills = num_prefills
        self._num_decode_tokens = num_decode_tokens
        self._num_prefill_tokens = num_prefill_tokens

        return modified_batch

    def build(self, common_prefix_len: int,
              common_attn_metadata: CommonAttentionMetadata):
        num_reqs = common_attn_metadata.num_reqs
        query_start_loc = common_attn_metadata.query_start_loc
        seq_lens = common_attn_metadata.seq_lens

        seq_idx = None
        chunk_indices, chunk_offsets = None, None
        # Need flags to indicate if there are initial states
        # currently we really only support the FlashAttention backend
        has_initial_states = None
        prep_initial_states = False

        state_indices_tensor = self.block_table.block_table[:num_reqs, 0]

        # Compute seq_idx, chunk_indices and chunk_offsets for prefill only
        if self._num_prefills > 0:
            #[batch,]
            has_initial_states_cpu = (
                self.runner.input_batch.
                num_computed_tokens_cpu_tensor[num_reqs -
                                               self._num_prefills:num_reqs]
                > 0)
            prep_initial_states = torch.any(has_initial_states_cpu).item()
            has_initial_states = has_initial_states_cpu.to(
                query_start_loc.device)

            query_start_loc_p = common_attn_metadata.query_start_loc[
                -self._num_prefills - 1:] - self._num_decode_tokens

            seq_idx = torch.repeat_interleave(
                torch.arange(self._num_prefills,
                             dtype=torch.int32,
                             device=query_start_loc_p.device),
                query_start_loc_p.diff(),
                output_size=self._num_prefill_tokens)
            seq_idx.unsqueeze_(0)

            # We compute metadata for chunked prefill once at the top level
            # model forward and reuse them in mamba layers. If not needed,
            # they will be ignored inside mamba kernels.
            if prep_initial_states:
                chunk_indices, chunk_offsets = (
                    _query_start_loc_to_chunk_indices_offsets(
                        query_start_loc_p, self.chunk_size,
                        self._num_prefill_tokens))

        attn_metadata = Mamba2AttentionMetadata(
            num_prefills=self._num_prefills,
            num_prefill_tokens=self._num_prefill_tokens,
            num_decodes=self._num_decodes,
            num_decode_tokens=self._num_decode_tokens,
            query_start_loc=query_start_loc,
            seq_lens=seq_lens,
            has_initial_states=has_initial_states,
            prep_initial_states=prep_initial_states,
            chunk_size=self.chunk_size,
            seq_idx=seq_idx,
            chunk_indices=chunk_indices,
            chunk_offsets=chunk_offsets,
            state_indices_tensor=state_indices_tensor,
        )
        return attn_metadata

block_table instance-attribute

block_table = block_table

chunk_size instance-attribute

chunk_size = get_mamba2_chunk_size(vllm_config)

kv_cache_spec instance-attribute

kv_cache_spec = kv_cache_spec

runner instance-attribute

runner = runner

__init__

__init__(
    runner: GPUModelRunner,
    kv_cache_spec: MambaSpec,
    block_table: BlockTable,
)
Source code in vllm/v1/attention/backends/mamba_attn.py
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: MambaSpec,
             block_table: BlockTable):
    self.runner = runner
    self.kv_cache_spec = kv_cache_spec
    self.block_table = block_table
    self.chunk_size = get_mamba2_chunk_size(runner.vllm_config)

build

build(
    common_prefix_len: int,
    common_attn_metadata: CommonAttentionMetadata,
)
Source code in vllm/v1/attention/backends/mamba_attn.py
def build(self, common_prefix_len: int,
          common_attn_metadata: CommonAttentionMetadata):
    num_reqs = common_attn_metadata.num_reqs
    query_start_loc = common_attn_metadata.query_start_loc
    seq_lens = common_attn_metadata.seq_lens

    seq_idx = None
    chunk_indices, chunk_offsets = None, None
    # Need flags to indicate if there are initial states
    # currently we really only support the FlashAttention backend
    has_initial_states = None
    prep_initial_states = False

    state_indices_tensor = self.block_table.block_table[:num_reqs, 0]

    # Compute seq_idx, chunk_indices and chunk_offsets for prefill only
    if self._num_prefills > 0:
        #[batch,]
        has_initial_states_cpu = (
            self.runner.input_batch.
            num_computed_tokens_cpu_tensor[num_reqs -
                                           self._num_prefills:num_reqs]
            > 0)
        prep_initial_states = torch.any(has_initial_states_cpu).item()
        has_initial_states = has_initial_states_cpu.to(
            query_start_loc.device)

        query_start_loc_p = common_attn_metadata.query_start_loc[
            -self._num_prefills - 1:] - self._num_decode_tokens

        seq_idx = torch.repeat_interleave(
            torch.arange(self._num_prefills,
                         dtype=torch.int32,
                         device=query_start_loc_p.device),
            query_start_loc_p.diff(),
            output_size=self._num_prefill_tokens)
        seq_idx.unsqueeze_(0)

        # We compute metadata for chunked prefill once at the top level
        # model forward and reuse them in mamba layers. If not needed,
        # they will be ignored inside mamba kernels.
        if prep_initial_states:
            chunk_indices, chunk_offsets = (
                _query_start_loc_to_chunk_indices_offsets(
                    query_start_loc_p, self.chunk_size,
                    self._num_prefill_tokens))

    attn_metadata = Mamba2AttentionMetadata(
        num_prefills=self._num_prefills,
        num_prefill_tokens=self._num_prefill_tokens,
        num_decodes=self._num_decodes,
        num_decode_tokens=self._num_decode_tokens,
        query_start_loc=query_start_loc,
        seq_lens=seq_lens,
        has_initial_states=has_initial_states,
        prep_initial_states=prep_initial_states,
        chunk_size=self.chunk_size,
        seq_idx=seq_idx,
        chunk_indices=chunk_indices,
        chunk_offsets=chunk_offsets,
        state_indices_tensor=state_indices_tensor,
    )
    return attn_metadata

reorder_batch

reorder_batch(
    input_batch: InputBatch,
    scheduler_output: SchedulerOutput,
) -> bool
Source code in vllm/v1/attention/backends/mamba_attn.py
def reorder_batch(self, input_batch: "InputBatch",
                  scheduler_output: "SchedulerOutput") -> bool:
    # NOTE (Chen): Copied from MLACommonMetadataBuilder and
    # FlashInferMetadataBuilder. Should be refactored later to avoid code
    # duplication of these 3 functions.
    # We now want to reorder the batch so that the "decode" requests are and
    # the front and the "prefill" requests are at the using the least amount
    # swaps possible. (NOTE for now we loosely use "decode" to mean requests
    # where attention is likely memory-bound and "prefill" to mean requests
    # where attention is likely compute-bound, TODO(lucas): figure out a
    # better naming here)
    decodes = []
    prefills = []
    num_decode_tokens = 0
    num_prefill_tokens = 0

    for i, req_id in enumerate(input_batch.req_ids):
        num_tokens = scheduler_output.num_scheduled_tokens[req_id]
        # for now treat 1 scheduled token as "decode" even if its not,
        # we should update this to something like < 8 in the future but
        # currently the decode run only supports num_tokens = 1
        if num_tokens == 1:
            decodes.append(i)
            num_decode_tokens += num_tokens
        else:
            prefills.append(i)
            num_prefill_tokens += num_tokens

    # We hope that this is fairly minimal since decodes
    # should be around for a number of iterations so hopefully they are
    # relatively stationary (and new request are generally appended to the
    # persistent batch so already should be at the back)
    # To achieve this we loop over the decodes in descending order and
    # the prefills in ascending order. We swap decodes from the  "back"
    # i.e. past where the last decode should be in the reodorered with
    # prefills from the front of the batch.
    # `decodes` and `prefills` are already in ascending order just based on
    # the above loop
    num_decodes = len(decodes)
    num_prefills = len(prefills)
    modified_batch = False

    for i in range(1, min(num_decodes, num_prefills) + 1):
        # If the decode is at the "back" of the batch, i, we can swap it
        # with the prefill closest to the front of the batch
        decode_idx = decodes[num_decodes - i]
        if decode_idx < num_decodes:
            break

        input_batch.swap_states(prefills[i - 1], decode_idx)
        modified_batch = True

    # Save for next `build` call
    # TODO(lucas): this is a bit of a hack, we should probably have a
    # better way of doing this
    self._num_decodes = num_decodes
    self._num_prefills = num_prefills
    self._num_decode_tokens = num_decode_tokens
    self._num_prefill_tokens = num_prefill_tokens

    return modified_batch

get_mamba2_chunk_size

get_mamba2_chunk_size(vllm_config: VllmConfig) -> int
Source code in vllm/v1/attention/backends/mamba_attn.py
def get_mamba2_chunk_size(vllm_config: VllmConfig) -> int:
    from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
    layers = get_layers_from_vllm_config(vllm_config, MambaMixer2)
    chunk_sizes = set(layer.chunk_size for layer in layers.values())
    assert len(
        chunk_sizes) == 1, "All Mamba2 layers must have the same chunk size"
    return chunk_sizes.pop()