Skip to content

vllm.model_executor.models.minimax_text_01

Inference-only MiniMaxText01 model.

MiniMaxText01Attention

Bases: Module

Source code in vllm/model_executor/models/minimax_text_01.py
class MiniMaxText01Attention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        head_dim: int,
        num_kv_heads: int,
        rotary_dim: int,
        max_position: int = 4096 * 32,
        rope_theta: float = 10000,
        sliding_window: Optional[int] = None,
        quant_config: Optional[QuantizationConfig] = None,
        layer_idx: int = None,
        cache_config: Optional[CacheConfig] = None,
        prefix: str = "mha",
    ) -> None:
        super().__init__()
        self.layer_idx = layer_idx

        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            assert self.total_num_kv_heads % tp_size == 0
        else:
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = head_dim

        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta
        self.sliding_window = sliding_window

        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
        return

    def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor,
                **kwargs) -> torch.Tensor:
        forward_context = get_forward_context()
        attn_metadata = forward_context.attn_metadata
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = attn_metadata.rotary_emb(positions, q, k)
        attn_output = self.attn(q, k, v)
        output, _ = self.o_proj(attn_output)
        return output

attn instance-attribute

attn = Attention(
    num_heads,
    head_dim,
    scaling,
    num_kv_heads=num_kv_heads,
    cache_config=cache_config,
    quant_config=quant_config,
    prefix=f"{prefix}.attn",
)

head_dim instance-attribute

head_dim = head_dim

hidden_size instance-attribute

hidden_size = hidden_size

kv_size instance-attribute

kv_size = num_kv_heads * head_dim

layer_idx instance-attribute

layer_idx = layer_idx

num_heads instance-attribute

num_heads = total_num_heads // tp_size

num_kv_heads instance-attribute

num_kv_heads = max(1, total_num_kv_heads // tp_size)

o_proj instance-attribute

o_proj = RowParallelLinear(
    total_num_heads * head_dim,
    hidden_size,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.o_proj",
)

q_size instance-attribute

q_size = num_heads * head_dim

qkv_proj instance-attribute

qkv_proj = QKVParallelLinear(
    hidden_size,
    head_dim,
    total_num_heads,
    total_num_kv_heads,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.qkv_proj",
)

rope_theta instance-attribute

rope_theta = rope_theta

scaling instance-attribute

scaling = head_dim ** -0.5

sliding_window instance-attribute

sliding_window = sliding_window

total_num_heads instance-attribute

total_num_heads = num_heads

total_num_kv_heads instance-attribute

total_num_kv_heads = num_kv_heads

__init__

__init__(
    hidden_size: int,
    num_heads: int,
    head_dim: int,
    num_kv_heads: int,
    rotary_dim: int,
    max_position: int = 4096 * 32,
    rope_theta: float = 10000,
    sliding_window: Optional[int] = None,
    quant_config: Optional[QuantizationConfig] = None,
    layer_idx: int = None,
    cache_config: Optional[CacheConfig] = None,
    prefix: str = "mha",
) -> None
Source code in vllm/model_executor/models/minimax_text_01.py
def __init__(
    self,
    hidden_size: int,
    num_heads: int,
    head_dim: int,
    num_kv_heads: int,
    rotary_dim: int,
    max_position: int = 4096 * 32,
    rope_theta: float = 10000,
    sliding_window: Optional[int] = None,
    quant_config: Optional[QuantizationConfig] = None,
    layer_idx: int = None,
    cache_config: Optional[CacheConfig] = None,
    prefix: str = "mha",
) -> None:
    super().__init__()
    self.layer_idx = layer_idx

    self.hidden_size = hidden_size
    tp_size = get_tensor_model_parallel_world_size()
    self.total_num_heads = num_heads
    assert self.total_num_heads % tp_size == 0
    self.num_heads = self.total_num_heads // tp_size
    self.total_num_kv_heads = num_kv_heads
    if self.total_num_kv_heads >= tp_size:
        assert self.total_num_kv_heads % tp_size == 0
    else:
        assert tp_size % self.total_num_kv_heads == 0
    self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
    self.head_dim = head_dim

    self.q_size = self.num_heads * self.head_dim
    self.kv_size = self.num_kv_heads * self.head_dim
    self.scaling = self.head_dim**-0.5
    self.rope_theta = rope_theta
    self.sliding_window = sliding_window

    self.qkv_proj = QKVParallelLinear(
        hidden_size,
        self.head_dim,
        self.total_num_heads,
        self.total_num_kv_heads,
        bias=False,
        quant_config=quant_config,
        prefix=f"{prefix}.qkv_proj",
    )
    self.o_proj = RowParallelLinear(
        self.total_num_heads * self.head_dim,
        hidden_size,
        bias=False,
        quant_config=quant_config,
        prefix=f"{prefix}.o_proj",
    )
    self.attn = Attention(
        self.num_heads,
        self.head_dim,
        self.scaling,
        num_kv_heads=self.num_kv_heads,
        cache_config=cache_config,
        quant_config=quant_config,
        prefix=f"{prefix}.attn",
    )
    return

forward

forward(
    hidden_states: Tensor, positions: Tensor, **kwargs
) -> Tensor
Source code in vllm/model_executor/models/minimax_text_01.py
def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor,
            **kwargs) -> torch.Tensor:
    forward_context = get_forward_context()
    attn_metadata = forward_context.attn_metadata
    qkv, _ = self.qkv_proj(hidden_states)
    q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
    q, k = attn_metadata.rotary_emb(positions, q, k)
    attn_output = self.attn(q, k, v)
    output, _ = self.o_proj(attn_output)
    return output

MiniMaxText01DecoderLayer

Bases: Module

Source code in vllm/model_executor/models/minimax_text_01.py
class MiniMaxText01DecoderLayer(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        expert_num: int = 1,
        layer_id: int = None,
        linear_layer_id: Optional[int] = None,
        prefix: str = "decoder",
    ) -> None:
        self._ilayer = layer_id
        self._irank = get_tensor_model_parallel_rank()
        super().__init__()

        self.hidden_size = config.hidden_size
        self.expert_num = expert_num

        rope_theta = getattr(config, "rope_theta", 10000)

        head_dim = getattr(config, "head_dim", None)
        if head_dim is None:
            head_dim = config.hidden_size // config.num_attention_heads
        if hasattr(config, "max_model_len") and isinstance(
                config.max_model_len, int):
            max_position_embeddings = min(config.max_position_embeddings,
                                          config.max_model_len)
        if config.attention_type == 0:
            use_headxdim = True
            hidden_inner = (head_dim * config.num_attention_heads
                            if use_headxdim else config.hidden_size)
            self.self_attn = MiniMaxText01LinearAttention(
                hidden_size=self.hidden_size,
                hidden_inner_size=hidden_inner,
                num_heads=config.num_attention_heads,
                head_dim=head_dim,
                max_position=max_position_embeddings,
                block_size=config.block if hasattr(config, "block") else 256,
                num_hidden_layer=config.num_hidden_layers,
                quant_config=quant_config,
                layer_idx=self._ilayer,
                linear_layer_idx=linear_layer_id,
                prefix=prefix)
        elif config.attention_type == 1:
            self.self_attn = MiniMaxText01Attention(
                hidden_size=self.hidden_size,
                num_heads=config.num_attention_heads,
                head_dim=head_dim,
                rotary_dim=config.rotary_dim
                if hasattr(config, "rotary_dim") else head_dim,
                num_kv_heads=config.num_key_value_heads,
                max_position=max_position_embeddings,
                rope_theta=rope_theta,
                sliding_window=config.sliding_window,
                quant_config=quant_config,
                layer_idx=self._ilayer,
                cache_config=cache_config,
                prefix=prefix)
        else:
            raise ValueError(
                f"Unsupported attention type: {self.config.attention_type}")

        if expert_num == 1:
            self.mlp = MiniMaxText01MLP(
                hidden_size=self.hidden_size,
                intermediate_size=config.intermediate_size,
                quant_config=quant_config,
                layer_idx=self._ilayer,
                prefix=prefix)
        else:
            self.block_sparse_moe = MiniMaxText01MoE(
                num_experts=expert_num,
                top_k=config.num_experts_per_tok,
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                layer_idx=self._ilayer,
                quant_config=quant_config,
                prefix=prefix)

        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)
        if config.attention_type == 0:
            self.layernorm_attention_alpha = getattr(
                config, 'layernorm_linear_attention_alpha', 1)
            self.layernorm_attention_beta = getattr(
                config, 'layernorm_linear_attention_beta', 1)
        else:
            self.layernorm_attention_alpha = getattr(
                config, 'layernorm_full_attention_alpha', 1)
            self.layernorm_attention_beta = getattr(
                config, 'layernorm_full_attention_beta', 1)
        self.layernorm_mlp_alpha = getattr(config, 'layernorm_mlp_alpha', 1)
        self.layernorm_mlp_beta = getattr(config, 'layernorm_mlp_beta', 1)
        self.postnorm = getattr(config, 'postnorm', False)
        self.shared_moe = False

        shared_intermediate = getattr(config, 'shared_intermediate_size', 0)
        if isinstance(shared_intermediate, list):
            shared_intermediate = shared_intermediate[
                layer_id] if layer_id < len(shared_intermediate) else 0
        if shared_intermediate > 0:
            self.shared_moe = True
            self.shared_mlp = MiniMaxText01MLP(
                hidden_size=self.hidden_size,
                intermediate_size=shared_intermediate,
                quant_config=quant_config,
                layer_idx=self._ilayer,
                prefix=prefix)
            self.coefficient = ReplicatedLinear(
                self.hidden_size,
                1,
                bias=False,
                quant_config=quant_config,
                params_dtype=torch.float32,
            )
            self.coefficient.weight.weight_loader = (
                self.shared_moe_coefficient_loader)
            self.shared_moe_mode = getattr(config, 'shared_moe_mode',
                                           'softmax')
        return

    def forward(self,
                hidden_states: torch.Tensor,
                positions: torch.Tensor,
                kv_caches: Union[list[dict], Optional[torch.Tensor]],
                attn_metadata: AttentionMetadata,
                residual: Optional[torch.Tensor],
                is_warmup: bool = False,
                **kwargs) -> tuple[torch.Tensor, torch.Tensor]:

        forward_context = get_forward_context()
        attn_metadata = forward_context.attn_metadata
        layernorm_input = hidden_states
        layernorm_output = self.input_layernorm(layernorm_input)
        residual = layernorm_output if self.postnorm else layernorm_input
        self_attention_output = self.self_attn(
            hidden_states=layernorm_output,
            positions=positions,
            kv_caches=kv_caches,
            attn_metadata=attn_metadata,
        )

        residual = residual * self.layernorm_attention_alpha
        self_attention_output = (self_attention_output *
                                 self.layernorm_attention_beta)

        layernorm_input = residual + self_attention_output
        layernorm_output = self.post_attention_layernorm(layernorm_input)
        residual = layernorm_output if self.postnorm else layernorm_input

        if self.expert_num == 1:
            hidden_states = self.mlp(layernorm_output)
        else:
            moe_hidden_states = self.block_sparse_moe(
                copy.deepcopy(layernorm_output))
            if self.shared_moe:
                before_moe_dtype = layernorm_output.dtype
                moe_hidden_fp32 = moe_hidden_states.to(torch.float32)
                output_mlp = self.shared_mlp(layernorm_output).to(
                    torch.float32)

                coef, _ = self.coefficient(layernorm_output.to(torch.float32))

                if self.shared_moe_mode == 'softmax':
                    coef = torch.nn.functional.softmax(coef, dim=-1)
                    hidden_states = moe_hidden_fp32 * (
                        1 - coef) + output_mlp * coef
                elif self.shared_moe_mode == 'sigmoid':
                    coef = torch.nn.functional.sigmoid(coef)
                    hidden_states = moe_hidden_fp32 * (
                        1 - coef) + output_mlp * coef

                hidden_states = hidden_states.to(before_moe_dtype)
            else:
                hidden_states = moe_hidden_states

        residual = residual * self.layernorm_mlp_alpha
        hidden_states = hidden_states * self.layernorm_mlp_beta

        hidden_states = residual + hidden_states

        return hidden_states, None

    @staticmethod
    def shared_moe_coefficient_loader(param: torch.Tensor,
                                      loaded_weight: torch.Tensor) -> None:
        assert param.size() == loaded_weight.size()

        param.data.copy_(loaded_weight.to(torch.float32))
        return

_ilayer instance-attribute

_ilayer = layer_id

_irank instance-attribute

block_sparse_moe instance-attribute

block_sparse_moe = MiniMaxText01MoE(
    num_experts=expert_num,
    top_k=num_experts_per_tok,
    hidden_size=hidden_size,
    intermediate_size=intermediate_size,
    layer_idx=_ilayer,
    quant_config=quant_config,
    prefix=prefix,
)

coefficient instance-attribute

coefficient = ReplicatedLinear(
    hidden_size,
    1,
    bias=False,
    quant_config=quant_config,
    params_dtype=float32,
)

expert_num instance-attribute

expert_num = expert_num

hidden_size instance-attribute

hidden_size = hidden_size

input_layernorm instance-attribute

input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps)

layernorm_attention_alpha instance-attribute

layernorm_attention_alpha = getattr(
    config, "layernorm_linear_attention_alpha", 1
)

layernorm_attention_beta instance-attribute

layernorm_attention_beta = getattr(
    config, "layernorm_linear_attention_beta", 1
)

layernorm_mlp_alpha instance-attribute

layernorm_mlp_alpha = getattr(
    config, "layernorm_mlp_alpha", 1
)

layernorm_mlp_beta instance-attribute

layernorm_mlp_beta = getattr(
    config, "layernorm_mlp_beta", 1
)

mlp instance-attribute

mlp = MiniMaxText01MLP(
    hidden_size=hidden_size,
    intermediate_size=intermediate_size,
    quant_config=quant_config,
    layer_idx=_ilayer,
    prefix=prefix,
)

post_attention_layernorm instance-attribute

post_attention_layernorm = RMSNorm(
    hidden_size, eps=rms_norm_eps
)

postnorm instance-attribute

postnorm = getattr(config, 'postnorm', False)

self_attn instance-attribute

self_attn = MiniMaxText01LinearAttention(
    hidden_size=hidden_size,
    hidden_inner_size=hidden_inner,
    num_heads=num_attention_heads,
    head_dim=head_dim,
    max_position=max_position_embeddings,
    block_size=block if hasattr(config, "block") else 256,
    num_hidden_layer=num_hidden_layers,
    quant_config=quant_config,
    layer_idx=_ilayer,
    linear_layer_idx=linear_layer_id,
    prefix=prefix,
)

shared_mlp instance-attribute

shared_mlp = MiniMaxText01MLP(
    hidden_size=hidden_size,
    intermediate_size=shared_intermediate,
    quant_config=quant_config,
    layer_idx=_ilayer,
    prefix=prefix,
)

shared_moe instance-attribute

shared_moe = False

shared_moe_mode instance-attribute

shared_moe_mode = getattr(
    config, "shared_moe_mode", "softmax"
)

__init__

__init__(
    config: PretrainedConfig,
    cache_config: Optional[CacheConfig] = None,
    quant_config: Optional[QuantizationConfig] = None,
    expert_num: int = 1,
    layer_id: int = None,
    linear_layer_id: Optional[int] = None,
    prefix: str = "decoder",
) -> None
Source code in vllm/model_executor/models/minimax_text_01.py
def __init__(
    self,
    config: PretrainedConfig,
    cache_config: Optional[CacheConfig] = None,
    quant_config: Optional[QuantizationConfig] = None,
    expert_num: int = 1,
    layer_id: int = None,
    linear_layer_id: Optional[int] = None,
    prefix: str = "decoder",
) -> None:
    self._ilayer = layer_id
    self._irank = get_tensor_model_parallel_rank()
    super().__init__()

    self.hidden_size = config.hidden_size
    self.expert_num = expert_num

    rope_theta = getattr(config, "rope_theta", 10000)

    head_dim = getattr(config, "head_dim", None)
    if head_dim is None:
        head_dim = config.hidden_size // config.num_attention_heads
    if hasattr(config, "max_model_len") and isinstance(
            config.max_model_len, int):
        max_position_embeddings = min(config.max_position_embeddings,
                                      config.max_model_len)
    if config.attention_type == 0:
        use_headxdim = True
        hidden_inner = (head_dim * config.num_attention_heads
                        if use_headxdim else config.hidden_size)
        self.self_attn = MiniMaxText01LinearAttention(
            hidden_size=self.hidden_size,
            hidden_inner_size=hidden_inner,
            num_heads=config.num_attention_heads,
            head_dim=head_dim,
            max_position=max_position_embeddings,
            block_size=config.block if hasattr(config, "block") else 256,
            num_hidden_layer=config.num_hidden_layers,
            quant_config=quant_config,
            layer_idx=self._ilayer,
            linear_layer_idx=linear_layer_id,
            prefix=prefix)
    elif config.attention_type == 1:
        self.self_attn = MiniMaxText01Attention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            head_dim=head_dim,
            rotary_dim=config.rotary_dim
            if hasattr(config, "rotary_dim") else head_dim,
            num_kv_heads=config.num_key_value_heads,
            max_position=max_position_embeddings,
            rope_theta=rope_theta,
            sliding_window=config.sliding_window,
            quant_config=quant_config,
            layer_idx=self._ilayer,
            cache_config=cache_config,
            prefix=prefix)
    else:
        raise ValueError(
            f"Unsupported attention type: {self.config.attention_type}")

    if expert_num == 1:
        self.mlp = MiniMaxText01MLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            quant_config=quant_config,
            layer_idx=self._ilayer,
            prefix=prefix)
    else:
        self.block_sparse_moe = MiniMaxText01MoE(
            num_experts=expert_num,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            intermediate_size=config.intermediate_size,
            layer_idx=self._ilayer,
            quant_config=quant_config,
            prefix=prefix)

    self.input_layernorm = RMSNorm(config.hidden_size,
                                   eps=config.rms_norm_eps)
    self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                            eps=config.rms_norm_eps)
    if config.attention_type == 0:
        self.layernorm_attention_alpha = getattr(
            config, 'layernorm_linear_attention_alpha', 1)
        self.layernorm_attention_beta = getattr(
            config, 'layernorm_linear_attention_beta', 1)
    else:
        self.layernorm_attention_alpha = getattr(
            config, 'layernorm_full_attention_alpha', 1)
        self.layernorm_attention_beta = getattr(
            config, 'layernorm_full_attention_beta', 1)
    self.layernorm_mlp_alpha = getattr(config, 'layernorm_mlp_alpha', 1)
    self.layernorm_mlp_beta = getattr(config, 'layernorm_mlp_beta', 1)
    self.postnorm = getattr(config, 'postnorm', False)
    self.shared_moe = False

    shared_intermediate = getattr(config, 'shared_intermediate_size', 0)
    if isinstance(shared_intermediate, list):
        shared_intermediate = shared_intermediate[
            layer_id] if layer_id < len(shared_intermediate) else 0
    if shared_intermediate > 0:
        self.shared_moe = True
        self.shared_mlp = MiniMaxText01MLP(
            hidden_size=self.hidden_size,
            intermediate_size=shared_intermediate,
            quant_config=quant_config,
            layer_idx=self._ilayer,
            prefix=prefix)
        self.coefficient = ReplicatedLinear(
            self.hidden_size,
            1,
            bias=False,
            quant_config=quant_config,
            params_dtype=torch.float32,
        )
        self.coefficient.weight.weight_loader = (
            self.shared_moe_coefficient_loader)
        self.shared_moe_mode = getattr(config, 'shared_moe_mode',
                                       'softmax')
    return

forward

forward(
    hidden_states: Tensor,
    positions: Tensor,
    kv_caches: Union[list[dict], Optional[Tensor]],
    attn_metadata: AttentionMetadata,
    residual: Optional[Tensor],
    is_warmup: bool = False,
    **kwargs,
) -> tuple[Tensor, Tensor]
Source code in vllm/model_executor/models/minimax_text_01.py
def forward(self,
            hidden_states: torch.Tensor,
            positions: torch.Tensor,
            kv_caches: Union[list[dict], Optional[torch.Tensor]],
            attn_metadata: AttentionMetadata,
            residual: Optional[torch.Tensor],
            is_warmup: bool = False,
            **kwargs) -> tuple[torch.Tensor, torch.Tensor]:

    forward_context = get_forward_context()
    attn_metadata = forward_context.attn_metadata
    layernorm_input = hidden_states
    layernorm_output = self.input_layernorm(layernorm_input)
    residual = layernorm_output if self.postnorm else layernorm_input
    self_attention_output = self.self_attn(
        hidden_states=layernorm_output,
        positions=positions,
        kv_caches=kv_caches,
        attn_metadata=attn_metadata,
    )

    residual = residual * self.layernorm_attention_alpha
    self_attention_output = (self_attention_output *
                             self.layernorm_attention_beta)

    layernorm_input = residual + self_attention_output
    layernorm_output = self.post_attention_layernorm(layernorm_input)
    residual = layernorm_output if self.postnorm else layernorm_input

    if self.expert_num == 1:
        hidden_states = self.mlp(layernorm_output)
    else:
        moe_hidden_states = self.block_sparse_moe(
            copy.deepcopy(layernorm_output))
        if self.shared_moe:
            before_moe_dtype = layernorm_output.dtype
            moe_hidden_fp32 = moe_hidden_states.to(torch.float32)
            output_mlp = self.shared_mlp(layernorm_output).to(
                torch.float32)

            coef, _ = self.coefficient(layernorm_output.to(torch.float32))

            if self.shared_moe_mode == 'softmax':
                coef = torch.nn.functional.softmax(coef, dim=-1)
                hidden_states = moe_hidden_fp32 * (
                    1 - coef) + output_mlp * coef
            elif self.shared_moe_mode == 'sigmoid':
                coef = torch.nn.functional.sigmoid(coef)
                hidden_states = moe_hidden_fp32 * (
                    1 - coef) + output_mlp * coef

            hidden_states = hidden_states.to(before_moe_dtype)
        else:
            hidden_states = moe_hidden_states

    residual = residual * self.layernorm_mlp_alpha
    hidden_states = hidden_states * self.layernorm_mlp_beta

    hidden_states = residual + hidden_states

    return hidden_states, None

shared_moe_coefficient_loader staticmethod

shared_moe_coefficient_loader(
    param: Tensor, loaded_weight: Tensor
) -> None
Source code in vllm/model_executor/models/minimax_text_01.py
@staticmethod
def shared_moe_coefficient_loader(param: torch.Tensor,
                                  loaded_weight: torch.Tensor) -> None:
    assert param.size() == loaded_weight.size()

    param.data.copy_(loaded_weight.to(torch.float32))
    return

MiniMaxText01ForCausalLM

Bases: Module, HasInnerState, IsHybrid, SupportsV0Only

Source code in vllm/model_executor/models/minimax_text_01.py
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
                               SupportsV0Only):

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:

        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
        self.config = config
        self.lora_config = lora_config

        if not hasattr(config, "sliding_window"):
            config.sliding_window = None

        self.CONCAT_FFN = True

        self.unpadded_vocab_size = self.config.vocab_size
        if hasattr(vllm_config.model_config, "max_model_len"):
            self.config.max_model_len = vllm_config.model_config.max_model_len
        self.model = MiniMaxText01Model(
            self.config,
            quant_config,
            cache_config=vllm_config.cache_config,
            scheduler_config=vllm_config.scheduler_config,
            prefix=maybe_prefix(prefix, "model"))
        if get_pp_group().is_last_rank:
            self.lm_head = ParallelLMHead(
                self.unpadded_vocab_size,
                self.config.hidden_size,
                org_num_embeddings=self.config.vocab_size,
                padding_size=DEFAULT_VOCAB_PADDING_SIZE,
            )

            self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                    self.config.vocab_size)

        else:
            self.lm_head = PPMissingLayer()
        self.lm_head.float()
        flash_layer_count = sum(1 for attn_type in self.config.attn_type_list
                                if attn_type == 1)
        self.kv_cache = [torch.tensor([]) for _ in range(flash_layer_count)]
        return

    def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
        return self.model.minimax_cache.copy_inputs_before_cuda_graphs(
            input_buffers, **kwargs)

    def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
        return self.model.minimax_cache.get_seqlen_agnostic_capture_inputs(
            batch_size)

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
    ) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

    def forward(self,
                input_ids: torch.Tensor,
                positions: torch.Tensor,
                intermediate_tensors: Optional[IntermediateTensors] = None,
                inputs_embeds: Optional[torch.Tensor] = None,
                **kwargs) -> torch.Tensor:
        hidden_states = self.model(input_ids, positions, intermediate_tensors,
                                   inputs_embeds, **kwargs)

        return hidden_states

    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head, hidden_states.float(),
                                       sampling_metadata)

        return logits

    def make_empty_intermediate_tensors(
            self, batch_size: int, dtype: torch.dtype,
            device: torch.device) -> IntermediateTensors:
        return IntermediateTensors({
            "hidden_states":
            torch.zeros((batch_size, self.config.hidden_size),
                        dtype=dtype,
                        device=device),
            "residual":
            torch.zeros((batch_size, self.config.hidden_size),
                        dtype=dtype,
                        device=device),
        })

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()

        def which_layer(name: str) -> int:
            if "layers" in name:
                after_layer = name.split("layers")[-1]
                return int(after_layer.split(".")[1])
            return None

        def is_linear_attn_layer(layer_idx: int) -> bool:
            if layer_idx is None or not hasattr(self.config, "attn_type_list"):
                return False
            return self.config.attn_type_list[layer_idx] == 0

        def is_moe_weight(name: str) -> bool:
            return "block_sparse_moe" in name and not name.endswith(".bias")

        def get_expert_id(param_name):
            pattern = r'model\.layers\.\d+\.block_sparse_moe\.experts\.(\d+)\.'
            match = re.search(pattern, param_name)
            if match:
                return match.group(1)
            return None

        def load_sparse_moe_weight(name: str, loaded_weight: torch.Tensor,
                                   self) -> None:
            if isinstance(self.config.num_local_experts, list):
                expert_params_mapping = [
                    ("w13_weight"
                     if weight_name in ["w1", "w3"] else "w2_weight",
                     f"experts.{expert_id}.{weight_name}.weight", expert_id)
                    for expert_id in range(max(self.config.num_local_experts))
                    for weight_name in ["w1", "w2", "w3"]
                ]
            else:
                expert_params_mapping = [
                    ("w13_scale" if weight_name in ["w1", "w3"] else
                     "w2_scale", f"{expert_id}.{weight_name}.weight_scale",
                     expert_id, weight_name)
                    for expert_id in range(self.config.num_local_experts)
                    for weight_name in ["w1", "w2", "w3"]
                ] + [("w13_weight" if weight_name in ["w1", "w3"] else
                      "w2_weight", f"{expert_id}.{weight_name}.weight",
                      expert_id, weight_name)
                     for expert_id in range(self.config.num_local_experts)
                     for weight_name in ["w1", "w2", "w3"]]
            for (param_name, weight_name, expert_id,
                 shard_id) in expert_params_mapping:
                name_expert_id = get_expert_id(name)
                if name_expert_id is not None and int(name_expert_id) != int(
                        expert_id):
                    continue
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                if is_pp_missing_parameter(name, self):
                    return
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader = weight_loader_with_alias(name)(weight_loader)
                weight_loader(param,
                              loaded_weight,
                              weight_name,
                              expert_id=expert_id,
                              shard_id=shard_id)
                loaded_params.add(name)
                break
            else:
                if is_pp_missing_parameter(name, self):
                    return
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader = weight_loader_with_alias(name)(weight_loader)
                weight_loader(param, loaded_weight)
                loaded_params.add(name)
            return

        def is_shared_mlp_weight(name: str) -> bool:
            return "shared_mlp" in name and not name.endswith(".bias")

        def load_shared_mlp_weight(name: str, loaded_weight: torch.Tensor,
                                   self) -> None:
            if not self.CONCAT_FFN:
                if "gate_proj" in name:
                    name = name.replace("gate_proj", "w1", 1)
                elif "up_proj" in name:
                    name = name.replace("up_proj", "w3", 1)
                elif "down_proj" in name:
                    name = name.replace("down_proj", "w2", 1)
            else:
                if "gate_proj" in name:
                    name = name.replace("gate_proj", "gate_up_proj", 1)
                    loaded_shard_id = 0
                elif "up_proj" in name:
                    name = name.replace("up_proj", "gate_up_proj", 1)
                    loaded_shard_id = 1
            if is_pp_missing_parameter(name, self):
                return
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader = weight_loader_with_alias(name)(weight_loader)
            if not self.CONCAT_FFN:
                weight_loader(param, loaded_weight)
            else:
                if "gate_up_proj" in name:
                    weight_loader(param, loaded_weight, loaded_shard_id)
                elif "down_proj" in name:
                    weight_loader(param, loaded_weight)
                else:
                    raise AssertionError(
                        "MLP weight not in [gate_up_proj, down_proj]")
            loaded_params.add(name)
            return

        def is_mha_weight(name: str) -> bool:
            return "self_attn" in name and not name.endswith(".bias")

        def load_linear_attn_weight(name: str, loaded_weight: torch.Tensor,
                                    self) -> None:
            if is_pp_missing_parameter(name, self):
                return
            param = params_dict[name]

            weight_loader = getattr(
                param, "weight_loader",
                MiniMaxText01LinearAttention.weight_direct_load)
            weight_loader = weight_loader_with_alias(name)(weight_loader)
            weight_loader(param, loaded_weight)
            loaded_params.add(name)
            return

        def load_flash_attn_weight(name: str, loaded_weight: torch.Tensor,
                                   self) -> None:

            flash_mha_params_mapping = [
                ("qkv_proj", "q_proj", "q"),
                ("qkv_proj", "k_proj", "k"),
                ("qkv_proj", "v_proj", "v"),
                ("gate_up_proj", "gate_proj", 0),
                ("gate_up_proj", "up_proj", 1),
            ]
            for (param_name, weight_name,
                 shard_id) in flash_mha_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                if is_pp_missing_parameter(name, self):
                    return
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader = weight_loader_with_alias(name)(weight_loader)
                weight_loader(param, loaded_weight, shard_id)
                loaded_params.add(name)
                break
            else:
                if is_pp_missing_parameter(name, self):
                    return
                param = params_dict[name]

                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader = weight_loader_with_alias(name)(weight_loader)
                weight_loader(param, loaded_weight)
                loaded_params.add(name)
            return

        def is_layer_norm_weight(name: str) -> bool:
            return "norm" in name and not name.endswith(
                ".bias") and name in params_dict

        def load_layer_norm_weight(name: str, loaded_weight: torch.Tensor,
                                   self) -> None:
            if is_pp_missing_parameter(name, self):
                return
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader = weight_loader_with_alias(name)(weight_loader)
            weight_loader(param, loaded_weight)
            loaded_params.add(name)
            return

        def load_basic_weight(name: str, loaded_weight: torch.Tensor,
                              self) -> None:
            if is_pp_missing_parameter(name, self):
                return
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader = weight_loader_with_alias(name)(weight_loader)
            weight_loader(param, loaded_weight)
            loaded_params.add(name)
            return

        for name, loaded_weight in weights:
            weight_at_layer = which_layer(name)
            if weight_at_layer and weight_at_layer >= len(
                    self.config.attn_type_list):
                continue

            if is_layer_norm_weight(name):
                load_layer_norm_weight(name, loaded_weight, self)
                continue
            if is_mha_weight(name):
                if is_linear_attn_layer(weight_at_layer):
                    load_linear_attn_weight(name, loaded_weight, self)
                else:
                    load_flash_attn_weight(name, loaded_weight, self)
                continue
            if is_moe_weight(name):
                load_sparse_moe_weight(name, loaded_weight, self)
                continue
            if is_shared_mlp_weight(name):
                load_shared_mlp_weight(name, loaded_weight, self)
                continue

            if "rotary_emb.inv_freq" in name:
                continue

            load_basic_weight(name, loaded_weight, self)
        return loaded_params

CONCAT_FFN instance-attribute

CONCAT_FFN = True

config instance-attribute

config = config

kv_cache instance-attribute

kv_cache = [tensor([]) for _ in range(flash_layer_count)]

lm_head instance-attribute

lm_head = ParallelLMHead(
    unpadded_vocab_size,
    hidden_size,
    org_num_embeddings=vocab_size,
    padding_size=DEFAULT_VOCAB_PADDING_SIZE,
)

logits_processor instance-attribute

logits_processor = LogitsProcessor(
    unpadded_vocab_size, vocab_size
)

lora_config instance-attribute

lora_config = lora_config

model instance-attribute

model = MiniMaxText01Model(
    config,
    quant_config,
    cache_config=cache_config,
    scheduler_config=scheduler_config,
    prefix=maybe_prefix(prefix, "model"),
)

unpadded_vocab_size instance-attribute

unpadded_vocab_size = vocab_size

__init__

__init__(
    *, vllm_config: VllmConfig, prefix: str = ""
) -> None
Source code in vllm/model_executor/models/minimax_text_01.py
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:

    super().__init__()
    config = vllm_config.model_config.hf_config
    quant_config = vllm_config.quant_config
    lora_config = vllm_config.lora_config
    self.config = config
    self.lora_config = lora_config

    if not hasattr(config, "sliding_window"):
        config.sliding_window = None

    self.CONCAT_FFN = True

    self.unpadded_vocab_size = self.config.vocab_size
    if hasattr(vllm_config.model_config, "max_model_len"):
        self.config.max_model_len = vllm_config.model_config.max_model_len
    self.model = MiniMaxText01Model(
        self.config,
        quant_config,
        cache_config=vllm_config.cache_config,
        scheduler_config=vllm_config.scheduler_config,
        prefix=maybe_prefix(prefix, "model"))
    if get_pp_group().is_last_rank:
        self.lm_head = ParallelLMHead(
            self.unpadded_vocab_size,
            self.config.hidden_size,
            org_num_embeddings=self.config.vocab_size,
            padding_size=DEFAULT_VOCAB_PADDING_SIZE,
        )

        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                self.config.vocab_size)

    else:
        self.lm_head = PPMissingLayer()
    self.lm_head.float()
    flash_layer_count = sum(1 for attn_type in self.config.attn_type_list
                            if attn_type == 1)
    self.kv_cache = [torch.tensor([]) for _ in range(flash_layer_count)]
    return

compute_logits

compute_logits(
    hidden_states: Tensor,
    sampling_metadata: SamplingMetadata,
) -> Tensor
Source code in vllm/model_executor/models/minimax_text_01.py
def compute_logits(self, hidden_states: torch.Tensor,
                   sampling_metadata: SamplingMetadata) -> torch.Tensor:
    logits = self.logits_processor(self.lm_head, hidden_states.float(),
                                   sampling_metadata)

    return logits

copy_inputs_before_cuda_graphs

copy_inputs_before_cuda_graphs(input_buffers, **kwargs)
Source code in vllm/model_executor/models/minimax_text_01.py
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
    return self.model.minimax_cache.copy_inputs_before_cuda_graphs(
        input_buffers, **kwargs)

forward

forward(
    input_ids: Tensor,
    positions: Tensor,
    intermediate_tensors: Optional[
        IntermediateTensors
    ] = None,
    inputs_embeds: Optional[Tensor] = None,
    **kwargs,
) -> Tensor
Source code in vllm/model_executor/models/minimax_text_01.py
def forward(self,
            input_ids: torch.Tensor,
            positions: torch.Tensor,
            intermediate_tensors: Optional[IntermediateTensors] = None,
            inputs_embeds: Optional[torch.Tensor] = None,
            **kwargs) -> torch.Tensor:
    hidden_states = self.model(input_ids, positions, intermediate_tensors,
                               inputs_embeds, **kwargs)

    return hidden_states

get_input_embeddings

get_input_embeddings(input_ids: Tensor) -> Tensor
Source code in vllm/model_executor/models/minimax_text_01.py
def get_input_embeddings(
    self,
    input_ids: torch.Tensor,
) -> torch.Tensor:
    return self.model.get_input_embeddings(input_ids)

get_seqlen_agnostic_capture_inputs

get_seqlen_agnostic_capture_inputs(batch_size: int)
Source code in vllm/model_executor/models/minimax_text_01.py
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
    return self.model.minimax_cache.get_seqlen_agnostic_capture_inputs(
        batch_size)

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]
Source code in vllm/model_executor/models/minimax_text_01.py
def load_weights(self, weights: Iterable[tuple[str,
                                               torch.Tensor]]) -> set[str]:
    params_dict = dict(self.named_parameters())
    loaded_params: set[str] = set()

    def which_layer(name: str) -> int:
        if "layers" in name:
            after_layer = name.split("layers")[-1]
            return int(after_layer.split(".")[1])
        return None

    def is_linear_attn_layer(layer_idx: int) -> bool:
        if layer_idx is None or not hasattr(self.config, "attn_type_list"):
            return False
        return self.config.attn_type_list[layer_idx] == 0

    def is_moe_weight(name: str) -> bool:
        return "block_sparse_moe" in name and not name.endswith(".bias")

    def get_expert_id(param_name):
        pattern = r'model\.layers\.\d+\.block_sparse_moe\.experts\.(\d+)\.'
        match = re.search(pattern, param_name)
        if match:
            return match.group(1)
        return None

    def load_sparse_moe_weight(name: str, loaded_weight: torch.Tensor,
                               self) -> None:
        if isinstance(self.config.num_local_experts, list):
            expert_params_mapping = [
                ("w13_weight"
                 if weight_name in ["w1", "w3"] else "w2_weight",
                 f"experts.{expert_id}.{weight_name}.weight", expert_id)
                for expert_id in range(max(self.config.num_local_experts))
                for weight_name in ["w1", "w2", "w3"]
            ]
        else:
            expert_params_mapping = [
                ("w13_scale" if weight_name in ["w1", "w3"] else
                 "w2_scale", f"{expert_id}.{weight_name}.weight_scale",
                 expert_id, weight_name)
                for expert_id in range(self.config.num_local_experts)
                for weight_name in ["w1", "w2", "w3"]
            ] + [("w13_weight" if weight_name in ["w1", "w3"] else
                  "w2_weight", f"{expert_id}.{weight_name}.weight",
                  expert_id, weight_name)
                 for expert_id in range(self.config.num_local_experts)
                 for weight_name in ["w1", "w2", "w3"]]
        for (param_name, weight_name, expert_id,
             shard_id) in expert_params_mapping:
            name_expert_id = get_expert_id(name)
            if name_expert_id is not None and int(name_expert_id) != int(
                    expert_id):
                continue
            if weight_name not in name:
                continue
            name = name.replace(weight_name, param_name)
            if is_pp_missing_parameter(name, self):
                return
            param = params_dict[name]
            weight_loader = param.weight_loader
            weight_loader = weight_loader_with_alias(name)(weight_loader)
            weight_loader(param,
                          loaded_weight,
                          weight_name,
                          expert_id=expert_id,
                          shard_id=shard_id)
            loaded_params.add(name)
            break
        else:
            if is_pp_missing_parameter(name, self):
                return
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader = weight_loader_with_alias(name)(weight_loader)
            weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return

    def is_shared_mlp_weight(name: str) -> bool:
        return "shared_mlp" in name and not name.endswith(".bias")

    def load_shared_mlp_weight(name: str, loaded_weight: torch.Tensor,
                               self) -> None:
        if not self.CONCAT_FFN:
            if "gate_proj" in name:
                name = name.replace("gate_proj", "w1", 1)
            elif "up_proj" in name:
                name = name.replace("up_proj", "w3", 1)
            elif "down_proj" in name:
                name = name.replace("down_proj", "w2", 1)
        else:
            if "gate_proj" in name:
                name = name.replace("gate_proj", "gate_up_proj", 1)
                loaded_shard_id = 0
            elif "up_proj" in name:
                name = name.replace("up_proj", "gate_up_proj", 1)
                loaded_shard_id = 1
        if is_pp_missing_parameter(name, self):
            return
        param = params_dict[name]
        weight_loader = getattr(param, "weight_loader",
                                default_weight_loader)
        weight_loader = weight_loader_with_alias(name)(weight_loader)
        if not self.CONCAT_FFN:
            weight_loader(param, loaded_weight)
        else:
            if "gate_up_proj" in name:
                weight_loader(param, loaded_weight, loaded_shard_id)
            elif "down_proj" in name:
                weight_loader(param, loaded_weight)
            else:
                raise AssertionError(
                    "MLP weight not in [gate_up_proj, down_proj]")
        loaded_params.add(name)
        return

    def is_mha_weight(name: str) -> bool:
        return "self_attn" in name and not name.endswith(".bias")

    def load_linear_attn_weight(name: str, loaded_weight: torch.Tensor,
                                self) -> None:
        if is_pp_missing_parameter(name, self):
            return
        param = params_dict[name]

        weight_loader = getattr(
            param, "weight_loader",
            MiniMaxText01LinearAttention.weight_direct_load)
        weight_loader = weight_loader_with_alias(name)(weight_loader)
        weight_loader(param, loaded_weight)
        loaded_params.add(name)
        return

    def load_flash_attn_weight(name: str, loaded_weight: torch.Tensor,
                               self) -> None:

        flash_mha_params_mapping = [
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        for (param_name, weight_name,
             shard_id) in flash_mha_params_mapping:
            if weight_name not in name:
                continue
            name = name.replace(weight_name, param_name)
            if is_pp_missing_parameter(name, self):
                return
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader = weight_loader_with_alias(name)(weight_loader)
            weight_loader(param, loaded_weight, shard_id)
            loaded_params.add(name)
            break
        else:
            if is_pp_missing_parameter(name, self):
                return
            param = params_dict[name]

            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader = weight_loader_with_alias(name)(weight_loader)
            weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return

    def is_layer_norm_weight(name: str) -> bool:
        return "norm" in name and not name.endswith(
            ".bias") and name in params_dict

    def load_layer_norm_weight(name: str, loaded_weight: torch.Tensor,
                               self) -> None:
        if is_pp_missing_parameter(name, self):
            return
        param = params_dict[name]
        weight_loader = getattr(param, "weight_loader",
                                default_weight_loader)
        weight_loader = weight_loader_with_alias(name)(weight_loader)
        weight_loader(param, loaded_weight)
        loaded_params.add(name)
        return

    def load_basic_weight(name: str, loaded_weight: torch.Tensor,
                          self) -> None:
        if is_pp_missing_parameter(name, self):
            return
        param = params_dict[name]
        weight_loader = getattr(param, "weight_loader",
                                default_weight_loader)
        weight_loader = weight_loader_with_alias(name)(weight_loader)
        weight_loader(param, loaded_weight)
        loaded_params.add(name)
        return

    for name, loaded_weight in weights:
        weight_at_layer = which_layer(name)
        if weight_at_layer and weight_at_layer >= len(
                self.config.attn_type_list):
            continue

        if is_layer_norm_weight(name):
            load_layer_norm_weight(name, loaded_weight, self)
            continue
        if is_mha_weight(name):
            if is_linear_attn_layer(weight_at_layer):
                load_linear_attn_weight(name, loaded_weight, self)
            else:
                load_flash_attn_weight(name, loaded_weight, self)
            continue
        if is_moe_weight(name):
            load_sparse_moe_weight(name, loaded_weight, self)
            continue
        if is_shared_mlp_weight(name):
            load_shared_mlp_weight(name, loaded_weight, self)
            continue

        if "rotary_emb.inv_freq" in name:
            continue

        load_basic_weight(name, loaded_weight, self)
    return loaded_params

make_empty_intermediate_tensors

make_empty_intermediate_tensors(
    batch_size: int, dtype: dtype, device: device
) -> IntermediateTensors
Source code in vllm/model_executor/models/minimax_text_01.py
def make_empty_intermediate_tensors(
        self, batch_size: int, dtype: torch.dtype,
        device: torch.device) -> IntermediateTensors:
    return IntermediateTensors({
        "hidden_states":
        torch.zeros((batch_size, self.config.hidden_size),
                    dtype=dtype,
                    device=device),
        "residual":
        torch.zeros((batch_size, self.config.hidden_size),
                    dtype=dtype,
                    device=device),
    })

MiniMaxText01LinearAttention

Bases: Module

Source code in vllm/model_executor/models/minimax_text_01.py
class MiniMaxText01LinearAttention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        hidden_inner_size: int,
        num_heads: int,
        head_dim: int,
        max_position: int,
        block_size: int,
        num_hidden_layer: int,
        quant_config: Optional[QuantizationConfig] = None,
        layer_idx: int = 0,
        linear_layer_idx: int = 0,
        prefix: str = "linear_attn",
    ) -> None:
        super().__init__()

        self.layer_idx = layer_idx
        self.BLOCK = block_size
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.total_num_heads = num_heads
        self.hidden_inner_size = hidden_inner_size
        self.tp_size = get_tensor_model_parallel_world_size()
        self.tp_rank = get_tensor_model_parallel_rank()

        assert self.total_num_heads % self.tp_size == 0
        self.tp_heads = self.total_num_heads // self.tp_size
        self.qkv_size = self.num_heads * self.head_dim
        self.tp_hidden = self.head_dim * self.tp_heads

        self.qkv_proj = ColumnParallelLinear(
            hidden_size,
            self.hidden_inner_size * 3,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )
        self.output_gate = ColumnParallelLinear(
            hidden_size,
            self.hidden_inner_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.output_gate",
        )
        self.out_proj = RowParallelLinear(
            self.hidden_inner_size,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.out_proj",
        )
        self.norm = MiniMaxText01RMSNormTP(
            self.hidden_inner_size,
            eps=1e-5,
        )

        slope_rate = MiniMaxText01LinearAttention._build_slope_tensor(
            self.num_heads)
        if num_hidden_layer <= 1:
            self.slope_rate = slope_rate * (1 + 1e-5)
        else:
            self.slope_rate = slope_rate * (1 - layer_idx /
                                            (num_hidden_layer - 1) + 1e-5)
        self.tp_slope = self.slope_rate[self.tp_rank *
                                        self.tp_heads:(self.tp_rank + 1) *
                                        self.tp_heads].contiguous()

    @staticmethod
    def weight_direct_load(param: torch.Tensor,
                           loaded_weight: torch.Tensor) -> None:
        assert param.size() == loaded_weight.size()
        param.data.copy_(loaded_weight)
        return

    @staticmethod
    def _build_slope_tensor(n_attention_heads: int):

        def get_slopes(n):

            def get_slopes_power_of_2(n):
                start = 2**(-(2**-(math.log2(n) - 3)))
                ratio = start
                return [start * ratio**i for i in range(n)]

            if math.log2(n).is_integer():
                return get_slopes_power_of_2(n)
            else:
                closest_power_of_2 = 2**math.floor(math.log2(n))
                return (get_slopes_power_of_2(closest_power_of_2) + get_slopes(
                    2 * closest_power_of_2)[0::2][:n - closest_power_of_2])

        slopes = torch.tensor(get_slopes(n_attention_heads),
                              dtype=torch.float32).reshape(
                                  n_attention_heads, 1, 1)
        return slopes

    def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor,
                               attn_metadata):
        hidden = []
        for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)):
            if _prefill_idx >= len(attn_metadata.query_start_loc):
                break
            if _prefill_idx >= len(state_indices_tensor):
                break
            _start = attn_metadata.query_start_loc[_prefill_idx]
            _end = attn_metadata.query_start_loc[_prefill_idx + 1]
            slot_id = state_indices_tensor[_prefill_idx]
            qs = q[_start:_end].transpose(0, 1).contiguous()
            ks = k[_start:_end].transpose(0, 1).contiguous()
            vs = v[_start:_end].transpose(0, 1).contiguous()
            slot_id = state_indices_tensor[_prefill_idx]
            slice_layer_cache = kv_cache[slot_id, ...]

            out_slice = MiniMaxText01LinearKernel.jit_linear_forward_prefix(
                qs,
                ks,
                vs,
                slice_layer_cache,
                self.tp_slope,
                self.BLOCK,
                layer_idx=self.layer_idx)
            hidden.append(out_slice.contiguous())
        if attn_metadata.num_decode_tokens > 0:
            hidden.append(
                self._decode_infer(q, k, v, kv_cache, state_indices_tensor,
                                   attn_metadata))

        if not hidden:
            return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype)

        hidden = torch.concat(hidden, dim=0).contiguous()
        return hidden

    def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor,
                      attn_metadata):
        q = q[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
        k = k[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
        v = v[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
        slot_id = state_indices_tensor[getattr(attn_metadata, "num_prefills", 0
                                               ):]
        hidden = linear_decode_forward_triton(q, k, v, kv_cache, self.tp_slope,
                                              slot_id, 32)
        return hidden

    def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor,
                kv_caches: MinimaxCacheParams, **kwargs) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        qkv32 = qkv.to(torch.float32)
        qkvact = torch.nn.functional.silu(qkv32)
        qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1))
        q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1)
        forward_context = get_forward_context()
        attn_metadata = forward_context.attn_metadata
        kv_cache = kv_caches.minimax_cache
        state_indices_tensor = kv_caches.state_indices_tensor

        decode_only = getattr(attn_metadata, "num_prefills", 0) == 0
        if not decode_only:
            hidden = self._prefill_and_mix_infer(q, k, v, kv_cache,
                                                 state_indices_tensor,
                                                 attn_metadata)
        else:
            hidden = self._decode_infer(q, k, v, kv_cache,
                                        state_indices_tensor, attn_metadata)

        hidden = self.norm._forward(hidden)
        gate, _ = self.output_gate(hidden_states)
        hidden = F.sigmoid(gate) * hidden
        hidden = hidden.to(hidden_states.dtype)
        hidden, _ = self.out_proj(hidden)
        return hidden

BLOCK instance-attribute

BLOCK = block_size

head_dim instance-attribute

head_dim = head_dim

hidden_inner_size instance-attribute

hidden_inner_size = hidden_inner_size

hidden_size instance-attribute

hidden_size = hidden_size

layer_idx instance-attribute

layer_idx = layer_idx

norm instance-attribute

norm = MiniMaxText01RMSNormTP(hidden_inner_size, eps=1e-05)

num_heads instance-attribute

num_heads = num_heads

out_proj instance-attribute

out_proj = RowParallelLinear(
    hidden_inner_size,
    hidden_size,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.out_proj",
)

output_gate instance-attribute

output_gate = ColumnParallelLinear(
    hidden_size,
    hidden_inner_size,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.output_gate",
)

qkv_proj instance-attribute

qkv_proj = ColumnParallelLinear(
    hidden_size,
    hidden_inner_size * 3,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.qkv_proj",
)

qkv_size instance-attribute

qkv_size = num_heads * head_dim

slope_rate instance-attribute

slope_rate = slope_rate * 1 + 1e-05

total_num_heads instance-attribute

total_num_heads = num_heads

tp_heads instance-attribute

tp_heads = total_num_heads // tp_size

tp_hidden instance-attribute

tp_hidden = head_dim * tp_heads

tp_rank instance-attribute

tp_size instance-attribute

tp_slope instance-attribute

tp_slope = contiguous()

__init__

__init__(
    hidden_size: int,
    hidden_inner_size: int,
    num_heads: int,
    head_dim: int,
    max_position: int,
    block_size: int,
    num_hidden_layer: int,
    quant_config: Optional[QuantizationConfig] = None,
    layer_idx: int = 0,
    linear_layer_idx: int = 0,
    prefix: str = "linear_attn",
) -> None
Source code in vllm/model_executor/models/minimax_text_01.py
def __init__(
    self,
    hidden_size: int,
    hidden_inner_size: int,
    num_heads: int,
    head_dim: int,
    max_position: int,
    block_size: int,
    num_hidden_layer: int,
    quant_config: Optional[QuantizationConfig] = None,
    layer_idx: int = 0,
    linear_layer_idx: int = 0,
    prefix: str = "linear_attn",
) -> None:
    super().__init__()

    self.layer_idx = layer_idx
    self.BLOCK = block_size
    self.hidden_size = hidden_size
    self.num_heads = num_heads
    self.head_dim = head_dim
    self.total_num_heads = num_heads
    self.hidden_inner_size = hidden_inner_size
    self.tp_size = get_tensor_model_parallel_world_size()
    self.tp_rank = get_tensor_model_parallel_rank()

    assert self.total_num_heads % self.tp_size == 0
    self.tp_heads = self.total_num_heads // self.tp_size
    self.qkv_size = self.num_heads * self.head_dim
    self.tp_hidden = self.head_dim * self.tp_heads

    self.qkv_proj = ColumnParallelLinear(
        hidden_size,
        self.hidden_inner_size * 3,
        bias=False,
        quant_config=quant_config,
        prefix=f"{prefix}.qkv_proj",
    )
    self.output_gate = ColumnParallelLinear(
        hidden_size,
        self.hidden_inner_size,
        bias=False,
        quant_config=quant_config,
        prefix=f"{prefix}.output_gate",
    )
    self.out_proj = RowParallelLinear(
        self.hidden_inner_size,
        hidden_size,
        bias=False,
        quant_config=quant_config,
        prefix=f"{prefix}.out_proj",
    )
    self.norm = MiniMaxText01RMSNormTP(
        self.hidden_inner_size,
        eps=1e-5,
    )

    slope_rate = MiniMaxText01LinearAttention._build_slope_tensor(
        self.num_heads)
    if num_hidden_layer <= 1:
        self.slope_rate = slope_rate * (1 + 1e-5)
    else:
        self.slope_rate = slope_rate * (1 - layer_idx /
                                        (num_hidden_layer - 1) + 1e-5)
    self.tp_slope = self.slope_rate[self.tp_rank *
                                    self.tp_heads:(self.tp_rank + 1) *
                                    self.tp_heads].contiguous()

_build_slope_tensor staticmethod

_build_slope_tensor(n_attention_heads: int)
Source code in vllm/model_executor/models/minimax_text_01.py
@staticmethod
def _build_slope_tensor(n_attention_heads: int):

    def get_slopes(n):

        def get_slopes_power_of_2(n):
            start = 2**(-(2**-(math.log2(n) - 3)))
            ratio = start
            return [start * ratio**i for i in range(n)]

        if math.log2(n).is_integer():
            return get_slopes_power_of_2(n)
        else:
            closest_power_of_2 = 2**math.floor(math.log2(n))
            return (get_slopes_power_of_2(closest_power_of_2) + get_slopes(
                2 * closest_power_of_2)[0::2][:n - closest_power_of_2])

    slopes = torch.tensor(get_slopes(n_attention_heads),
                          dtype=torch.float32).reshape(
                              n_attention_heads, 1, 1)
    return slopes

_decode_infer

_decode_infer(
    q, k, v, kv_cache, state_indices_tensor, attn_metadata
)
Source code in vllm/model_executor/models/minimax_text_01.py
def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor,
                  attn_metadata):
    q = q[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
    k = k[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
    v = v[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
    slot_id = state_indices_tensor[getattr(attn_metadata, "num_prefills", 0
                                           ):]
    hidden = linear_decode_forward_triton(q, k, v, kv_cache, self.tp_slope,
                                          slot_id, 32)
    return hidden

_prefill_and_mix_infer

_prefill_and_mix_infer(
    q, k, v, kv_cache, state_indices_tensor, attn_metadata
)
Source code in vllm/model_executor/models/minimax_text_01.py
def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor,
                           attn_metadata):
    hidden = []
    for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)):
        if _prefill_idx >= len(attn_metadata.query_start_loc):
            break
        if _prefill_idx >= len(state_indices_tensor):
            break
        _start = attn_metadata.query_start_loc[_prefill_idx]
        _end = attn_metadata.query_start_loc[_prefill_idx + 1]
        slot_id = state_indices_tensor[_prefill_idx]
        qs = q[_start:_end].transpose(0, 1).contiguous()
        ks = k[_start:_end].transpose(0, 1).contiguous()
        vs = v[_start:_end].transpose(0, 1).contiguous()
        slot_id = state_indices_tensor[_prefill_idx]
        slice_layer_cache = kv_cache[slot_id, ...]

        out_slice = MiniMaxText01LinearKernel.jit_linear_forward_prefix(
            qs,
            ks,
            vs,
            slice_layer_cache,
            self.tp_slope,
            self.BLOCK,
            layer_idx=self.layer_idx)
        hidden.append(out_slice.contiguous())
    if attn_metadata.num_decode_tokens > 0:
        hidden.append(
            self._decode_infer(q, k, v, kv_cache, state_indices_tensor,
                               attn_metadata))

    if not hidden:
        return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype)

    hidden = torch.concat(hidden, dim=0).contiguous()
    return hidden

forward

forward(
    hidden_states: Tensor,
    positions: Tensor,
    kv_caches: MinimaxCacheParams,
    **kwargs,
) -> Tensor
Source code in vllm/model_executor/models/minimax_text_01.py
def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor,
            kv_caches: MinimaxCacheParams, **kwargs) -> torch.Tensor:
    qkv, _ = self.qkv_proj(hidden_states)
    qkv32 = qkv.to(torch.float32)
    qkvact = torch.nn.functional.silu(qkv32)
    qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1))
    q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1)
    forward_context = get_forward_context()
    attn_metadata = forward_context.attn_metadata
    kv_cache = kv_caches.minimax_cache
    state_indices_tensor = kv_caches.state_indices_tensor

    decode_only = getattr(attn_metadata, "num_prefills", 0) == 0
    if not decode_only:
        hidden = self._prefill_and_mix_infer(q, k, v, kv_cache,
                                             state_indices_tensor,
                                             attn_metadata)
    else:
        hidden = self._decode_infer(q, k, v, kv_cache,
                                    state_indices_tensor, attn_metadata)

    hidden = self.norm._forward(hidden)
    gate, _ = self.output_gate(hidden_states)
    hidden = F.sigmoid(gate) * hidden
    hidden = hidden.to(hidden_states.dtype)
    hidden, _ = self.out_proj(hidden)
    return hidden

weight_direct_load staticmethod

weight_direct_load(
    param: Tensor, loaded_weight: Tensor
) -> None
Source code in vllm/model_executor/models/minimax_text_01.py
@staticmethod
def weight_direct_load(param: torch.Tensor,
                       loaded_weight: torch.Tensor) -> None:
    assert param.size() == loaded_weight.size()
    param.data.copy_(loaded_weight)
    return

MiniMaxText01LinearKernel

Source code in vllm/model_executor/models/minimax_text_01.py
class MiniMaxText01LinearKernel:

    @staticmethod
    def jit_linear_forward_prefix(q: torch.Tensor,
                                  k: torch.Tensor,
                                  v: torch.Tensor,
                                  kv_caches: torch.Tensor,
                                  slope_rate: torch.Tensor,
                                  block_size: int,
                                  layer_idx: int = None,
                                  **kwargs) -> torch.Tensor:

        slope_rate = slope_rate.to(torch.float32)
        should_pad_dim = q.dim() == 3
        if should_pad_dim:
            q = q.unsqueeze(0)
            k = k.unsqueeze(0)
            v = v.unsqueeze(0)
        b, h, n, d = q.shape
        e = d
        kv_history = kv_caches.reshape(1, h, d, e).contiguous()
        output, kv_history = lightning_attention(q,
                                                 k,
                                                 v,
                                                 slope_rate,
                                                 block_size=block_size,
                                                 kv_history=kv_history)
        kv_caches.copy_(kv_history[:, :, -1, :, :].reshape(h, d, e))
        assert output.shape[0] == 1, "batch size must be 1"
        return rearrange(output.squeeze(0), "h n d -> n (h d)")

jit_linear_forward_prefix staticmethod

jit_linear_forward_prefix(
    q: Tensor,
    k: Tensor,
    v: Tensor,
    kv_caches: Tensor,
    slope_rate: Tensor,
    block_size: int,
    layer_idx: int = None,
    **kwargs,
) -> Tensor
Source code in vllm/model_executor/models/minimax_text_01.py
@staticmethod
def jit_linear_forward_prefix(q: torch.Tensor,
                              k: torch.Tensor,
                              v: torch.Tensor,
                              kv_caches: torch.Tensor,
                              slope_rate: torch.Tensor,
                              block_size: int,
                              layer_idx: int = None,
                              **kwargs) -> torch.Tensor:

    slope_rate = slope_rate.to(torch.float32)
    should_pad_dim = q.dim() == 3
    if should_pad_dim:
        q = q.unsqueeze(0)
        k = k.unsqueeze(0)
        v = v.unsqueeze(0)
    b, h, n, d = q.shape
    e = d
    kv_history = kv_caches.reshape(1, h, d, e).contiguous()
    output, kv_history = lightning_attention(q,
                                             k,
                                             v,
                                             slope_rate,
                                             block_size=block_size,
                                             kv_history=kv_history)
    kv_caches.copy_(kv_history[:, :, -1, :, :].reshape(h, d, e))
    assert output.shape[0] == 1, "batch size must be 1"
    return rearrange(output.squeeze(0), "h n d -> n (h d)")

MiniMaxText01MLP

Bases: Module

Source code in vllm/model_executor/models/minimax_text_01.py
class MiniMaxText01MLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        quant_config: Optional[QuantizationConfig] = None,
        layer_idx: int = None,
        prefix: str = "mlp",
    ) -> None:
        super().__init__()
        self.layer_idx = layer_idx

        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size,
            [intermediate_size] * 2,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.down_proj",
        )
        self.act_fn = SiluAndMul()
        return

    def forward(self, x: torch.Tensor) -> torch.Tensor:

        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x

act_fn instance-attribute

act_fn = SiluAndMul()

down_proj instance-attribute

down_proj = RowParallelLinear(
    intermediate_size,
    hidden_size,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.down_proj",
)

gate_up_proj instance-attribute

gate_up_proj = MergedColumnParallelLinear(
    hidden_size,
    [intermediate_size] * 2,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.gate_up_proj",
)

layer_idx instance-attribute

layer_idx = layer_idx

__init__

__init__(
    hidden_size: int,
    intermediate_size: int,
    quant_config: Optional[QuantizationConfig] = None,
    layer_idx: int = None,
    prefix: str = "mlp",
) -> None
Source code in vllm/model_executor/models/minimax_text_01.py
def __init__(
    self,
    hidden_size: int,
    intermediate_size: int,
    quant_config: Optional[QuantizationConfig] = None,
    layer_idx: int = None,
    prefix: str = "mlp",
) -> None:
    super().__init__()
    self.layer_idx = layer_idx

    self.gate_up_proj = MergedColumnParallelLinear(
        hidden_size,
        [intermediate_size] * 2,
        bias=False,
        quant_config=quant_config,
        prefix=f"{prefix}.gate_up_proj",
    )
    self.down_proj = RowParallelLinear(
        intermediate_size,
        hidden_size,
        bias=False,
        quant_config=quant_config,
        prefix=f"{prefix}.down_proj",
    )
    self.act_fn = SiluAndMul()
    return

forward

forward(x: Tensor) -> Tensor
Source code in vllm/model_executor/models/minimax_text_01.py
def forward(self, x: torch.Tensor) -> torch.Tensor:

    gate_up, _ = self.gate_up_proj(x)
    x = self.act_fn(gate_up)
    x, _ = self.down_proj(x)
    return x

MiniMaxText01MoE

Bases: Module

Source code in vllm/model_executor/models/minimax_text_01.py
class MiniMaxText01MoE(nn.Module):

    def __init__(
        self,
        num_experts: int,
        top_k: int,
        hidden_size: int,
        intermediate_size: int,
        params_dtype: Optional[torch.dtype] = None,
        layer_idx: int = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "moe",
    ) -> None:
        super().__init__()

        self.layer_idx = layer_idx
        self.tp_size = get_tensor_model_parallel_world_size()
        self.num_total_experts = num_experts
        self.top_k = top_k
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size // self.tp_size
        self.quant_config = quant_config

        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype

        self.gate = ReplicatedLinear(
            self.hidden_size,
            self.num_total_experts,
            bias=False,
            params_dtype=torch.float32,
            quant_config=None,
            prefix=f"{prefix}.gate",
        )
        self.gate.weight.weight_loader = MiniMaxText01MoE.gate_weight_loader

        self.experts = FusedMoE(
            num_experts=self.num_total_experts,
            top_k=self.top_k,
            hidden_size=self.hidden_size,
            intermediate_size=self.intermediate_size * self.tp_size,
            params_dtype=self.params_dtype,
            reduce_results=True,
            renormalize=True,
            quant_config=self.quant_config,
            tp_size=self.tp_size,
            prefix=f"{prefix}.experts",
        )
        return

    @staticmethod
    def gate_weight_loader(param: nn.Parameter,
                           loaded_weight: torch.Tensor) -> None:
        assert param.size() == loaded_weight.size()
        param.data.copy_(loaded_weight.to(torch.float32))
        return

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        num_tokens, hidden_size = hidden_states.shape
        hidden_states = hidden_states.view(-1, self.hidden_size)
        router_logits_fp32, _ = self.gate(hidden_states.to(torch.float32))
        final_hidden_states = self.experts(
            hidden_states, router_logits_fp32.to(hidden_states.dtype))
        final_hidden = final_hidden_states.view(num_tokens, hidden_size)
        return final_hidden

experts instance-attribute

experts = FusedMoE(
    num_experts=num_total_experts,
    top_k=top_k,
    hidden_size=hidden_size,
    intermediate_size=intermediate_size * tp_size,
    params_dtype=params_dtype,
    reduce_results=True,
    renormalize=True,
    quant_config=quant_config,
    tp_size=tp_size,
    prefix=f"{prefix}.experts",
)

gate instance-attribute

gate = ReplicatedLinear(
    hidden_size,
    num_total_experts,
    bias=False,
    params_dtype=float32,
    quant_config=None,
    prefix=f"{prefix}.gate",
)

hidden_size instance-attribute

hidden_size = hidden_size

intermediate_size instance-attribute

intermediate_size = intermediate_size // tp_size

layer_idx instance-attribute

layer_idx = layer_idx

num_total_experts instance-attribute

num_total_experts = num_experts

params_dtype instance-attribute

params_dtype = params_dtype

quant_config instance-attribute

quant_config = quant_config

top_k instance-attribute

top_k = top_k

tp_size instance-attribute

__init__

__init__(
    num_experts: int,
    top_k: int,
    hidden_size: int,
    intermediate_size: int,
    params_dtype: Optional[dtype] = None,
    layer_idx: int = None,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "moe",
) -> None
Source code in vllm/model_executor/models/minimax_text_01.py
def __init__(
    self,
    num_experts: int,
    top_k: int,
    hidden_size: int,
    intermediate_size: int,
    params_dtype: Optional[torch.dtype] = None,
    layer_idx: int = None,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "moe",
) -> None:
    super().__init__()

    self.layer_idx = layer_idx
    self.tp_size = get_tensor_model_parallel_world_size()
    self.num_total_experts = num_experts
    self.top_k = top_k
    self.hidden_size = hidden_size
    self.intermediate_size = intermediate_size // self.tp_size
    self.quant_config = quant_config

    if params_dtype is None:
        params_dtype = torch.get_default_dtype()
    self.params_dtype = params_dtype

    self.gate = ReplicatedLinear(
        self.hidden_size,
        self.num_total_experts,
        bias=False,
        params_dtype=torch.float32,
        quant_config=None,
        prefix=f"{prefix}.gate",
    )
    self.gate.weight.weight_loader = MiniMaxText01MoE.gate_weight_loader

    self.experts = FusedMoE(
        num_experts=self.num_total_experts,
        top_k=self.top_k,
        hidden_size=self.hidden_size,
        intermediate_size=self.intermediate_size * self.tp_size,
        params_dtype=self.params_dtype,
        reduce_results=True,
        renormalize=True,
        quant_config=self.quant_config,
        tp_size=self.tp_size,
        prefix=f"{prefix}.experts",
    )
    return

forward

forward(hidden_states: Tensor) -> Tensor
Source code in vllm/model_executor/models/minimax_text_01.py
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
    num_tokens, hidden_size = hidden_states.shape
    hidden_states = hidden_states.view(-1, self.hidden_size)
    router_logits_fp32, _ = self.gate(hidden_states.to(torch.float32))
    final_hidden_states = self.experts(
        hidden_states, router_logits_fp32.to(hidden_states.dtype))
    final_hidden = final_hidden_states.view(num_tokens, hidden_size)
    return final_hidden

gate_weight_loader staticmethod

gate_weight_loader(
    param: Parameter, loaded_weight: Tensor
) -> None
Source code in vllm/model_executor/models/minimax_text_01.py
@staticmethod
def gate_weight_loader(param: nn.Parameter,
                       loaded_weight: torch.Tensor) -> None:
    assert param.size() == loaded_weight.size()
    param.data.copy_(loaded_weight.to(torch.float32))
    return

MiniMaxText01Model

Bases: Module

Source code in vllm/model_executor/models/minimax_text_01.py
class MiniMaxText01Model(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
        cache_config: Optional[CacheConfig] = None,
        scheduler_config=None,
        prefix: str = "",
    ) -> None:
        super().__init__()

        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.decoder_attention_types = getattr(
            config, "attn_type_list", False) or getattr(
                config, "decoder_attention_types", False)
        if not self.decoder_attention_types:
            self.decoder_attention_types = [1] * config.num_hidden_layers
        self.num_layers = config.num_hidden_layers

        self._layer_barrier = False
        if get_pp_group().is_first_rank:
            self.embed_tokens = VocabParallelEmbedding(
                self.vocab_size,
                config.hidden_size,
                org_num_embeddings=self.vocab_size,
            )
        else:
            self.embed_tokens = PPMissingLayer()

        def layer_fn(prefix):
            layer_idx = int(prefix.split('.')[-1])
            layer_config = config
            layer_config.attention_type = self.decoder_attention_types[
                layer_idx]
            layer_config.layer_idx = layer_idx

            decoder_kwargs = {
                "quant_config": quant_config,
                "layer_id": layer_idx,
                "cache_config": cache_config
            }

            if layer_config.attention_type == 0:
                decoder_kwargs["linear_layer_id"] = sum(
                    1 for i in range(layer_idx)
                    if self.decoder_attention_types[i] == 0)
            else:
                decoder_kwargs["linear_layer_id"] = None

            if hasattr(config, "num_local_experts") and isinstance(
                    config.num_local_experts, list):
                decoder_kwargs["expert_num"] = config.num_local_experts[
                    layer_idx]
            elif hasattr(config, "num_local_experts") and isinstance(
                    config.num_local_experts, int):
                decoder_kwargs["expert_num"] = config.num_local_experts
            else:
                decoder_kwargs["expert_num"] = 1

            return MiniMaxText01DecoderLayer(layer_config,
                                             **decoder_kwargs,
                                             prefix=prefix)

        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers, layer_fn, prefix=f"{prefix}.layers")

        linear_layer_nums = sum(1 for i in range(config.num_hidden_layers)
                                if self.decoder_attention_types[i] == 0)
        max_slots_number = scheduler_config.max_num_seqs
        self.cache_shape = (linear_layer_nums, max_slots_number,
                            config.num_attention_heads //
                            get_tensor_model_parallel_world_size(),
                            config.head_dim, config.head_dim)
        _dummy = torch.zeros(1)
        self._dtype = _dummy.dtype
        del _dummy

        self.minimax_cache = MinimaxCacheManager(dtype=torch.float32,
                                                 cache_shape=self.cache_shape)

        rope_theta = getattr(config, "rope_theta", 10000)
        head_dim = getattr(config, "head_dim", None)
        if head_dim is None:
            head_dim = config.hidden_size // config.num_attention_heads
        if hasattr(config, "max_model_len") and isinstance(
                config.max_model_len, int):
            max_position_embeddings = min(config.max_position_embeddings,
                                          config.max_model_len)
        self.rotary_emb = MiniMaxText01RotaryEmbedding(
            head_dim,
            rotary_dim=config.rotary_dim
            if hasattr(config, "rotary_dim") else head_dim,
            max_position=max_position_embeddings,
            base=int(rope_theta),
            is_neox_style=True,
            cache_dtype=torch.float32,
        )

        norm_kwargs = {}
        if hasattr(config, "rms_norm_eps"):
            norm_kwargs["eps"] = config.rms_norm_eps
        if get_pp_group().is_last_rank:
            self.norm = RMSNorm(config.hidden_size, **norm_kwargs)
        else:
            self.norm = PPMissingLayer()
        self.embed_scale = 1.0
        return

    def _clear_prefill_cache(self, attn_metadata,
                             minimax_cache_tensors: torch.Tensor, **kwargs):
        seq_to_slot_maps = {}
        seq_id_map = sum(list(kwargs["request_ids_to_seq_ids"].values()), [])
        for _, seq_to_slot_map in (
                self.minimax_cache.cache_indices_mapping.items()):
            seq_to_slot_maps.update(seq_to_slot_map)

        slots_to_clear = []
        for _prefill_id in range(getattr(attn_metadata, "num_prefills", 0)):
            if _prefill_id >= len(seq_id_map):
                break
            seq_id = seq_id_map[_prefill_id]
            if attn_metadata.context_lens_tensor[
                    _prefill_id] == 0 and seq_id in seq_to_slot_maps:
                slots_to_clear.append(seq_to_slot_maps[seq_id])

        if slots_to_clear:
            slots_tensor = torch.tensor(slots_to_clear,
                                        device=minimax_cache_tensors.device,
                                        dtype=torch.long)
            minimax_cache_tensors[:, slots_tensor, ...] = 0

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
    ) -> torch.Tensor:
        return self.embed_tokens(input_ids)

    def forward(self,
                input_ids: Optional[torch.Tensor],
                positions: torch.Tensor,
                intermediate_tensors: Optional[IntermediateTensors] = None,
                inputs_embeds: Optional[torch.Tensor] = None,
                **kwargs) -> Union[torch.Tensor, IntermediateTensors]:
        forward_context = get_forward_context()
        attn_metadata = forward_context.attn_metadata
        if attn_metadata is None:
            return None
        if "request_ids_to_seq_ids" not in kwargs:
            kwargs["request_ids_to_seq_ids"] = {}
        if "finished_requests_ids" not in kwargs:
            kwargs["finished_requests_ids"] = []

        (
            minimax_cache_tensors,
            state_indices_tensor,
        ) = self.minimax_cache.current_run_tensors(**kwargs)
        if getattr(attn_metadata, "num_prefills", 0) > 0:
            self._clear_prefill_cache(attn_metadata, minimax_cache_tensors,
                                      **kwargs)

        minimax_cache_params = MinimaxCacheParams(minimax_cache_tensors,
                                                  state_indices_tensor)
        if get_pp_group().is_first_rank:
            if inputs_embeds is None:
                hidden_states = self.embed_scale * self.embed_tokens(input_ids)
            else:
                hidden_states = inputs_embeds
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

        minimax_cache_index = 0
        attn_metadata.rotary_emb = self.rotary_emb
        for i in range(self.start_layer, self.end_layer):
            layer = self.layers[i]
            _caches = None
            if isinstance(layer.self_attn, MiniMaxText01LinearAttention):
                current_state_layer = minimax_cache_index
                _caches = minimax_cache_params.at_layer_idx(
                    current_state_layer)
                minimax_cache_index += 1
            hidden_states, residual = layer(
                hidden_states=hidden_states,
                positions=positions,
                kv_caches=_caches,
                attn_metadata=attn_metadata,
                residual=residual,
            )
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })
        if residual is not None:
            hidden_states, _ = self.norm(hidden_states, residual)
        else:
            hidden_states = self.norm(hidden_states)

        return hidden_states

_dtype instance-attribute

_dtype = dtype

_layer_barrier instance-attribute

_layer_barrier = False

cache_shape instance-attribute

cache_shape = (
    linear_layer_nums,
    max_slots_number,
    num_attention_heads
    // get_tensor_model_parallel_world_size(),
    head_dim,
    head_dim,
)

decoder_attention_types instance-attribute

decoder_attention_types = getattr(
    config, "attn_type_list", False
) or getattr(config, "decoder_attention_types", False)

embed_scale instance-attribute

embed_scale = 1.0

embed_tokens instance-attribute

embed_tokens = VocabParallelEmbedding(
    vocab_size, hidden_size, org_num_embeddings=vocab_size
)

minimax_cache instance-attribute

minimax_cache = MinimaxCacheManager(
    dtype=float32, cache_shape=cache_shape
)

norm instance-attribute

norm = RMSNorm(hidden_size, **norm_kwargs)

num_layers instance-attribute

num_layers = num_hidden_layers

padding_idx instance-attribute

padding_idx = pad_token_id

rotary_emb instance-attribute

rotary_emb = MiniMaxText01RotaryEmbedding(
    head_dim,
    rotary_dim=rotary_dim
    if hasattr(config, "rotary_dim")
    else head_dim,
    max_position=max_position_embeddings,
    base=int(rope_theta),
    is_neox_style=True,
    cache_dtype=float32,
)

vocab_size instance-attribute

vocab_size = vocab_size

__init__

__init__(
    config: PretrainedConfig,
    quant_config: Optional[QuantizationConfig] = None,
    cache_config: Optional[CacheConfig] = None,
    scheduler_config=None,
    prefix: str = "",
) -> None
Source code in vllm/model_executor/models/minimax_text_01.py
def __init__(
    self,
    config: PretrainedConfig,
    quant_config: Optional[QuantizationConfig] = None,
    cache_config: Optional[CacheConfig] = None,
    scheduler_config=None,
    prefix: str = "",
) -> None:
    super().__init__()

    self.padding_idx = config.pad_token_id
    self.vocab_size = config.vocab_size

    self.decoder_attention_types = getattr(
        config, "attn_type_list", False) or getattr(
            config, "decoder_attention_types", False)
    if not self.decoder_attention_types:
        self.decoder_attention_types = [1] * config.num_hidden_layers
    self.num_layers = config.num_hidden_layers

    self._layer_barrier = False
    if get_pp_group().is_first_rank:
        self.embed_tokens = VocabParallelEmbedding(
            self.vocab_size,
            config.hidden_size,
            org_num_embeddings=self.vocab_size,
        )
    else:
        self.embed_tokens = PPMissingLayer()

    def layer_fn(prefix):
        layer_idx = int(prefix.split('.')[-1])
        layer_config = config
        layer_config.attention_type = self.decoder_attention_types[
            layer_idx]
        layer_config.layer_idx = layer_idx

        decoder_kwargs = {
            "quant_config": quant_config,
            "layer_id": layer_idx,
            "cache_config": cache_config
        }

        if layer_config.attention_type == 0:
            decoder_kwargs["linear_layer_id"] = sum(
                1 for i in range(layer_idx)
                if self.decoder_attention_types[i] == 0)
        else:
            decoder_kwargs["linear_layer_id"] = None

        if hasattr(config, "num_local_experts") and isinstance(
                config.num_local_experts, list):
            decoder_kwargs["expert_num"] = config.num_local_experts[
                layer_idx]
        elif hasattr(config, "num_local_experts") and isinstance(
                config.num_local_experts, int):
            decoder_kwargs["expert_num"] = config.num_local_experts
        else:
            decoder_kwargs["expert_num"] = 1

        return MiniMaxText01DecoderLayer(layer_config,
                                         **decoder_kwargs,
                                         prefix=prefix)

    self.start_layer, self.end_layer, self.layers = make_layers(
        config.num_hidden_layers, layer_fn, prefix=f"{prefix}.layers")

    linear_layer_nums = sum(1 for i in range(config.num_hidden_layers)
                            if self.decoder_attention_types[i] == 0)
    max_slots_number = scheduler_config.max_num_seqs
    self.cache_shape = (linear_layer_nums, max_slots_number,
                        config.num_attention_heads //
                        get_tensor_model_parallel_world_size(),
                        config.head_dim, config.head_dim)
    _dummy = torch.zeros(1)
    self._dtype = _dummy.dtype
    del _dummy

    self.minimax_cache = MinimaxCacheManager(dtype=torch.float32,
                                             cache_shape=self.cache_shape)

    rope_theta = getattr(config, "rope_theta", 10000)
    head_dim = getattr(config, "head_dim", None)
    if head_dim is None:
        head_dim = config.hidden_size // config.num_attention_heads
    if hasattr(config, "max_model_len") and isinstance(
            config.max_model_len, int):
        max_position_embeddings = min(config.max_position_embeddings,
                                      config.max_model_len)
    self.rotary_emb = MiniMaxText01RotaryEmbedding(
        head_dim,
        rotary_dim=config.rotary_dim
        if hasattr(config, "rotary_dim") else head_dim,
        max_position=max_position_embeddings,
        base=int(rope_theta),
        is_neox_style=True,
        cache_dtype=torch.float32,
    )

    norm_kwargs = {}
    if hasattr(config, "rms_norm_eps"):
        norm_kwargs["eps"] = config.rms_norm_eps
    if get_pp_group().is_last_rank:
        self.norm = RMSNorm(config.hidden_size, **norm_kwargs)
    else:
        self.norm = PPMissingLayer()
    self.embed_scale = 1.0
    return

_clear_prefill_cache

_clear_prefill_cache(
    attn_metadata, minimax_cache_tensors: Tensor, **kwargs
)
Source code in vllm/model_executor/models/minimax_text_01.py
def _clear_prefill_cache(self, attn_metadata,
                         minimax_cache_tensors: torch.Tensor, **kwargs):
    seq_to_slot_maps = {}
    seq_id_map = sum(list(kwargs["request_ids_to_seq_ids"].values()), [])
    for _, seq_to_slot_map in (
            self.minimax_cache.cache_indices_mapping.items()):
        seq_to_slot_maps.update(seq_to_slot_map)

    slots_to_clear = []
    for _prefill_id in range(getattr(attn_metadata, "num_prefills", 0)):
        if _prefill_id >= len(seq_id_map):
            break
        seq_id = seq_id_map[_prefill_id]
        if attn_metadata.context_lens_tensor[
                _prefill_id] == 0 and seq_id in seq_to_slot_maps:
            slots_to_clear.append(seq_to_slot_maps[seq_id])

    if slots_to_clear:
        slots_tensor = torch.tensor(slots_to_clear,
                                    device=minimax_cache_tensors.device,
                                    dtype=torch.long)
        minimax_cache_tensors[:, slots_tensor, ...] = 0

forward

forward(
    input_ids: Optional[Tensor],
    positions: Tensor,
    intermediate_tensors: Optional[
        IntermediateTensors
    ] = None,
    inputs_embeds: Optional[Tensor] = None,
    **kwargs,
) -> Union[Tensor, IntermediateTensors]
Source code in vllm/model_executor/models/minimax_text_01.py
def forward(self,
            input_ids: Optional[torch.Tensor],
            positions: torch.Tensor,
            intermediate_tensors: Optional[IntermediateTensors] = None,
            inputs_embeds: Optional[torch.Tensor] = None,
            **kwargs) -> Union[torch.Tensor, IntermediateTensors]:
    forward_context = get_forward_context()
    attn_metadata = forward_context.attn_metadata
    if attn_metadata is None:
        return None
    if "request_ids_to_seq_ids" not in kwargs:
        kwargs["request_ids_to_seq_ids"] = {}
    if "finished_requests_ids" not in kwargs:
        kwargs["finished_requests_ids"] = []

    (
        minimax_cache_tensors,
        state_indices_tensor,
    ) = self.minimax_cache.current_run_tensors(**kwargs)
    if getattr(attn_metadata, "num_prefills", 0) > 0:
        self._clear_prefill_cache(attn_metadata, minimax_cache_tensors,
                                  **kwargs)

    minimax_cache_params = MinimaxCacheParams(minimax_cache_tensors,
                                              state_indices_tensor)
    if get_pp_group().is_first_rank:
        if inputs_embeds is None:
            hidden_states = self.embed_scale * self.embed_tokens(input_ids)
        else:
            hidden_states = inputs_embeds
        residual = None
    else:
        assert intermediate_tensors is not None
        hidden_states = intermediate_tensors["hidden_states"]
        residual = intermediate_tensors["residual"]

    minimax_cache_index = 0
    attn_metadata.rotary_emb = self.rotary_emb
    for i in range(self.start_layer, self.end_layer):
        layer = self.layers[i]
        _caches = None
        if isinstance(layer.self_attn, MiniMaxText01LinearAttention):
            current_state_layer = minimax_cache_index
            _caches = minimax_cache_params.at_layer_idx(
                current_state_layer)
            minimax_cache_index += 1
        hidden_states, residual = layer(
            hidden_states=hidden_states,
            positions=positions,
            kv_caches=_caches,
            attn_metadata=attn_metadata,
            residual=residual,
        )
    if not get_pp_group().is_last_rank:
        return IntermediateTensors({
            "hidden_states": hidden_states,
            "residual": residual
        })
    if residual is not None:
        hidden_states, _ = self.norm(hidden_states, residual)
    else:
        hidden_states = self.norm(hidden_states)

    return hidden_states

get_input_embeddings

get_input_embeddings(input_ids: Tensor) -> Tensor
Source code in vllm/model_executor/models/minimax_text_01.py
def get_input_embeddings(
    self,
    input_ids: torch.Tensor,
) -> torch.Tensor:
    return self.embed_tokens(input_ids)

MiniMaxText01RMSNormTP

Bases: CustomOp

Source code in vllm/model_executor/models/minimax_text_01.py
class MiniMaxText01RMSNormTP(CustomOp):
    name = "MiniMaxText01RMSNormTP"

    def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
        super().__init__()
        self.tp_world = get_tensor_model_parallel_world_size()
        self.tp_rank = get_tensor_model_parallel_rank()
        self.weight = nn.Parameter(torch.ones(int(hidden_size /
                                                  self.tp_world)))

        self.weight.weight_loader = self.weight_loader
        self.variance_epsilon = eps
        return

    @staticmethod
    def weight_loader(
        param: nn.Parameter,
        loaded_weight: torch.Tensor,
    ) -> None:
        tp_world = get_tensor_model_parallel_world_size()
        tp_rank = get_tensor_model_parallel_rank()

        shard_size = loaded_weight.shape[0] // tp_world
        shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
        param.data.copy_(loaded_weight[shard])
        return

    def _forward(
        self,
        x: torch.Tensor,
    ) -> torch.Tensor:
        orig_dtype = x.dtype
        x = x.to(torch.float32)
        variance = x.pow(2).mean(dim=-1, keepdim=True, dtype=torch.float32)
        if self.tp_world > 1:
            variance = tensor_model_parallel_all_reduce(
                variance) / self.tp_world
        x = x * torch.rsqrt(variance + self.variance_epsilon)

        weight = self.weight
        if x.size(-1) != self.weight.size(0):
            if self.weight.size(0) < x.size(-1):
                repeat_count = (x.size(-1) + self.weight.size(0)) // x.size(-1)
                full_weight = self.weight.repeat(repeat_count)
                weight = full_weight[:x.size(-1)]
            else:
                weight = self.weight[:x.size(-1)]

        x = x.to(orig_dtype) * weight
        return x

    def forward(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
        assert residual is None, "RMSNorm does not support residual connection."
        return self._forward(x)

name class-attribute instance-attribute

name = 'MiniMaxText01RMSNormTP'

tp_rank instance-attribute

tp_world instance-attribute

variance_epsilon instance-attribute

variance_epsilon = eps

weight instance-attribute

weight = Parameter(ones(int(hidden_size / tp_world)))

__init__

__init__(hidden_size: int, eps: float = 1e-06) -> None
Source code in vllm/model_executor/models/minimax_text_01.py
def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
    super().__init__()
    self.tp_world = get_tensor_model_parallel_world_size()
    self.tp_rank = get_tensor_model_parallel_rank()
    self.weight = nn.Parameter(torch.ones(int(hidden_size /
                                              self.tp_world)))

    self.weight.weight_loader = self.weight_loader
    self.variance_epsilon = eps
    return

_forward

_forward(x: Tensor) -> Tensor
Source code in vllm/model_executor/models/minimax_text_01.py
def _forward(
    self,
    x: torch.Tensor,
) -> torch.Tensor:
    orig_dtype = x.dtype
    x = x.to(torch.float32)
    variance = x.pow(2).mean(dim=-1, keepdim=True, dtype=torch.float32)
    if self.tp_world > 1:
        variance = tensor_model_parallel_all_reduce(
            variance) / self.tp_world
    x = x * torch.rsqrt(variance + self.variance_epsilon)

    weight = self.weight
    if x.size(-1) != self.weight.size(0):
        if self.weight.size(0) < x.size(-1):
            repeat_count = (x.size(-1) + self.weight.size(0)) // x.size(-1)
            full_weight = self.weight.repeat(repeat_count)
            weight = full_weight[:x.size(-1)]
        else:
            weight = self.weight[:x.size(-1)]

    x = x.to(orig_dtype) * weight
    return x

forward

forward(
    x: Tensor, residual: Optional[Tensor] = None
) -> Union[Tensor, tuple[Tensor, Tensor]]
Source code in vllm/model_executor/models/minimax_text_01.py
def forward(
    self,
    x: torch.Tensor,
    residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
    assert residual is None, "RMSNorm does not support residual connection."
    return self._forward(x)

weight_loader staticmethod

weight_loader(
    param: Parameter, loaded_weight: Tensor
) -> None
Source code in vllm/model_executor/models/minimax_text_01.py
@staticmethod
def weight_loader(
    param: nn.Parameter,
    loaded_weight: torch.Tensor,
) -> None:
    tp_world = get_tensor_model_parallel_world_size()
    tp_rank = get_tensor_model_parallel_rank()

    shard_size = loaded_weight.shape[0] // tp_world
    shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
    param.data.copy_(loaded_weight[shard])
    return

MiniMaxText01RotaryEmbedding

Bases: CustomOp

Source code in vllm/model_executor/models/minimax_text_01.py
class MiniMaxText01RotaryEmbedding(CustomOp):
    name = "MiniMaxText01RotaryEmbedding"

    def __init__(
        self,
        head_size: int,
        rotary_dim: int,
        max_position: int,
        base: float,
        is_neox_style: bool,
        cache_dtype: torch.dtype,
    ) -> None:
        super().__init__()
        self.head_size = head_size
        self.rotary_dim = rotary_dim
        self.max_position_embeddings = max_position
        self.base = base
        self.is_neox_style = is_neox_style
        self.cache_dtype = cache_dtype
        cache = self._compute_cos_sin_cache().to(cache_dtype)
        self.register_buffer("cos_sin_cache", cache, persistent=False)

    def _compute_inv_freq(self, base: float) -> torch.Tensor:
        """Compute the inverse frequency."""
        inv_freq = 1.0 / (base**(torch.arange(
            0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
        return inv_freq

    def _compute_cos_sin_cache(self) -> torch.Tensor:
        """Compute the cos and sin cache."""
        inv_freq = self._compute_inv_freq(self.base)
        t = torch.arange(self.max_position_embeddings, dtype=torch.float)
        freqs = torch.einsum("i,j -> ij", t, inv_freq)
        cos = freqs.cos()
        sin = freqs.sin()
        cache = torch.cat((cos, sin), dim=-1)
        return cache

    def forward(
        self,
        positions: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        from vllm import _custom_ops as ops
        self.cos_sin_cache = self.cos_sin_cache.to(positions.device)
        query_cast = query.to(self.cache_dtype)
        key_cast = key.to(self.cache_dtype)
        ops.rotary_embedding(positions, query_cast, key_cast, self.head_size,
                             self.cos_sin_cache, self.is_neox_style)
        query = query_cast.to(query.dtype)
        key = key_cast.to(key.dtype)
        return query, key

base instance-attribute

base = base

cache_dtype instance-attribute

cache_dtype = cache_dtype

head_size instance-attribute

head_size = head_size

is_neox_style instance-attribute

is_neox_style = is_neox_style

max_position_embeddings instance-attribute

max_position_embeddings = max_position

name class-attribute instance-attribute

name = 'MiniMaxText01RotaryEmbedding'

rotary_dim instance-attribute

rotary_dim = rotary_dim

__init__

__init__(
    head_size: int,
    rotary_dim: int,
    max_position: int,
    base: float,
    is_neox_style: bool,
    cache_dtype: dtype,
) -> None
Source code in vllm/model_executor/models/minimax_text_01.py
def __init__(
    self,
    head_size: int,
    rotary_dim: int,
    max_position: int,
    base: float,
    is_neox_style: bool,
    cache_dtype: torch.dtype,
) -> None:
    super().__init__()
    self.head_size = head_size
    self.rotary_dim = rotary_dim
    self.max_position_embeddings = max_position
    self.base = base
    self.is_neox_style = is_neox_style
    self.cache_dtype = cache_dtype
    cache = self._compute_cos_sin_cache().to(cache_dtype)
    self.register_buffer("cos_sin_cache", cache, persistent=False)

_compute_cos_sin_cache

_compute_cos_sin_cache() -> Tensor

Compute the cos and sin cache.

Source code in vllm/model_executor/models/minimax_text_01.py
def _compute_cos_sin_cache(self) -> torch.Tensor:
    """Compute the cos and sin cache."""
    inv_freq = self._compute_inv_freq(self.base)
    t = torch.arange(self.max_position_embeddings, dtype=torch.float)
    freqs = torch.einsum("i,j -> ij", t, inv_freq)
    cos = freqs.cos()
    sin = freqs.sin()
    cache = torch.cat((cos, sin), dim=-1)
    return cache

_compute_inv_freq

_compute_inv_freq(base: float) -> Tensor

Compute the inverse frequency.

Source code in vllm/model_executor/models/minimax_text_01.py
def _compute_inv_freq(self, base: float) -> torch.Tensor:
    """Compute the inverse frequency."""
    inv_freq = 1.0 / (base**(torch.arange(
        0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
    return inv_freq

forward

forward(
    positions: Tensor, query: Tensor, key: Tensor
) -> tuple[Tensor, Tensor]
Source code in vllm/model_executor/models/minimax_text_01.py
def forward(
    self,
    positions: torch.Tensor,
    query: torch.Tensor,
    key: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
    from vllm import _custom_ops as ops
    self.cos_sin_cache = self.cos_sin_cache.to(positions.device)
    query_cast = query.to(self.cache_dtype)
    key_cast = key.to(self.cache_dtype)
    ops.rotary_embedding(positions, query_cast, key_cast, self.head_size,
                         self.cos_sin_cache, self.is_neox_style)
    query = query_cast.to(query.dtype)
    key = key_cast.to(key.dtype)
    return query, key

replace_weight_name

replace_weight_name(
    name: str,
    key: str = None,
    to: str = None,
    count: int = None,
    prefix: str = None,
) -> str
Source code in vllm/model_executor/models/minimax_text_01.py
def replace_weight_name(name: str,
                        key: str = None,
                        to: str = None,
                        count: int = None,
                        prefix: str = None) -> str:
    name = name.replace(key, to) if count is None else \
        name.replace(key, to, count)
    return name

weight_loader_with_alias

weight_loader_with_alias(alias: str)
Source code in vllm/model_executor/models/minimax_text_01.py
def weight_loader_with_alias(alias: str):

    def wrapper(func: callable):

        def inner_func(param: torch.Tensor,
                       loaded_weight: torch.Tensor,
                       *args,
                       prefix: str = None,
                       **kwargs):
            value = func(param, loaded_weight, *args, **kwargs)
            return value

        return inner_func

    return wrapper