Skip to content

vllm.model_executor.models.minicpm3

Inference-only MiniCPM3 model compatible with HuggingFace weights.

MiniCPM3Attention

Bases: Module

Source code in vllm/model_executor/models/minicpm3.py
class MiniCPM3Attention(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        hidden_size: int,
        num_heads: int,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
        q_lora_rank: int,
        kv_lora_rank: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[dict[str, Any]] = None,
        max_position_embeddings: int = 8192,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
        self.v_head_dim = v_head_dim
        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank
        self.num_heads = num_heads

        tp_size = get_tensor_model_parallel_world_size()
        assert self.num_heads % tp_size == 0
        self.num_local_heads = num_heads // tp_size

        self.scaling = self.qk_head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

        self.q_a_proj = ReplicatedLinear(self.hidden_size,
                                         self.q_lora_rank,
                                         bias=False,
                                         quant_config=quant_config)
        self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
        self.q_b_proj = ColumnParallelLinear(q_lora_rank,
                                             self.num_heads * self.qk_head_dim,
                                             bias=False,
                                             quant_config=quant_config)

        self.kv_a_proj_with_mqa = ReplicatedLinear(self.hidden_size,
                                                   self.kv_lora_rank +
                                                   self.qk_rope_head_dim,
                                                   bias=False,
                                                   quant_config=quant_config)
        self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
                                      eps=config.rms_norm_eps)
        self.kv_b_proj = ColumnParallelLinear(
            self.kv_lora_rank,
            self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
            bias=False,
            quant_config=quant_config)
        # O projection.
        self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
                                        self.hidden_size,
                                        bias=False,
                                        quant_config=quant_config)

        self.rotary_emb = get_rope(
            self.qk_rope_head_dim,
            rotary_dim=self.qk_rope_head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
        self.attn = Attention(self.num_local_heads,
                              self.qk_head_dim,
                              self.scaling,
                              num_kv_heads=self.num_local_heads,
                              cache_config=cache_config,
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        q, _ = self.q_a_proj(hidden_states)
        q = self.q_a_layernorm(q)
        q, _ = self.q_b_proj(q)
        q = q.view(-1, self.num_local_heads, self.qk_head_dim)
        _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim],
                          dim=-1)
        latent_cache, _ = self.kv_a_proj_with_mqa(hidden_states)
        kv_a, _ = latent_cache.split(
            [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
        latent_cache = latent_cache.unsqueeze(1)
        kv_a = self.kv_a_layernorm(kv_a.contiguous())
        kv, _ = self.kv_b_proj(kv_a)
        kv = kv.view(-1, self.num_local_heads,
                     self.qk_nope_head_dim + self.v_head_dim)
        k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)

        k_pe = latent_cache[:, :, self.kv_lora_rank:]

        q_pe, k_pe = self.rotary_emb(
            positions,
            q_pe.reshape(-1, self.num_local_heads * self.qk_rope_head_dim),
            k_pe.reshape(-1, self.qk_rope_head_dim))
        q_pe = q_pe.view(-1, self.num_local_heads, self.qk_rope_head_dim)
        k_pe = k_pe.view(-1, 1, self.qk_rope_head_dim)

        q[..., self.qk_nope_head_dim:] = q_pe

        k = torch.empty_like(q)

        k[..., :self.qk_nope_head_dim] = k_nope
        k[..., self.qk_nope_head_dim:] = k_pe

        q = q.reshape(-1, self.num_local_heads * self.qk_head_dim)
        k = k.view(-1, self.num_local_heads * self.qk_head_dim)
        v = torch.nn.functional.pad(
            v, [0, self.qk_head_dim - self.v_head_dim],
            value=0).view(-1, self.num_local_heads * self.qk_head_dim)

        attn_output = self.attn(q, k, v)
        attn_output = attn_output.view(
            -1, self.num_local_heads,
            self.qk_head_dim)[..., :self.v_head_dim].reshape(
                -1, self.num_local_heads * self.v_head_dim)

        output, _ = self.o_proj(attn_output)
        return output

attn instance-attribute

attn = Attention(
    num_local_heads,
    qk_head_dim,
    scaling,
    num_kv_heads=num_local_heads,
    cache_config=cache_config,
    quant_config=quant_config,
    prefix=f"{prefix}.attn",
)

hidden_size instance-attribute

hidden_size = hidden_size

kv_a_layernorm instance-attribute

kv_a_layernorm = RMSNorm(kv_lora_rank, eps=rms_norm_eps)

kv_a_proj_with_mqa instance-attribute

kv_a_proj_with_mqa = ReplicatedLinear(
    hidden_size,
    kv_lora_rank + qk_rope_head_dim,
    bias=False,
    quant_config=quant_config,
)

kv_b_proj instance-attribute

kv_b_proj = ColumnParallelLinear(
    kv_lora_rank,
    num_heads * qk_nope_head_dim + v_head_dim,
    bias=False,
    quant_config=quant_config,
)

kv_lora_rank instance-attribute

kv_lora_rank = kv_lora_rank

max_position_embeddings instance-attribute

max_position_embeddings = max_position_embeddings

num_heads instance-attribute

num_heads = num_heads

num_local_heads instance-attribute

num_local_heads = num_heads // tp_size

o_proj instance-attribute

o_proj = RowParallelLinear(
    num_heads * v_head_dim,
    hidden_size,
    bias=False,
    quant_config=quant_config,
)

q_a_layernorm instance-attribute

q_a_layernorm = RMSNorm(q_lora_rank, eps=rms_norm_eps)

q_a_proj instance-attribute

q_a_proj = ReplicatedLinear(
    hidden_size,
    q_lora_rank,
    bias=False,
    quant_config=quant_config,
)

q_b_proj instance-attribute

q_b_proj = ColumnParallelLinear(
    q_lora_rank,
    num_heads * qk_head_dim,
    bias=False,
    quant_config=quant_config,
)

q_lora_rank instance-attribute

q_lora_rank = q_lora_rank

qk_head_dim instance-attribute

qk_head_dim = qk_nope_head_dim + qk_rope_head_dim

qk_nope_head_dim instance-attribute

qk_nope_head_dim = qk_nope_head_dim

qk_rope_head_dim instance-attribute

qk_rope_head_dim = qk_rope_head_dim

rope_theta instance-attribute

rope_theta = rope_theta

rotary_emb instance-attribute

rotary_emb = get_rope(
    qk_rope_head_dim,
    rotary_dim=qk_rope_head_dim,
    max_position=max_position_embeddings,
    base=rope_theta,
    rope_scaling=rope_scaling,
)

scaling instance-attribute

scaling = qk_head_dim ** -0.5

v_head_dim instance-attribute

v_head_dim = v_head_dim

__init__

__init__(
    config: PretrainedConfig,
    hidden_size: int,
    num_heads: int,
    qk_nope_head_dim: int,
    qk_rope_head_dim: int,
    v_head_dim: int,
    q_lora_rank: int,
    kv_lora_rank: int,
    rope_theta: float = 10000,
    rope_scaling: Optional[dict[str, Any]] = None,
    max_position_embeddings: int = 8192,
    cache_config: Optional[CacheConfig] = None,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
) -> None
Source code in vllm/model_executor/models/minicpm3.py
def __init__(
    self,
    config: PretrainedConfig,
    hidden_size: int,
    num_heads: int,
    qk_nope_head_dim: int,
    qk_rope_head_dim: int,
    v_head_dim: int,
    q_lora_rank: int,
    kv_lora_rank: int,
    rope_theta: float = 10000,
    rope_scaling: Optional[dict[str, Any]] = None,
    max_position_embeddings: int = 8192,
    cache_config: Optional[CacheConfig] = None,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
) -> None:
    super().__init__()
    self.hidden_size = hidden_size
    self.qk_nope_head_dim = qk_nope_head_dim
    self.qk_rope_head_dim = qk_rope_head_dim
    self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
    self.v_head_dim = v_head_dim
    self.q_lora_rank = q_lora_rank
    self.kv_lora_rank = kv_lora_rank
    self.num_heads = num_heads

    tp_size = get_tensor_model_parallel_world_size()
    assert self.num_heads % tp_size == 0
    self.num_local_heads = num_heads // tp_size

    self.scaling = self.qk_head_dim**-0.5
    self.rope_theta = rope_theta
    self.max_position_embeddings = max_position_embeddings

    self.q_a_proj = ReplicatedLinear(self.hidden_size,
                                     self.q_lora_rank,
                                     bias=False,
                                     quant_config=quant_config)
    self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
    self.q_b_proj = ColumnParallelLinear(q_lora_rank,
                                         self.num_heads * self.qk_head_dim,
                                         bias=False,
                                         quant_config=quant_config)

    self.kv_a_proj_with_mqa = ReplicatedLinear(self.hidden_size,
                                               self.kv_lora_rank +
                                               self.qk_rope_head_dim,
                                               bias=False,
                                               quant_config=quant_config)
    self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
                                  eps=config.rms_norm_eps)
    self.kv_b_proj = ColumnParallelLinear(
        self.kv_lora_rank,
        self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
        bias=False,
        quant_config=quant_config)
    # O projection.
    self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
                                    self.hidden_size,
                                    bias=False,
                                    quant_config=quant_config)

    self.rotary_emb = get_rope(
        self.qk_rope_head_dim,
        rotary_dim=self.qk_rope_head_dim,
        max_position=max_position_embeddings,
        base=rope_theta,
        rope_scaling=rope_scaling,
    )
    self.attn = Attention(self.num_local_heads,
                          self.qk_head_dim,
                          self.scaling,
                          num_kv_heads=self.num_local_heads,
                          cache_config=cache_config,
                          quant_config=quant_config,
                          prefix=f"{prefix}.attn")

forward

forward(positions: Tensor, hidden_states: Tensor) -> Tensor
Source code in vllm/model_executor/models/minicpm3.py
def forward(
    self,
    positions: torch.Tensor,
    hidden_states: torch.Tensor,
) -> torch.Tensor:
    q, _ = self.q_a_proj(hidden_states)
    q = self.q_a_layernorm(q)
    q, _ = self.q_b_proj(q)
    q = q.view(-1, self.num_local_heads, self.qk_head_dim)
    _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim],
                      dim=-1)
    latent_cache, _ = self.kv_a_proj_with_mqa(hidden_states)
    kv_a, _ = latent_cache.split(
        [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
    latent_cache = latent_cache.unsqueeze(1)
    kv_a = self.kv_a_layernorm(kv_a.contiguous())
    kv, _ = self.kv_b_proj(kv_a)
    kv = kv.view(-1, self.num_local_heads,
                 self.qk_nope_head_dim + self.v_head_dim)
    k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)

    k_pe = latent_cache[:, :, self.kv_lora_rank:]

    q_pe, k_pe = self.rotary_emb(
        positions,
        q_pe.reshape(-1, self.num_local_heads * self.qk_rope_head_dim),
        k_pe.reshape(-1, self.qk_rope_head_dim))
    q_pe = q_pe.view(-1, self.num_local_heads, self.qk_rope_head_dim)
    k_pe = k_pe.view(-1, 1, self.qk_rope_head_dim)

    q[..., self.qk_nope_head_dim:] = q_pe

    k = torch.empty_like(q)

    k[..., :self.qk_nope_head_dim] = k_nope
    k[..., self.qk_nope_head_dim:] = k_pe

    q = q.reshape(-1, self.num_local_heads * self.qk_head_dim)
    k = k.view(-1, self.num_local_heads * self.qk_head_dim)
    v = torch.nn.functional.pad(
        v, [0, self.qk_head_dim - self.v_head_dim],
        value=0).view(-1, self.num_local_heads * self.qk_head_dim)

    attn_output = self.attn(q, k, v)
    attn_output = attn_output.view(
        -1, self.num_local_heads,
        self.qk_head_dim)[..., :self.v_head_dim].reshape(
            -1, self.num_local_heads * self.v_head_dim)

    output, _ = self.o_proj(attn_output)
    return output

MiniCPM3DecoderLayer

Bases: MiniCPMDecoderLayer

Source code in vllm/model_executor/models/minicpm3.py
class MiniCPM3DecoderLayer(MiniCPMDecoderLayer):

    def _init_attn_block(self):
        self.input_layernorm = RMSNorm(self.config.hidden_size,
                                       eps=self.config.rms_norm_eps)
        self.self_attn = MiniCPM3Attention(
            config=self.config,
            hidden_size=self.hidden_size,
            num_heads=self.config.num_attention_heads,
            qk_nope_head_dim=self.config.qk_nope_head_dim,
            qk_rope_head_dim=self.config.qk_rope_head_dim,
            v_head_dim=self.config.v_head_dim,
            q_lora_rank=self.config.q_lora_rank,
            kv_lora_rank=self.config.kv_lora_rank,
            rope_theta=self.rope_theta,
            rope_scaling=self.rope_scaling,
            max_position_embeddings=self.max_position_embeddings,
            cache_config=self.cache_config,
            quant_config=self.quant_config,
            prefix=f"{self.prefix}.self_attn",
        )

_init_attn_block

_init_attn_block()
Source code in vllm/model_executor/models/minicpm3.py
def _init_attn_block(self):
    self.input_layernorm = RMSNorm(self.config.hidden_size,
                                   eps=self.config.rms_norm_eps)
    self.self_attn = MiniCPM3Attention(
        config=self.config,
        hidden_size=self.hidden_size,
        num_heads=self.config.num_attention_heads,
        qk_nope_head_dim=self.config.qk_nope_head_dim,
        qk_rope_head_dim=self.config.qk_rope_head_dim,
        v_head_dim=self.config.v_head_dim,
        q_lora_rank=self.config.q_lora_rank,
        kv_lora_rank=self.config.kv_lora_rank,
        rope_theta=self.rope_theta,
        rope_scaling=self.rope_scaling,
        max_position_embeddings=self.max_position_embeddings,
        cache_config=self.cache_config,
        quant_config=self.quant_config,
        prefix=f"{self.prefix}.self_attn",
    )

MiniCPM3ForCausalLM

Bases: MiniCPMForCausalLM

Source code in vllm/model_executor/models/minicpm3.py
class MiniCPM3ForCausalLM(MiniCPMForCausalLM):
    packed_modules_mapping = {
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

    def _init_model(self, *, vllm_config: VllmConfig, prefix: str = ""):
        return MiniCPM3Model(vllm_config=vllm_config, prefix=prefix)

packed_modules_mapping class-attribute instance-attribute

packed_modules_mapping = {
    "gate_up_proj": ["gate_proj", "up_proj"]
}

_init_model

_init_model(*, vllm_config: VllmConfig, prefix: str = '')
Source code in vllm/model_executor/models/minicpm3.py
def _init_model(self, *, vllm_config: VllmConfig, prefix: str = ""):
    return MiniCPM3Model(vllm_config=vllm_config, prefix=prefix)

MiniCPM3Model

Bases: MiniCPMModel

Source code in vllm/model_executor/models/minicpm3.py
class MiniCPM3Model(MiniCPMModel):

    def _init_layers(
        self,
        prefix: str,
        config: PretrainedConfig,
        cache_config: Optional[CacheConfig],
        quant_config: Optional[QuantizationConfig],
    ):
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            lambda prefix: MiniCPM3DecoderLayer(
                config, cache_config, quant_config, prefix=prefix),
            prefix=f"{prefix}.layers")

_init_layers

_init_layers(
    prefix: str,
    config: PretrainedConfig,
    cache_config: Optional[CacheConfig],
    quant_config: Optional[QuantizationConfig],
)
Source code in vllm/model_executor/models/minicpm3.py
def _init_layers(
    self,
    prefix: str,
    config: PretrainedConfig,
    cache_config: Optional[CacheConfig],
    quant_config: Optional[QuantizationConfig],
):
    self.start_layer, self.end_layer, self.layers = make_layers(
        config.num_hidden_layers,
        lambda prefix: MiniCPM3DecoderLayer(
            config, cache_config, quant_config, prefix=prefix),
        prefix=f"{prefix}.layers")