Skip to content

vllm.model_executor.layers.mamba.mamba_mixer

MambaMixer

Bases: CustomOp

Compute ∆, A, B, C, and D the state space parameters and compute the contextualized_states. A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective) ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4, and is why Mamba is called selective state spaces)

Source code in vllm/model_executor/layers/mamba/mamba_mixer.py
@CustomOp.register("mamba_mixer")
class MambaMixer(CustomOp):
    """
    Compute ∆, A, B, C, and D the state space parameters and compute
    the `contextualized_states`. A, D are input independent
    (see Mamba paper [1] Section 3.5.2 "Interpretation of A"
    for why A isn't selective) ∆, B, C are input-dependent
    (this is a key difference between Mamba and the linear time
    invariant S4, and is why Mamba is called
    **selective** state spaces)
    """

    def __init__(self,
                 hidden_size: int,
                 ssm_state_size: int,
                 conv_kernel_size: int,
                 intermediate_size: int,
                 time_step_rank: int,
                 use_conv_bias: bool,
                 use_bias: bool,
                 use_rms_norm: bool,
                 rms_norm_has_weight: bool = True,
                 rms_norm_eps: float = 1e-5,
                 activation="silu",
                 is_lora_enabled: bool = False):
        super().__init__()
        self.time_step_rank = time_step_rank
        self.ssm_state_size = ssm_state_size
        self.use_rms_norm = use_rms_norm
        self.activation = activation
        self.is_lora_enabled = is_lora_enabled

        self.conv1d = ColumnParallelLinear(
            input_size=conv_kernel_size,
            output_size=intermediate_size,
            bias=use_conv_bias,
        )
        # unsqueeze to fit conv1d weights shape into the linear weights shape.
        # Can't do this in `weight_loader` since it already exists in
        # `ColumnParallelLinear` and `set_weight_attrs`
        # doesn't allow to override it
        self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)

        self.in_proj = MergedColumnParallelLinear(hidden_size,
                                                  [intermediate_size] * 2,
                                                  bias=use_bias)

        # selective projection used to make dt, B and C input dependent
        self.x_proj = RowParallelLinear(
            intermediate_size,
            time_step_rank + ssm_state_size * 2,
            bias=False,
        )
        # time step projection (discretization) -
        # In the forward we need to apply dt_proj without the bias,
        # as the bias is added in the selective scan kernel.
        self.dt_proj = ColumnParallelLinear(time_step_rank,
                                            intermediate_size,
                                            bias=True,
                                            skip_bias_add=True)

        def weight_loader(param: Parameter, loaded_weight: torch.Tensor):
            tp_rank = get_tensor_model_parallel_rank()
            tp_size = get_tensor_model_parallel_world_size()
            param.data.copy_(
                loaded_weight.data.split(loaded_weight.shape[0] // tp_size,
                                         dim=0)[tp_rank])

        def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor):
            weight_loader(param, -torch.exp(loaded_weight.float()))

        tp_size = get_tensor_model_parallel_world_size()
        self.A = nn.Parameter(
            torch.empty(
                intermediate_size // tp_size,
                ssm_state_size,
                dtype=torch.float32,
            ))
        self.D = nn.Parameter(torch.ones(intermediate_size // tp_size))

        set_weight_attrs(self.D, {"weight_loader": weight_loader})
        set_weight_attrs(self.A, {"weight_loader": A_weight_loader})

        self.out_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=use_bias,
            input_is_parallel=True,
        )

        self.dt_layernorm = RMSNorm(
            time_step_rank,
            eps=rms_norm_eps,
            has_weight=rms_norm_has_weight,
        ) if use_rms_norm else None

        self.b_layernorm = RMSNorm(
            ssm_state_size,
            eps=rms_norm_eps,
            has_weight=rms_norm_has_weight,
        ) if use_rms_norm else None

        self.c_layernorm = RMSNorm(
            ssm_state_size,
            eps=rms_norm_eps,
            has_weight=rms_norm_has_weight,
        ) if use_rms_norm else None

    def forward_native(self, hidden_states: torch.Tensor,
                       conv_state: torch.Tensor, ssm_state: torch.Tensor):
        pass

    def forward_cuda(self, hidden_states: torch.Tensor,
                     mamba_cache_params: MambaCacheParams):

        attn_metadata: AttentionMetadata = get_forward_context().attn_metadata

        # 1. Gated MLP's linear projection
        projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
        hidden_states, gate = projected_states.chunk(2, dim=-2)

        # 2. Convolution sequence transformation
        conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
                                               self.conv1d.weight.size(2))

        if attn_metadata.query_start_loc is not None \
            and attn_metadata.context_lens_tensor is not None:
            # |---------- N-1 iteration --------|
            # |---------------- N iteration ---------------------|
            # |- tokenA -|......................|-- newTokens ---|
            # |---------- context_len ----------|
            # |-------------------- seq_len ---------------------|
            #                                   |-- query_len ---|
            hidden_states = causal_conv1d_fn(
                hidden_states,
                conv_weights,
                self.conv1d.bias,
                activation=self.activation,
                conv_states=mamba_cache_params.conv_state,
                has_initial_state=attn_metadata.context_lens_tensor > 0,
                cache_indices=mamba_cache_params.state_indices_tensor,
                query_start_loc=attn_metadata.query_start_loc)
        else:
            hidden_states = causal_conv1d_update(
                hidden_states.transpose(0, 1),
                mamba_cache_params.conv_state,
                conv_weights,
                self.conv1d.bias,
                self.activation,
                conv_state_indices=mamba_cache_params.state_indices_tensor)
            hidden_states = hidden_states.transpose(0, 1)

        # 3. State Space Model sequence transformation
        # 3.a. input varying initialization of time_step, B and C

        if self.is_lora_enabled:
            #   lora kernel requires contiguous tensor
            ssm_parameters = self.x_proj(
                hidden_states.transpose(-2, -1).contiguous())[0]
        else:
            ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0]

        time_step, B, C = torch.split(
            ssm_parameters,
            [self.time_step_rank, self.ssm_state_size, self.ssm_state_size],
            dim=-1,
        )
        if self.use_rms_norm:
            assert self.dt_layernorm is not None
            assert self.b_layernorm is not None
            assert self.c_layernorm is not None
            time_step = self.dt_layernorm(time_step.contiguous())
            B = self.b_layernorm(B.contiguous())
            C = self.c_layernorm(C.contiguous())

        discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
        # 3.c perform the recurrence y ← SSM(A, B, C)(x)
        time_proj_bias = (self.dt_proj.bias.float() if hasattr(
            self.dt_proj, "bias") else None)

        if attn_metadata.query_start_loc is not None \
            and attn_metadata.context_lens_tensor is not None:
            scan_outputs = selective_scan_fn(
                hidden_states,
                mamba_cache_params.ssm_state,
                discrete_time_step,
                self.A,
                B.transpose(-2, -1),
                C.transpose(-2, -1),
                self.D.float(),
                gate,
                time_proj_bias,
                delta_softplus=True,
                cache_indices=mamba_cache_params.state_indices_tensor,
                has_initial_state=attn_metadata.context_lens_tensor > 0,
                query_start_loc=attn_metadata.query_start_loc)
        else:
            scan_outputs = selective_state_update(
                mamba_cache_params.ssm_state,
                hidden_states.transpose(0, 1),
                discrete_time_step.transpose(0, 1),
                self.A,
                B,
                C,
                self.D,
                gate.transpose(0, 1),
                time_proj_bias,
                dt_softplus=True,
                state_batch_indices=mamba_cache_params.state_indices_tensor)
            scan_outputs = scan_outputs.transpose(0, 1)

        # 4. Final linear projection
        if self.is_lora_enabled:
            #  lora kernel requires contiguous tensor
            contextualized_states = self.out_proj(
                scan_outputs.transpose(-2, -1).contiguous())[0]
        else:
            contextualized_states = self.out_proj(
                scan_outputs.transpose(-2, -1))[0]
        return contextualized_states

A instance-attribute

A = Parameter(
    empty(
        intermediate_size // tp_size,
        ssm_state_size,
        dtype=float32,
    )
)

D instance-attribute

D = Parameter(ones(intermediate_size // tp_size))

activation instance-attribute

activation = activation

b_layernorm instance-attribute

b_layernorm = (
    RMSNorm(
        ssm_state_size,
        eps=rms_norm_eps,
        has_weight=rms_norm_has_weight,
    )
    if use_rms_norm
    else None
)

c_layernorm instance-attribute

c_layernorm = (
    RMSNorm(
        ssm_state_size,
        eps=rms_norm_eps,
        has_weight=rms_norm_has_weight,
    )
    if use_rms_norm
    else None
)

conv1d instance-attribute

conv1d = ColumnParallelLinear(
    input_size=conv_kernel_size,
    output_size=intermediate_size,
    bias=use_conv_bias,
)

dt_layernorm instance-attribute

dt_layernorm = (
    RMSNorm(
        time_step_rank,
        eps=rms_norm_eps,
        has_weight=rms_norm_has_weight,
    )
    if use_rms_norm
    else None
)

dt_proj instance-attribute

dt_proj = ColumnParallelLinear(
    time_step_rank,
    intermediate_size,
    bias=True,
    skip_bias_add=True,
)

in_proj instance-attribute

in_proj = MergedColumnParallelLinear(
    hidden_size, [intermediate_size] * 2, bias=use_bias
)

is_lora_enabled instance-attribute

is_lora_enabled = is_lora_enabled

out_proj instance-attribute

out_proj = RowParallelLinear(
    intermediate_size,
    hidden_size,
    bias=use_bias,
    input_is_parallel=True,
)

ssm_state_size instance-attribute

ssm_state_size = ssm_state_size

time_step_rank instance-attribute

time_step_rank = time_step_rank

use_rms_norm instance-attribute

use_rms_norm = use_rms_norm

x_proj instance-attribute

x_proj = RowParallelLinear(
    intermediate_size,
    time_step_rank + ssm_state_size * 2,
    bias=False,
)

__init__

__init__(
    hidden_size: int,
    ssm_state_size: int,
    conv_kernel_size: int,
    intermediate_size: int,
    time_step_rank: int,
    use_conv_bias: bool,
    use_bias: bool,
    use_rms_norm: bool,
    rms_norm_has_weight: bool = True,
    rms_norm_eps: float = 1e-05,
    activation="silu",
    is_lora_enabled: bool = False,
)
Source code in vllm/model_executor/layers/mamba/mamba_mixer.py
def __init__(self,
             hidden_size: int,
             ssm_state_size: int,
             conv_kernel_size: int,
             intermediate_size: int,
             time_step_rank: int,
             use_conv_bias: bool,
             use_bias: bool,
             use_rms_norm: bool,
             rms_norm_has_weight: bool = True,
             rms_norm_eps: float = 1e-5,
             activation="silu",
             is_lora_enabled: bool = False):
    super().__init__()
    self.time_step_rank = time_step_rank
    self.ssm_state_size = ssm_state_size
    self.use_rms_norm = use_rms_norm
    self.activation = activation
    self.is_lora_enabled = is_lora_enabled

    self.conv1d = ColumnParallelLinear(
        input_size=conv_kernel_size,
        output_size=intermediate_size,
        bias=use_conv_bias,
    )
    # unsqueeze to fit conv1d weights shape into the linear weights shape.
    # Can't do this in `weight_loader` since it already exists in
    # `ColumnParallelLinear` and `set_weight_attrs`
    # doesn't allow to override it
    self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)

    self.in_proj = MergedColumnParallelLinear(hidden_size,
                                              [intermediate_size] * 2,
                                              bias=use_bias)

    # selective projection used to make dt, B and C input dependent
    self.x_proj = RowParallelLinear(
        intermediate_size,
        time_step_rank + ssm_state_size * 2,
        bias=False,
    )
    # time step projection (discretization) -
    # In the forward we need to apply dt_proj without the bias,
    # as the bias is added in the selective scan kernel.
    self.dt_proj = ColumnParallelLinear(time_step_rank,
                                        intermediate_size,
                                        bias=True,
                                        skip_bias_add=True)

    def weight_loader(param: Parameter, loaded_weight: torch.Tensor):
        tp_rank = get_tensor_model_parallel_rank()
        tp_size = get_tensor_model_parallel_world_size()
        param.data.copy_(
            loaded_weight.data.split(loaded_weight.shape[0] // tp_size,
                                     dim=0)[tp_rank])

    def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor):
        weight_loader(param, -torch.exp(loaded_weight.float()))

    tp_size = get_tensor_model_parallel_world_size()
    self.A = nn.Parameter(
        torch.empty(
            intermediate_size // tp_size,
            ssm_state_size,
            dtype=torch.float32,
        ))
    self.D = nn.Parameter(torch.ones(intermediate_size // tp_size))

    set_weight_attrs(self.D, {"weight_loader": weight_loader})
    set_weight_attrs(self.A, {"weight_loader": A_weight_loader})

    self.out_proj = RowParallelLinear(
        intermediate_size,
        hidden_size,
        bias=use_bias,
        input_is_parallel=True,
    )

    self.dt_layernorm = RMSNorm(
        time_step_rank,
        eps=rms_norm_eps,
        has_weight=rms_norm_has_weight,
    ) if use_rms_norm else None

    self.b_layernorm = RMSNorm(
        ssm_state_size,
        eps=rms_norm_eps,
        has_weight=rms_norm_has_weight,
    ) if use_rms_norm else None

    self.c_layernorm = RMSNorm(
        ssm_state_size,
        eps=rms_norm_eps,
        has_weight=rms_norm_has_weight,
    ) if use_rms_norm else None

forward_cuda

forward_cuda(
    hidden_states: Tensor,
    mamba_cache_params: MambaCacheParams,
)
Source code in vllm/model_executor/layers/mamba/mamba_mixer.py
def forward_cuda(self, hidden_states: torch.Tensor,
                 mamba_cache_params: MambaCacheParams):

    attn_metadata: AttentionMetadata = get_forward_context().attn_metadata

    # 1. Gated MLP's linear projection
    projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
    hidden_states, gate = projected_states.chunk(2, dim=-2)

    # 2. Convolution sequence transformation
    conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
                                           self.conv1d.weight.size(2))

    if attn_metadata.query_start_loc is not None \
        and attn_metadata.context_lens_tensor is not None:
        # |---------- N-1 iteration --------|
        # |---------------- N iteration ---------------------|
        # |- tokenA -|......................|-- newTokens ---|
        # |---------- context_len ----------|
        # |-------------------- seq_len ---------------------|
        #                                   |-- query_len ---|
        hidden_states = causal_conv1d_fn(
            hidden_states,
            conv_weights,
            self.conv1d.bias,
            activation=self.activation,
            conv_states=mamba_cache_params.conv_state,
            has_initial_state=attn_metadata.context_lens_tensor > 0,
            cache_indices=mamba_cache_params.state_indices_tensor,
            query_start_loc=attn_metadata.query_start_loc)
    else:
        hidden_states = causal_conv1d_update(
            hidden_states.transpose(0, 1),
            mamba_cache_params.conv_state,
            conv_weights,
            self.conv1d.bias,
            self.activation,
            conv_state_indices=mamba_cache_params.state_indices_tensor)
        hidden_states = hidden_states.transpose(0, 1)

    # 3. State Space Model sequence transformation
    # 3.a. input varying initialization of time_step, B and C

    if self.is_lora_enabled:
        #   lora kernel requires contiguous tensor
        ssm_parameters = self.x_proj(
            hidden_states.transpose(-2, -1).contiguous())[0]
    else:
        ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0]

    time_step, B, C = torch.split(
        ssm_parameters,
        [self.time_step_rank, self.ssm_state_size, self.ssm_state_size],
        dim=-1,
    )
    if self.use_rms_norm:
        assert self.dt_layernorm is not None
        assert self.b_layernorm is not None
        assert self.c_layernorm is not None
        time_step = self.dt_layernorm(time_step.contiguous())
        B = self.b_layernorm(B.contiguous())
        C = self.c_layernorm(C.contiguous())

    discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
    # 3.c perform the recurrence y ← SSM(A, B, C)(x)
    time_proj_bias = (self.dt_proj.bias.float() if hasattr(
        self.dt_proj, "bias") else None)

    if attn_metadata.query_start_loc is not None \
        and attn_metadata.context_lens_tensor is not None:
        scan_outputs = selective_scan_fn(
            hidden_states,
            mamba_cache_params.ssm_state,
            discrete_time_step,
            self.A,
            B.transpose(-2, -1),
            C.transpose(-2, -1),
            self.D.float(),
            gate,
            time_proj_bias,
            delta_softplus=True,
            cache_indices=mamba_cache_params.state_indices_tensor,
            has_initial_state=attn_metadata.context_lens_tensor > 0,
            query_start_loc=attn_metadata.query_start_loc)
    else:
        scan_outputs = selective_state_update(
            mamba_cache_params.ssm_state,
            hidden_states.transpose(0, 1),
            discrete_time_step.transpose(0, 1),
            self.A,
            B,
            C,
            self.D,
            gate.transpose(0, 1),
            time_proj_bias,
            dt_softplus=True,
            state_batch_indices=mamba_cache_params.state_indices_tensor)
        scan_outputs = scan_outputs.transpose(0, 1)

    # 4. Final linear projection
    if self.is_lora_enabled:
        #  lora kernel requires contiguous tensor
        contextualized_states = self.out_proj(
            scan_outputs.transpose(-2, -1).contiguous())[0]
    else:
        contextualized_states = self.out_proj(
            scan_outputs.transpose(-2, -1))[0]
    return contextualized_states

forward_native

forward_native(
    hidden_states: Tensor,
    conv_state: Tensor,
    ssm_state: Tensor,
)
Source code in vllm/model_executor/layers/mamba/mamba_mixer.py
def forward_native(self, hidden_states: torch.Tensor,
                   conv_state: torch.Tensor, ssm_state: torch.Tensor):
    pass