Skip to content

vllm.model_executor.models.deepseek_v2

Inference-only DeepseekV2/DeepseekV3 model.

DeepseekV2Attention

Bases: Module

Source code in vllm/model_executor/models/deepseek_v2.py
class DeepseekV2Attention(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        hidden_size: int,
        num_heads: int,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
        q_lora_rank: int,
        kv_lora_rank: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[dict[str, Any]] = None,
        max_position_embeddings: int = 8192,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
        self.v_head_dim = v_head_dim
        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank
        self.num_heads = num_heads
        tp_size = get_tensor_model_parallel_world_size()
        assert num_heads % tp_size == 0
        self.num_local_heads = num_heads // tp_size
        self.scaling = self.qk_head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

        if self.q_lora_rank is not None:
            self.q_a_proj = ReplicatedLinear(self.hidden_size,
                                             self.q_lora_rank,
                                             bias=False,
                                             quant_config=quant_config,
                                             prefix=f"{prefix}.q_a_proj")
            self.q_a_layernorm = RMSNorm(self.q_lora_rank,
                                         eps=config.rms_norm_eps)
            self.q_b_proj = ColumnParallelLinear(q_lora_rank,
                                                 self.num_heads *
                                                 self.qk_head_dim,
                                                 bias=False,
                                                 quant_config=quant_config,
                                                 prefix=f"{prefix}.q_b_proj")
        else:
            self.q_proj = ColumnParallelLinear(self.hidden_size,
                                               self.num_heads *
                                               self.qk_head_dim,
                                               bias=False,
                                               quant_config=quant_config,
                                               prefix=f"{prefix}.q_proj")

        self.kv_a_proj_with_mqa = ReplicatedLinear(
            self.hidden_size,
            self.kv_lora_rank + self.qk_rope_head_dim,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.kv_a_proj_with_mqa")
        self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
                                      eps=config.rms_norm_eps)
        self.kv_b_proj = ColumnParallelLinear(
            self.kv_lora_rank,
            self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.kv_b_proj")
        # O projection.
        self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
                                        self.hidden_size,
                                        bias=False,
                                        quant_config=quant_config,
                                        prefix=f"{prefix}.o_proj")
        if rope_scaling:
            rope_scaling["rope_type"] = 'deepseek_yarn'

        self.rotary_emb = get_rope(qk_rope_head_dim,
                                   rotary_dim=qk_rope_head_dim,
                                   max_position=max_position_embeddings,
                                   base=rope_theta,
                                   rope_scaling=rope_scaling,
                                   is_neox_style=False)

        if rope_scaling:
            mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
            scaling_factor = rope_scaling["factor"]
            mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
            self.scaling = self.scaling * mscale * mscale

        self.attn = Attention(self.num_local_heads,
                              self.qk_head_dim,
                              self.scaling,
                              num_kv_heads=self.num_local_heads,
                              cache_config=cache_config,
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        if self.q_lora_rank is not None:
            q = self.q_a_proj(hidden_states)[0]
            q = self.q_a_layernorm(q)
            q = self.q_b_proj(q)[0].view(-1, self.num_local_heads,
                                         self.qk_head_dim)
        else:
            q = self.q_proj(hidden_states)[0].view(-1, self.num_local_heads,
                                                   self.qk_head_dim)
        q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim],
                               dim=-1)
        latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
        kv_a, _ = latent_cache.split(
            [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
        latent_cache = latent_cache.unsqueeze(1)
        kv_a = self.kv_a_layernorm(kv_a.contiguous())
        kv = self.kv_b_proj(kv_a)[0]
        kv = kv.view(-1, self.num_local_heads,
                     self.qk_nope_head_dim + self.v_head_dim)
        k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
        k_pe = latent_cache[:, :, self.kv_lora_rank:]

        q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)

        q[..., self.qk_nope_head_dim:] = q_pe
        k = torch.empty_like(q)
        k[..., :self.qk_nope_head_dim] = k_nope
        k[..., self.qk_nope_head_dim:] = k_pe
        # padding value to qk_head_dim for alignment
        v = torch.nn.functional.pad(
            v, [0, self.qk_head_dim - self.v_head_dim],
            value=0).view(-1, self.num_local_heads * self.qk_head_dim)
        attn_output = self.attn(q, k, v)
        attn_output = attn_output.view(
            -1, self.num_local_heads,
            self.qk_head_dim)[..., :self.v_head_dim].reshape(
                -1, self.num_local_heads * self.v_head_dim)
        output, _ = self.o_proj(attn_output)
        return output

attn instance-attribute

attn = Attention(
    num_local_heads,
    qk_head_dim,
    scaling,
    num_kv_heads=num_local_heads,
    cache_config=cache_config,
    quant_config=quant_config,
    prefix=f"{prefix}.attn",
)

hidden_size instance-attribute

hidden_size = hidden_size

kv_a_layernorm instance-attribute

kv_a_layernorm = RMSNorm(kv_lora_rank, eps=rms_norm_eps)

kv_a_proj_with_mqa instance-attribute

kv_a_proj_with_mqa = ReplicatedLinear(
    hidden_size,
    kv_lora_rank + qk_rope_head_dim,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.kv_a_proj_with_mqa",
)

kv_b_proj instance-attribute

kv_b_proj = ColumnParallelLinear(
    kv_lora_rank,
    num_heads * qk_nope_head_dim + v_head_dim,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.kv_b_proj",
)

kv_lora_rank instance-attribute

kv_lora_rank = kv_lora_rank

max_position_embeddings instance-attribute

max_position_embeddings = max_position_embeddings

num_heads instance-attribute

num_heads = num_heads

num_local_heads instance-attribute

num_local_heads = num_heads // tp_size

o_proj instance-attribute

o_proj = RowParallelLinear(
    num_heads * v_head_dim,
    hidden_size,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.o_proj",
)

q_a_layernorm instance-attribute

q_a_layernorm = RMSNorm(q_lora_rank, eps=rms_norm_eps)

q_a_proj instance-attribute

q_a_proj = ReplicatedLinear(
    hidden_size,
    q_lora_rank,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.q_a_proj",
)

q_b_proj instance-attribute

q_b_proj = ColumnParallelLinear(
    q_lora_rank,
    num_heads * qk_head_dim,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.q_b_proj",
)

q_lora_rank instance-attribute

q_lora_rank = q_lora_rank

q_proj instance-attribute

q_proj = ColumnParallelLinear(
    hidden_size,
    num_heads * qk_head_dim,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.q_proj",
)

qk_head_dim instance-attribute

qk_head_dim = qk_nope_head_dim + qk_rope_head_dim

qk_nope_head_dim instance-attribute

qk_nope_head_dim = qk_nope_head_dim

qk_rope_head_dim instance-attribute

qk_rope_head_dim = qk_rope_head_dim

rope_theta instance-attribute

rope_theta = rope_theta

rotary_emb instance-attribute

rotary_emb = get_rope(
    qk_rope_head_dim,
    rotary_dim=qk_rope_head_dim,
    max_position=max_position_embeddings,
    base=rope_theta,
    rope_scaling=rope_scaling,
    is_neox_style=False,
)

scaling instance-attribute

scaling = qk_head_dim ** -0.5

v_head_dim instance-attribute

v_head_dim = v_head_dim

__init__

__init__(
    config: PretrainedConfig,
    hidden_size: int,
    num_heads: int,
    qk_nope_head_dim: int,
    qk_rope_head_dim: int,
    v_head_dim: int,
    q_lora_rank: int,
    kv_lora_rank: int,
    rope_theta: float = 10000,
    rope_scaling: Optional[dict[str, Any]] = None,
    max_position_embeddings: int = 8192,
    cache_config: Optional[CacheConfig] = None,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
) -> None
Source code in vllm/model_executor/models/deepseek_v2.py
def __init__(
    self,
    config: PretrainedConfig,
    hidden_size: int,
    num_heads: int,
    qk_nope_head_dim: int,
    qk_rope_head_dim: int,
    v_head_dim: int,
    q_lora_rank: int,
    kv_lora_rank: int,
    rope_theta: float = 10000,
    rope_scaling: Optional[dict[str, Any]] = None,
    max_position_embeddings: int = 8192,
    cache_config: Optional[CacheConfig] = None,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
) -> None:
    super().__init__()
    self.hidden_size = hidden_size
    self.qk_nope_head_dim = qk_nope_head_dim
    self.qk_rope_head_dim = qk_rope_head_dim
    self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
    self.v_head_dim = v_head_dim
    self.q_lora_rank = q_lora_rank
    self.kv_lora_rank = kv_lora_rank
    self.num_heads = num_heads
    tp_size = get_tensor_model_parallel_world_size()
    assert num_heads % tp_size == 0
    self.num_local_heads = num_heads // tp_size
    self.scaling = self.qk_head_dim**-0.5
    self.rope_theta = rope_theta
    self.max_position_embeddings = max_position_embeddings

    if self.q_lora_rank is not None:
        self.q_a_proj = ReplicatedLinear(self.hidden_size,
                                         self.q_lora_rank,
                                         bias=False,
                                         quant_config=quant_config,
                                         prefix=f"{prefix}.q_a_proj")
        self.q_a_layernorm = RMSNorm(self.q_lora_rank,
                                     eps=config.rms_norm_eps)
        self.q_b_proj = ColumnParallelLinear(q_lora_rank,
                                             self.num_heads *
                                             self.qk_head_dim,
                                             bias=False,
                                             quant_config=quant_config,
                                             prefix=f"{prefix}.q_b_proj")
    else:
        self.q_proj = ColumnParallelLinear(self.hidden_size,
                                           self.num_heads *
                                           self.qk_head_dim,
                                           bias=False,
                                           quant_config=quant_config,
                                           prefix=f"{prefix}.q_proj")

    self.kv_a_proj_with_mqa = ReplicatedLinear(
        self.hidden_size,
        self.kv_lora_rank + self.qk_rope_head_dim,
        bias=False,
        quant_config=quant_config,
        prefix=f"{prefix}.kv_a_proj_with_mqa")
    self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
                                  eps=config.rms_norm_eps)
    self.kv_b_proj = ColumnParallelLinear(
        self.kv_lora_rank,
        self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
        bias=False,
        quant_config=quant_config,
        prefix=f"{prefix}.kv_b_proj")
    # O projection.
    self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
                                    self.hidden_size,
                                    bias=False,
                                    quant_config=quant_config,
                                    prefix=f"{prefix}.o_proj")
    if rope_scaling:
        rope_scaling["rope_type"] = 'deepseek_yarn'

    self.rotary_emb = get_rope(qk_rope_head_dim,
                               rotary_dim=qk_rope_head_dim,
                               max_position=max_position_embeddings,
                               base=rope_theta,
                               rope_scaling=rope_scaling,
                               is_neox_style=False)

    if rope_scaling:
        mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
        scaling_factor = rope_scaling["factor"]
        mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
        self.scaling = self.scaling * mscale * mscale

    self.attn = Attention(self.num_local_heads,
                          self.qk_head_dim,
                          self.scaling,
                          num_kv_heads=self.num_local_heads,
                          cache_config=cache_config,
                          quant_config=quant_config,
                          prefix=f"{prefix}.attn")

forward

forward(positions: Tensor, hidden_states: Tensor) -> Tensor
Source code in vllm/model_executor/models/deepseek_v2.py
def forward(
    self,
    positions: torch.Tensor,
    hidden_states: torch.Tensor,
) -> torch.Tensor:
    if self.q_lora_rank is not None:
        q = self.q_a_proj(hidden_states)[0]
        q = self.q_a_layernorm(q)
        q = self.q_b_proj(q)[0].view(-1, self.num_local_heads,
                                     self.qk_head_dim)
    else:
        q = self.q_proj(hidden_states)[0].view(-1, self.num_local_heads,
                                               self.qk_head_dim)
    q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim],
                           dim=-1)
    latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
    kv_a, _ = latent_cache.split(
        [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
    latent_cache = latent_cache.unsqueeze(1)
    kv_a = self.kv_a_layernorm(kv_a.contiguous())
    kv = self.kv_b_proj(kv_a)[0]
    kv = kv.view(-1, self.num_local_heads,
                 self.qk_nope_head_dim + self.v_head_dim)
    k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
    k_pe = latent_cache[:, :, self.kv_lora_rank:]

    q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)

    q[..., self.qk_nope_head_dim:] = q_pe
    k = torch.empty_like(q)
    k[..., :self.qk_nope_head_dim] = k_nope
    k[..., self.qk_nope_head_dim:] = k_pe
    # padding value to qk_head_dim for alignment
    v = torch.nn.functional.pad(
        v, [0, self.qk_head_dim - self.v_head_dim],
        value=0).view(-1, self.num_local_heads * self.qk_head_dim)
    attn_output = self.attn(q, k, v)
    attn_output = attn_output.view(
        -1, self.num_local_heads,
        self.qk_head_dim)[..., :self.v_head_dim].reshape(
            -1, self.num_local_heads * self.v_head_dim)
    output, _ = self.o_proj(attn_output)
    return output

DeepseekV2DecoderLayer

Bases: Module

Source code in vllm/model_executor/models/deepseek_v2.py
class DeepseekV2DecoderLayer(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        prefix: str,
        model_config: ModelConfig,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        enable_eplb: bool = False,
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
        # DecoderLayers are created with `make_layers` which passes the prefix
        # with the layer's index.
        layer_idx = int(prefix.split(sep='.')[-1])
        self.layer_idx = layer_idx
        if model_config.use_mla:
            attn_cls = DeepseekV2MLAAttention
        else:
            attn_cls = DeepseekV2Attention
        self.self_attn = attn_cls(
            config=config,
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            qk_nope_head_dim=config.qk_nope_head_dim,
            qk_rope_head_dim=config.qk_rope_head_dim,
            v_head_dim=config.v_head_dim,
            q_lora_rank=config.q_lora_rank
            if hasattr(config, "q_lora_rank") else None,
            kv_lora_rank=config.kv_lora_rank,
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
        )

        if (config.n_routed_experts is not None
                and layer_idx >= config.first_k_dense_replace
                and layer_idx % config.moe_layer_freq == 0):
            self.mlp = DeepseekV2MoE(
                config=config,
                quant_config=quant_config,
                prefix=f"{prefix}.mlp",
                enable_eplb=enable_eplb,
            )
        else:
            self.mlp = DeepseekV2MLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
                prefix=f"{prefix}.mlp",
            )
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)
        self.routed_scaling_factor = config.routed_scaling_factor

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
    ) -> torch.Tensor:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        if hidden_states.dtype == torch.float16:
            # Fix FP16 overflow
            # We scale both hidden_states and residual before
            # rmsnorm, and rmsnorm result would not affect by scale.
            hidden_states *= 1. / self.routed_scaling_factor
            if self.layer_idx == 0:
                # The residual is shared by all layers, we only scale it on
                # first layer.
                residual *= 1. / self.routed_scaling_factor

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
        hidden_states = self.mlp(hidden_states)

        if isinstance(self.mlp,
                      DeepseekV2MLP) and hidden_states.dtype == torch.float16:
            # Fix FP16 overflow
            # Scaling the DeepseekV2MLP output, it is the input of
            # input_layernorm of next decoder layer.
            # The scaling of DeepseekV2MOE output would be done in the forward
            # of DeepseekV2MOE
            hidden_states *= 1. / self.routed_scaling_factor

        return hidden_states, residual

hidden_size instance-attribute

hidden_size = hidden_size

input_layernorm instance-attribute

input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps)

layer_idx instance-attribute

layer_idx = layer_idx

mlp instance-attribute

mlp = DeepseekV2MoE(
    config=config,
    quant_config=quant_config,
    prefix=f"{prefix}.mlp",
    enable_eplb=enable_eplb,
)

post_attention_layernorm instance-attribute

post_attention_layernorm = RMSNorm(
    hidden_size, eps=rms_norm_eps
)

routed_scaling_factor instance-attribute

routed_scaling_factor = routed_scaling_factor

self_attn instance-attribute

self_attn = attn_cls(
    config=config,
    hidden_size=hidden_size,
    num_heads=num_attention_heads,
    qk_nope_head_dim=qk_nope_head_dim,
    qk_rope_head_dim=qk_rope_head_dim,
    v_head_dim=v_head_dim,
    q_lora_rank=q_lora_rank
    if hasattr(config, "q_lora_rank")
    else None,
    kv_lora_rank=kv_lora_rank,
    rope_theta=rope_theta,
    rope_scaling=rope_scaling,
    max_position_embeddings=max_position_embeddings,
    cache_config=cache_config,
    quant_config=quant_config,
    prefix=f"{prefix}.self_attn",
)

__init__

__init__(
    config: PretrainedConfig,
    prefix: str,
    model_config: ModelConfig,
    cache_config: Optional[CacheConfig] = None,
    quant_config: Optional[QuantizationConfig] = None,
    enable_eplb: bool = False,
) -> None
Source code in vllm/model_executor/models/deepseek_v2.py
def __init__(
    self,
    config: PretrainedConfig,
    prefix: str,
    model_config: ModelConfig,
    cache_config: Optional[CacheConfig] = None,
    quant_config: Optional[QuantizationConfig] = None,
    enable_eplb: bool = False,
) -> None:
    super().__init__()
    self.hidden_size = config.hidden_size
    rope_theta = getattr(config, "rope_theta", 10000)
    rope_scaling = getattr(config, "rope_scaling", None)
    max_position_embeddings = getattr(config, "max_position_embeddings",
                                      8192)
    # DecoderLayers are created with `make_layers` which passes the prefix
    # with the layer's index.
    layer_idx = int(prefix.split(sep='.')[-1])
    self.layer_idx = layer_idx
    if model_config.use_mla:
        attn_cls = DeepseekV2MLAAttention
    else:
        attn_cls = DeepseekV2Attention
    self.self_attn = attn_cls(
        config=config,
        hidden_size=self.hidden_size,
        num_heads=config.num_attention_heads,
        qk_nope_head_dim=config.qk_nope_head_dim,
        qk_rope_head_dim=config.qk_rope_head_dim,
        v_head_dim=config.v_head_dim,
        q_lora_rank=config.q_lora_rank
        if hasattr(config, "q_lora_rank") else None,
        kv_lora_rank=config.kv_lora_rank,
        rope_theta=rope_theta,
        rope_scaling=rope_scaling,
        max_position_embeddings=max_position_embeddings,
        cache_config=cache_config,
        quant_config=quant_config,
        prefix=f"{prefix}.self_attn",
    )

    if (config.n_routed_experts is not None
            and layer_idx >= config.first_k_dense_replace
            and layer_idx % config.moe_layer_freq == 0):
        self.mlp = DeepseekV2MoE(
            config=config,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
            enable_eplb=enable_eplb,
        )
    else:
        self.mlp = DeepseekV2MLP(
            hidden_size=config.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
        )
    self.input_layernorm = RMSNorm(config.hidden_size,
                                   eps=config.rms_norm_eps)
    self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                            eps=config.rms_norm_eps)
    self.routed_scaling_factor = config.routed_scaling_factor

forward

forward(
    positions: Tensor,
    hidden_states: Tensor,
    residual: Optional[Tensor],
) -> Tensor
Source code in vllm/model_executor/models/deepseek_v2.py
def forward(
    self,
    positions: torch.Tensor,
    hidden_states: torch.Tensor,
    residual: Optional[torch.Tensor],
) -> torch.Tensor:
    # Self Attention
    if residual is None:
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
    else:
        hidden_states, residual = self.input_layernorm(
            hidden_states, residual)
    hidden_states = self.self_attn(
        positions=positions,
        hidden_states=hidden_states,
    )

    if hidden_states.dtype == torch.float16:
        # Fix FP16 overflow
        # We scale both hidden_states and residual before
        # rmsnorm, and rmsnorm result would not affect by scale.
        hidden_states *= 1. / self.routed_scaling_factor
        if self.layer_idx == 0:
            # The residual is shared by all layers, we only scale it on
            # first layer.
            residual *= 1. / self.routed_scaling_factor

    # Fully Connected
    hidden_states, residual = self.post_attention_layernorm(
        hidden_states, residual)
    hidden_states = self.mlp(hidden_states)

    if isinstance(self.mlp,
                  DeepseekV2MLP) and hidden_states.dtype == torch.float16:
        # Fix FP16 overflow
        # Scaling the DeepseekV2MLP output, it is the input of
        # input_layernorm of next decoder layer.
        # The scaling of DeepseekV2MOE output would be done in the forward
        # of DeepseekV2MOE
        hidden_states *= 1. / self.routed_scaling_factor

    return hidden_states, residual

DeepseekV2ForCausalLM

Bases: Module, SupportsPP, MixtureOfExperts

Source code in vllm/model_executor/models/deepseek_v2.py
class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        self.config = config
        self.quant_config = quant_config
        self.model = DeepseekV2Model(vllm_config=vllm_config,
                                     prefix=maybe_prefix(prefix, "model"))
        if get_pp_group().is_last_rank:
            self.lm_head = ParallelLMHead(config.vocab_size,
                                          config.hidden_size,
                                          quant_config=quant_config)
        else:
            self.lm_head = PPMissingLayer()
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
        self.expert_weights = []

        # Set MoE hyperparameters
        self.num_moe_layers = (config.num_hidden_layers -
                               config.first_k_dense_replace)
        self.num_expert_groups = config.n_group

        self.moe_layers: list[FusedMoE] = []
        for layer in self.model.layers:
            assert isinstance(layer, DeepseekV2DecoderLayer)
            if isinstance(layer.mlp, DeepseekV2MoE):
                self.moe_layers.append(layer.mlp.experts)

        # Pick last one layer since the first ones may be dense layers.
        example_moe = typing.cast(
            DeepseekV2MoE, self.model.layers[config.num_hidden_layers - 1].mlp)
        self.num_logical_experts = example_moe.n_logical_experts
        self.num_physical_experts = example_moe.n_physical_experts
        self.num_local_physical_experts = example_moe.n_local_physical_experts
        self.num_routed_experts = example_moe.n_routed_experts
        self.num_shared_experts = example_moe.n_shared_experts
        self.num_redundant_experts = example_moe.n_redundant_experts

    def set_eplb_state(
        self,
        expert_load_view: torch.Tensor,
        logical_to_physical_map: torch.Tensor,
        logical_replica_count: torch.Tensor,
    ) -> None:
        for layer_idx, layer in enumerate(self.moe_layers):
            # Register the expert weights.
            self.expert_weights.append(layer.get_expert_weights())
            layer.set_eplb_state(
                moe_layer_idx=layer_idx,
                expert_load_view=expert_load_view,
                logical_to_physical_map=logical_to_physical_map,
                logical_replica_count=logical_replica_count,
            )

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        hidden_states = self.model(input_ids, positions, intermediate_tensors,
                                   inputs_embeds)
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
        logits = self.logits_processor(self.lm_head, hidden_states,
                                       sampling_metadata)
        return logits

    def make_empty_intermediate_tensors(
            self, batch_size: int, dtype: torch.dtype,
            device: torch.device) -> IntermediateTensors:
        return IntermediateTensors({
            "hidden_states":
            torch.zeros((batch_size, self.config.hidden_size),
                        dtype=dtype,
                        device=device),
            "residual":
            torch.zeros((batch_size, self.config.hidden_size),
                        dtype=dtype,
                        device=device),
        })

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]

        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
        expert_params_mapping = FusedMoE.make_expert_params_mapping(
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
            num_experts=self.config.n_routed_experts,
            num_redundant_experts=self.num_redundant_experts)

        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue

            spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
            if spec_layer is not None:
                continue  # skip spec decode layers for main model

            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                # Skip non-stacked layers and experts (experts handled below).
                if weight_name not in name:
                    continue
                # We have mlp.experts[0].gate_proj in the checkpoint.
                # Since we handle the experts below in expert_params_mapping,
                # we need to skip here BEFORE we update the name, otherwise
                # name will be updated to mlp.experts[0].gate_up_proj, which
                # will then be updated below in expert_params_mapping
                # for mlp.experts[0].gate_gate_up_proj, which breaks load.
                if (("mlp.experts." in name) and name not in params_dict):
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                is_expert_weight = False
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
                    if weight_name not in name:
                        continue

                    # Anyway, this is an expert weight and should not be
                    # attempted to load as other weights later
                    is_expert_weight = True

                    # Do not modify `name` since the loop may continue here
                    # Instead, create a new variable
                    name_mapped = name.replace(weight_name, param_name)

                    if is_pp_missing_parameter(name_mapped, self):
                        continue

                    param = params_dict[name_mapped]
                    # We should ask the weight loader to return success or not
                    # here since otherwise we may skip experts with other
                    # available replicas.
                    weight_loader = typing.cast(Callable[..., bool],
                                                param.weight_loader)
                    success = weight_loader(param,
                                            loaded_weight,
                                            name_mapped,
                                            shard_id=shard_id,
                                            expert_id=expert_id,
                                            return_success=True)
                    if success:
                        name = name_mapped
                        break
                else:
                    if is_expert_weight:
                        # We've checked that this is an expert weight
                        # However it's not mapped locally to this rank
                        # So we simply skip it
                        continue

                    # Skip loading extra bias for GPTQ models.
                    if name.endswith(".bias") and name not in params_dict:
                        continue

                    # Remapping the name of FP8 kv-scale.
                    name = maybe_remap_kv_scale_name(name, params_dict)
                    if name is None:
                        continue

                    if is_pp_missing_parameter(name, self):
                        continue

                    param = params_dict[name]
                    weight_loader = getattr(param, "weight_loader",
                                            default_weight_loader)
                    weight_loader(param, loaded_weight)
            loaded_params.add(name)

        return loaded_params

config instance-attribute

config = config

expert_weights instance-attribute

expert_weights = []

lm_head instance-attribute

lm_head = ParallelLMHead(
    vocab_size, hidden_size, quant_config=quant_config
)

logits_processor instance-attribute

logits_processor = LogitsProcessor(vocab_size)

model instance-attribute

model = DeepseekV2Model(
    vllm_config=vllm_config,
    prefix=maybe_prefix(prefix, "model"),
)

moe_layers instance-attribute

moe_layers: list[FusedMoE] = []

num_expert_groups instance-attribute

num_expert_groups = n_group

num_local_physical_experts instance-attribute

num_local_physical_experts = n_local_physical_experts

num_logical_experts instance-attribute

num_logical_experts = n_logical_experts

num_moe_layers instance-attribute

num_moe_layers = num_hidden_layers - first_k_dense_replace

num_physical_experts instance-attribute

num_physical_experts = n_physical_experts

num_redundant_experts instance-attribute

num_redundant_experts = n_redundant_experts

num_routed_experts instance-attribute

num_routed_experts = n_routed_experts

num_shared_experts instance-attribute

num_shared_experts = n_shared_experts

quant_config instance-attribute

quant_config = quant_config

__init__

__init__(*, vllm_config: VllmConfig, prefix: str = '')
Source code in vllm/model_executor/models/deepseek_v2.py
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
    super().__init__()
    config = vllm_config.model_config.hf_config
    quant_config = vllm_config.quant_config
    self.config = config
    self.quant_config = quant_config
    self.model = DeepseekV2Model(vllm_config=vllm_config,
                                 prefix=maybe_prefix(prefix, "model"))
    if get_pp_group().is_last_rank:
        self.lm_head = ParallelLMHead(config.vocab_size,
                                      config.hidden_size,
                                      quant_config=quant_config)
    else:
        self.lm_head = PPMissingLayer()
    self.logits_processor = LogitsProcessor(config.vocab_size)
    self.make_empty_intermediate_tensors = (
        self.model.make_empty_intermediate_tensors)
    self.expert_weights = []

    # Set MoE hyperparameters
    self.num_moe_layers = (config.num_hidden_layers -
                           config.first_k_dense_replace)
    self.num_expert_groups = config.n_group

    self.moe_layers: list[FusedMoE] = []
    for layer in self.model.layers:
        assert isinstance(layer, DeepseekV2DecoderLayer)
        if isinstance(layer.mlp, DeepseekV2MoE):
            self.moe_layers.append(layer.mlp.experts)

    # Pick last one layer since the first ones may be dense layers.
    example_moe = typing.cast(
        DeepseekV2MoE, self.model.layers[config.num_hidden_layers - 1].mlp)
    self.num_logical_experts = example_moe.n_logical_experts
    self.num_physical_experts = example_moe.n_physical_experts
    self.num_local_physical_experts = example_moe.n_local_physical_experts
    self.num_routed_experts = example_moe.n_routed_experts
    self.num_shared_experts = example_moe.n_shared_experts
    self.num_redundant_experts = example_moe.n_redundant_experts

compute_logits

compute_logits(
    hidden_states: Tensor,
    sampling_metadata: SamplingMetadata,
) -> Optional[Tensor]
Source code in vllm/model_executor/models/deepseek_v2.py
def compute_logits(
    self,
    hidden_states: torch.Tensor,
    sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
    logits = self.logits_processor(self.lm_head, hidden_states,
                                   sampling_metadata)
    return logits

forward

forward(
    input_ids: Tensor,
    positions: Tensor,
    intermediate_tensors: Optional[
        IntermediateTensors
    ] = None,
    inputs_embeds: Optional[Tensor] = None,
) -> Union[Tensor, IntermediateTensors]
Source code in vllm/model_executor/models/deepseek_v2.py
def forward(
    self,
    input_ids: torch.Tensor,
    positions: torch.Tensor,
    intermediate_tensors: Optional[IntermediateTensors] = None,
    inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
    hidden_states = self.model(input_ids, positions, intermediate_tensors,
                               inputs_embeds)
    return hidden_states

get_input_embeddings

get_input_embeddings(input_ids: Tensor) -> Tensor
Source code in vllm/model_executor/models/deepseek_v2.py
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
    return self.model.get_input_embeddings(input_ids)

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]
Source code in vllm/model_executor/models/deepseek_v2.py
def load_weights(self, weights: Iterable[tuple[str,
                                               torch.Tensor]]) -> set[str]:
    stacked_params_mapping = [
        # (param_name, shard_name, shard_id)
        ("gate_up_proj", "gate_proj", 0),
        ("gate_up_proj", "up_proj", 1),
    ]

    # Params for weights, fp8 weight scales, fp8 activation scales
    # (param_name, weight_name, expert_id, shard_id)
    expert_params_mapping = FusedMoE.make_expert_params_mapping(
        ckpt_gate_proj_name="gate_proj",
        ckpt_down_proj_name="down_proj",
        ckpt_up_proj_name="up_proj",
        num_experts=self.config.n_routed_experts,
        num_redundant_experts=self.num_redundant_experts)

    params_dict = dict(self.named_parameters())
    loaded_params: set[str] = set()
    for name, loaded_weight in weights:
        if "rotary_emb.inv_freq" in name:
            continue

        spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
        if spec_layer is not None:
            continue  # skip spec decode layers for main model

        for (param_name, weight_name, shard_id) in stacked_params_mapping:
            # Skip non-stacked layers and experts (experts handled below).
            if weight_name not in name:
                continue
            # We have mlp.experts[0].gate_proj in the checkpoint.
            # Since we handle the experts below in expert_params_mapping,
            # we need to skip here BEFORE we update the name, otherwise
            # name will be updated to mlp.experts[0].gate_up_proj, which
            # will then be updated below in expert_params_mapping
            # for mlp.experts[0].gate_gate_up_proj, which breaks load.
            if (("mlp.experts." in name) and name not in params_dict):
                continue
            name = name.replace(weight_name, param_name)
            # Skip loading extra bias for GPTQ models.
            if name.endswith(".bias") and name not in params_dict:
                continue

            if is_pp_missing_parameter(name, self):
                continue

            param = params_dict[name]
            weight_loader = param.weight_loader
            weight_loader(param, loaded_weight, shard_id)
            break
        else:
            is_expert_weight = False
            for mapping in expert_params_mapping:
                param_name, weight_name, expert_id, shard_id = mapping
                if weight_name not in name:
                    continue

                # Anyway, this is an expert weight and should not be
                # attempted to load as other weights later
                is_expert_weight = True

                # Do not modify `name` since the loop may continue here
                # Instead, create a new variable
                name_mapped = name.replace(weight_name, param_name)

                if is_pp_missing_parameter(name_mapped, self):
                    continue

                param = params_dict[name_mapped]
                # We should ask the weight loader to return success or not
                # here since otherwise we may skip experts with other
                # available replicas.
                weight_loader = typing.cast(Callable[..., bool],
                                            param.weight_loader)
                success = weight_loader(param,
                                        loaded_weight,
                                        name_mapped,
                                        shard_id=shard_id,
                                        expert_id=expert_id,
                                        return_success=True)
                if success:
                    name = name_mapped
                    break
            else:
                if is_expert_weight:
                    # We've checked that this is an expert weight
                    # However it's not mapped locally to this rank
                    # So we simply skip it
                    continue

                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                # Remapping the name of FP8 kv-scale.
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
        loaded_params.add(name)

    return loaded_params

make_empty_intermediate_tensors

make_empty_intermediate_tensors(
    batch_size: int, dtype: dtype, device: device
) -> IntermediateTensors
Source code in vllm/model_executor/models/deepseek_v2.py
def make_empty_intermediate_tensors(
        self, batch_size: int, dtype: torch.dtype,
        device: torch.device) -> IntermediateTensors:
    return IntermediateTensors({
        "hidden_states":
        torch.zeros((batch_size, self.config.hidden_size),
                    dtype=dtype,
                    device=device),
        "residual":
        torch.zeros((batch_size, self.config.hidden_size),
                    dtype=dtype,
                    device=device),
    })

set_eplb_state

set_eplb_state(
    expert_load_view: Tensor,
    logical_to_physical_map: Tensor,
    logical_replica_count: Tensor,
) -> None
Source code in vllm/model_executor/models/deepseek_v2.py
def set_eplb_state(
    self,
    expert_load_view: torch.Tensor,
    logical_to_physical_map: torch.Tensor,
    logical_replica_count: torch.Tensor,
) -> None:
    for layer_idx, layer in enumerate(self.moe_layers):
        # Register the expert weights.
        self.expert_weights.append(layer.get_expert_weights())
        layer.set_eplb_state(
            moe_layer_idx=layer_idx,
            expert_load_view=expert_load_view,
            logical_to_physical_map=logical_to_physical_map,
            logical_replica_count=logical_replica_count,
        )

DeepseekV2MLAAttention

Bases: Module

Main reference: DeepseekV2 paper, and FlashInfer Implementation (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).

For more info see MLACommonImpl in: vllm/attention/backends/mla/utils.py

Source code in vllm/model_executor/models/deepseek_v2.py
class DeepseekV2MLAAttention(nn.Module):
    """
    Main reference: DeepseekV2 paper, and FlashInfer Implementation
    (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).

    For more info see MLACommonImpl in: vllm/attention/backends/mla/utils.py
    """

    def __init__(
        self,
        config: PretrainedConfig,
        hidden_size: int,
        num_heads: int,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
        q_lora_rank: Optional[int],
        kv_lora_rank: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[dict[str, Any]] = None,
        max_position_embeddings: int = 8192,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
        self.v_head_dim = v_head_dim

        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank

        self.num_heads = num_heads
        tp_size = get_tensor_model_parallel_world_size()
        assert num_heads % tp_size == 0
        self.num_local_heads = num_heads // tp_size

        self.scaling = self.qk_head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

        if self.q_lora_rank is not None:
            self.q_a_proj = ReplicatedLinear(self.hidden_size,
                                             self.q_lora_rank,
                                             bias=False,
                                             quant_config=quant_config,
                                             prefix=f"{prefix}.q_a_proj")
            self.q_a_layernorm = RMSNorm(self.q_lora_rank,
                                         eps=config.rms_norm_eps)
            self.q_b_proj = ColumnParallelLinear(q_lora_rank,
                                                 self.num_heads *
                                                 self.qk_head_dim,
                                                 bias=False,
                                                 quant_config=quant_config,
                                                 prefix=f"{prefix}.q_b_proj")
        else:
            self.q_proj = ColumnParallelLinear(self.hidden_size,
                                               self.num_heads *
                                               self.qk_head_dim,
                                               bias=False,
                                               quant_config=quant_config,
                                               prefix=f"{prefix}.q_proj")

        self.kv_a_proj_with_mqa = ReplicatedLinear(
            self.hidden_size,
            self.kv_lora_rank + self.qk_rope_head_dim,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.kv_a_proj_with_mqa")
        self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
                                      eps=config.rms_norm_eps)
        self.kv_b_proj = ColumnParallelLinear(
            self.kv_lora_rank,
            self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.kv_b_proj")
        self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
                                        self.hidden_size,
                                        bias=False,
                                        quant_config=quant_config,
                                        prefix=f"{prefix}.o_proj")

        if rope_scaling:
            rope_scaling["rope_type"] = 'deepseek_yarn'
        self.rotary_emb = get_rope(qk_rope_head_dim,
                                   rotary_dim=qk_rope_head_dim,
                                   max_position=max_position_embeddings,
                                   base=rope_theta,
                                   rope_scaling=rope_scaling,
                                   is_neox_style=False)
        if rope_scaling:
            mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
            scaling_factor = rope_scaling["factor"]
            mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
            self.scaling = self.scaling * mscale * mscale

        # In the MLA backend, kv_cache includes both k_c and
        # pe (i.e. decoupled position embeddings). In particular,
        # the concat_and_cache_mla op requires
        #     k_c.size(1) + k_pe.size(1) == kv_cache.size(2)
        # i.e.
        #     kv_lora_rank + qk_rope_head_dim == head_size
        self.mla_attn = Attention(
            num_heads=self.num_local_heads,
            head_size=self.kv_lora_rank + self.qk_rope_head_dim,
            scale=self.scaling,
            num_kv_heads=1,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
            use_mla=True,
            # MLA Args
            q_lora_rank=self.q_lora_rank,
            kv_lora_rank=self.kv_lora_rank,
            qk_nope_head_dim=self.qk_nope_head_dim,
            qk_rope_head_dim=self.qk_rope_head_dim,
            qk_head_dim=self.qk_head_dim,
            v_head_dim=self.v_head_dim,
            kv_b_proj=self.kv_b_proj,
        )

        self.prefix = prefix
        self.debug_layer_idx = int(self.prefix.split(".")[-2])

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        if self.q_lora_rank is not None:
            q_c = self.q_a_proj(hidden_states)[0]
            q_c = self.q_a_layernorm(q_c)
            q = self.q_b_proj(q_c)[0]
        else:
            q = self.q_proj(hidden_states)[0]
        kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
            [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
        kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())

        q = q.view(-1, self.num_local_heads, self.qk_head_dim)
        # Add head dim of 1 to k_pe
        k_pe = k_pe.unsqueeze(1)

        q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb(
            positions, q[..., self.qk_nope_head_dim:], k_pe)

        attn_out = self.mla_attn(
            q,
            kv_c_normed,
            k_pe,
            output_shape=(hidden_states.shape[0],
                          self.num_local_heads * self.v_head_dim))
        return self.o_proj(attn_out)[0]

debug_layer_idx instance-attribute

debug_layer_idx = int(split('.')[-2])

hidden_size instance-attribute

hidden_size = hidden_size

kv_a_layernorm instance-attribute

kv_a_layernorm = RMSNorm(kv_lora_rank, eps=rms_norm_eps)

kv_a_proj_with_mqa instance-attribute

kv_a_proj_with_mqa = ReplicatedLinear(
    hidden_size,
    kv_lora_rank + qk_rope_head_dim,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.kv_a_proj_with_mqa",
)

kv_b_proj instance-attribute

kv_b_proj = ColumnParallelLinear(
    kv_lora_rank,
    num_heads * qk_nope_head_dim + v_head_dim,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.kv_b_proj",
)

kv_lora_rank instance-attribute

kv_lora_rank = kv_lora_rank

max_position_embeddings instance-attribute

max_position_embeddings = max_position_embeddings

mla_attn instance-attribute

mla_attn = Attention(
    num_heads=num_local_heads,
    head_size=kv_lora_rank + qk_rope_head_dim,
    scale=scaling,
    num_kv_heads=1,
    cache_config=cache_config,
    quant_config=quant_config,
    prefix=f"{prefix}.attn",
    use_mla=True,
    q_lora_rank=q_lora_rank,
    kv_lora_rank=kv_lora_rank,
    qk_nope_head_dim=qk_nope_head_dim,
    qk_rope_head_dim=qk_rope_head_dim,
    qk_head_dim=qk_head_dim,
    v_head_dim=v_head_dim,
    kv_b_proj=kv_b_proj,
)

num_heads instance-attribute

num_heads = num_heads

num_local_heads instance-attribute

num_local_heads = num_heads // tp_size

o_proj instance-attribute

o_proj = RowParallelLinear(
    num_heads * v_head_dim,
    hidden_size,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.o_proj",
)

prefix instance-attribute

prefix = prefix

q_a_layernorm instance-attribute

q_a_layernorm = RMSNorm(q_lora_rank, eps=rms_norm_eps)

q_a_proj instance-attribute

q_a_proj = ReplicatedLinear(
    hidden_size,
    q_lora_rank,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.q_a_proj",
)

q_b_proj instance-attribute

q_b_proj = ColumnParallelLinear(
    q_lora_rank,
    num_heads * qk_head_dim,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.q_b_proj",
)

q_lora_rank instance-attribute

q_lora_rank = q_lora_rank

q_proj instance-attribute

q_proj = ColumnParallelLinear(
    hidden_size,
    num_heads * qk_head_dim,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.q_proj",
)

qk_head_dim instance-attribute

qk_head_dim = qk_nope_head_dim + qk_rope_head_dim

qk_nope_head_dim instance-attribute

qk_nope_head_dim = qk_nope_head_dim

qk_rope_head_dim instance-attribute

qk_rope_head_dim = qk_rope_head_dim

rope_theta instance-attribute

rope_theta = rope_theta

rotary_emb instance-attribute

rotary_emb = get_rope(
    qk_rope_head_dim,
    rotary_dim=qk_rope_head_dim,
    max_position=max_position_embeddings,
    base=rope_theta,
    rope_scaling=rope_scaling,
    is_neox_style=False,
)

scaling instance-attribute

scaling = qk_head_dim ** -0.5

v_head_dim instance-attribute

v_head_dim = v_head_dim

__init__

__init__(
    config: PretrainedConfig,
    hidden_size: int,
    num_heads: int,
    qk_nope_head_dim: int,
    qk_rope_head_dim: int,
    v_head_dim: int,
    q_lora_rank: Optional[int],
    kv_lora_rank: int,
    rope_theta: float = 10000,
    rope_scaling: Optional[dict[str, Any]] = None,
    max_position_embeddings: int = 8192,
    cache_config: Optional[CacheConfig] = None,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
) -> None
Source code in vllm/model_executor/models/deepseek_v2.py
def __init__(
    self,
    config: PretrainedConfig,
    hidden_size: int,
    num_heads: int,
    qk_nope_head_dim: int,
    qk_rope_head_dim: int,
    v_head_dim: int,
    q_lora_rank: Optional[int],
    kv_lora_rank: int,
    rope_theta: float = 10000,
    rope_scaling: Optional[dict[str, Any]] = None,
    max_position_embeddings: int = 8192,
    cache_config: Optional[CacheConfig] = None,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
) -> None:
    super().__init__()
    self.hidden_size = hidden_size
    self.qk_nope_head_dim = qk_nope_head_dim
    self.qk_rope_head_dim = qk_rope_head_dim
    self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
    self.v_head_dim = v_head_dim

    self.q_lora_rank = q_lora_rank
    self.kv_lora_rank = kv_lora_rank

    self.num_heads = num_heads
    tp_size = get_tensor_model_parallel_world_size()
    assert num_heads % tp_size == 0
    self.num_local_heads = num_heads // tp_size

    self.scaling = self.qk_head_dim**-0.5
    self.rope_theta = rope_theta
    self.max_position_embeddings = max_position_embeddings

    if self.q_lora_rank is not None:
        self.q_a_proj = ReplicatedLinear(self.hidden_size,
                                         self.q_lora_rank,
                                         bias=False,
                                         quant_config=quant_config,
                                         prefix=f"{prefix}.q_a_proj")
        self.q_a_layernorm = RMSNorm(self.q_lora_rank,
                                     eps=config.rms_norm_eps)
        self.q_b_proj = ColumnParallelLinear(q_lora_rank,
                                             self.num_heads *
                                             self.qk_head_dim,
                                             bias=False,
                                             quant_config=quant_config,
                                             prefix=f"{prefix}.q_b_proj")
    else:
        self.q_proj = ColumnParallelLinear(self.hidden_size,
                                           self.num_heads *
                                           self.qk_head_dim,
                                           bias=False,
                                           quant_config=quant_config,
                                           prefix=f"{prefix}.q_proj")

    self.kv_a_proj_with_mqa = ReplicatedLinear(
        self.hidden_size,
        self.kv_lora_rank + self.qk_rope_head_dim,
        bias=False,
        quant_config=quant_config,
        prefix=f"{prefix}.kv_a_proj_with_mqa")
    self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
                                  eps=config.rms_norm_eps)
    self.kv_b_proj = ColumnParallelLinear(
        self.kv_lora_rank,
        self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
        bias=False,
        quant_config=quant_config,
        prefix=f"{prefix}.kv_b_proj")
    self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
                                    self.hidden_size,
                                    bias=False,
                                    quant_config=quant_config,
                                    prefix=f"{prefix}.o_proj")

    if rope_scaling:
        rope_scaling["rope_type"] = 'deepseek_yarn'
    self.rotary_emb = get_rope(qk_rope_head_dim,
                               rotary_dim=qk_rope_head_dim,
                               max_position=max_position_embeddings,
                               base=rope_theta,
                               rope_scaling=rope_scaling,
                               is_neox_style=False)
    if rope_scaling:
        mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
        scaling_factor = rope_scaling["factor"]
        mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
        self.scaling = self.scaling * mscale * mscale

    # In the MLA backend, kv_cache includes both k_c and
    # pe (i.e. decoupled position embeddings). In particular,
    # the concat_and_cache_mla op requires
    #     k_c.size(1) + k_pe.size(1) == kv_cache.size(2)
    # i.e.
    #     kv_lora_rank + qk_rope_head_dim == head_size
    self.mla_attn = Attention(
        num_heads=self.num_local_heads,
        head_size=self.kv_lora_rank + self.qk_rope_head_dim,
        scale=self.scaling,
        num_kv_heads=1,
        cache_config=cache_config,
        quant_config=quant_config,
        prefix=f"{prefix}.attn",
        use_mla=True,
        # MLA Args
        q_lora_rank=self.q_lora_rank,
        kv_lora_rank=self.kv_lora_rank,
        qk_nope_head_dim=self.qk_nope_head_dim,
        qk_rope_head_dim=self.qk_rope_head_dim,
        qk_head_dim=self.qk_head_dim,
        v_head_dim=self.v_head_dim,
        kv_b_proj=self.kv_b_proj,
    )

    self.prefix = prefix
    self.debug_layer_idx = int(self.prefix.split(".")[-2])

forward

forward(positions: Tensor, hidden_states: Tensor) -> Tensor
Source code in vllm/model_executor/models/deepseek_v2.py
def forward(
    self,
    positions: torch.Tensor,
    hidden_states: torch.Tensor,
) -> torch.Tensor:
    if self.q_lora_rank is not None:
        q_c = self.q_a_proj(hidden_states)[0]
        q_c = self.q_a_layernorm(q_c)
        q = self.q_b_proj(q_c)[0]
    else:
        q = self.q_proj(hidden_states)[0]
    kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
        [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
    kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())

    q = q.view(-1, self.num_local_heads, self.qk_head_dim)
    # Add head dim of 1 to k_pe
    k_pe = k_pe.unsqueeze(1)

    q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb(
        positions, q[..., self.qk_nope_head_dim:], k_pe)

    attn_out = self.mla_attn(
        q,
        kv_c_normed,
        k_pe,
        output_shape=(hidden_states.shape[0],
                      self.num_local_heads * self.v_head_dim))
    return self.o_proj(attn_out)[0]

DeepseekV2MLP

Bases: Module

Source code in vllm/model_executor/models/deepseek_v2.py
class DeepseekV2MLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
        quant_config: Optional[QuantizationConfig] = None,
        reduce_results: bool = True,
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.gate_up_proj")
        self.down_proj = RowParallelLinear(intermediate_size,
                                           hidden_size,
                                           bias=False,
                                           quant_config=quant_config,
                                           reduce_results=reduce_results,
                                           prefix=f"{prefix}.down_proj")
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x

act_fn instance-attribute

act_fn = SiluAndMul()

down_proj instance-attribute

down_proj = RowParallelLinear(
    intermediate_size,
    hidden_size,
    bias=False,
    quant_config=quant_config,
    reduce_results=reduce_results,
    prefix=f"{prefix}.down_proj",
)

gate_up_proj instance-attribute

gate_up_proj = MergedColumnParallelLinear(
    hidden_size,
    [intermediate_size] * 2,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.gate_up_proj",
)

__init__

__init__(
    hidden_size: int,
    intermediate_size: int,
    hidden_act: str,
    quant_config: Optional[QuantizationConfig] = None,
    reduce_results: bool = True,
    prefix: str = "",
) -> None
Source code in vllm/model_executor/models/deepseek_v2.py
def __init__(
    self,
    hidden_size: int,
    intermediate_size: int,
    hidden_act: str,
    quant_config: Optional[QuantizationConfig] = None,
    reduce_results: bool = True,
    prefix: str = "",
) -> None:
    super().__init__()
    self.gate_up_proj = MergedColumnParallelLinear(
        hidden_size, [intermediate_size] * 2,
        bias=False,
        quant_config=quant_config,
        prefix=f"{prefix}.gate_up_proj")
    self.down_proj = RowParallelLinear(intermediate_size,
                                       hidden_size,
                                       bias=False,
                                       quant_config=quant_config,
                                       reduce_results=reduce_results,
                                       prefix=f"{prefix}.down_proj")
    if hidden_act != "silu":
        raise ValueError(f"Unsupported activation: {hidden_act}. "
                         "Only silu is supported for now.")
    self.act_fn = SiluAndMul()

forward

forward(x)
Source code in vllm/model_executor/models/deepseek_v2.py
def forward(self, x):
    gate_up, _ = self.gate_up_proj(x)
    x = self.act_fn(gate_up)
    x, _ = self.down_proj(x)
    return x

DeepseekV2MoE

Bases: Module

Source code in vllm/model_executor/models/deepseek_v2.py
class DeepseekV2MoE(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        enable_eplb: bool = False,
    ):
        super().__init__()
        self.tp_size = get_tensor_model_parallel_world_size()
        self.routed_scaling_factor = config.routed_scaling_factor

        self.ep_group = get_ep_group().device_group
        self.ep_rank = self.ep_group.rank()
        self.ep_size = self.ep_group.size()
        self.n_routed_experts: int = config.n_routed_experts
        self.n_shared_experts: int = config.n_shared_experts

        if config.hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {config.hidden_act}. "
                             "Only silu is supported for now.")

        self.gate = ReplicatedLinear(config.hidden_size,
                                     config.n_routed_experts,
                                     bias=False,
                                     quant_config=None,
                                     prefix=f"{prefix}.gate")
        if config.topk_method == "noaux_tc":
            self.gate.e_score_correction_bias = nn.Parameter(
                torch.empty(config.n_routed_experts))
        else:
            self.gate.e_score_correction_bias = None

        # Load balancing settings.
        vllm_config = get_current_vllm_config()
        parallel_config = vllm_config.parallel_config
        self.enable_eplb = enable_eplb

        self.n_redundant_experts = parallel_config.num_redundant_experts
        self.n_logical_experts = self.n_routed_experts
        self.n_physical_experts = (self.n_logical_experts +
                                   self.n_redundant_experts)
        self.n_local_physical_experts = self.n_physical_experts // self.ep_size

        self.physical_expert_start = (self.ep_rank *
                                      self.n_local_physical_experts)
        self.physical_expert_end = (self.physical_expert_start +
                                    self.n_local_physical_experts)

        self.experts = FusedMoE(
            num_experts=config.n_routed_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            intermediate_size=config.moe_intermediate_size,
            reduce_results=False,
            renormalize=config.norm_topk_prob,
            quant_config=quant_config,
            use_grouped_topk=True,
            num_expert_group=config.n_group,
            topk_group=config.topk_group,
            prefix=f"{prefix}.experts",
            scoring_func=config.scoring_func,
            e_score_correction_bias=self.gate.e_score_correction_bias,
            enable_eplb=self.enable_eplb,
            num_redundant_experts=self.n_redundant_experts)

        if config.n_shared_experts is not None:
            intermediate_size = (config.moe_intermediate_size *
                                 config.n_shared_experts)
            self.shared_experts = DeepseekV2MLP(
                hidden_size=config.hidden_size,
                intermediate_size=intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
                reduce_results=self.experts.must_reduce_shared_expert_outputs(
                ),
                prefix=f"{prefix}.shared_experts",
            )

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        num_tokens, hidden_dim = hidden_states.shape
        hidden_states = hidden_states.view(-1, hidden_dim)
        if self.n_shared_experts is not None:
            shared_output = self.shared_experts(hidden_states)
        # router_logits: (num_tokens, n_experts)
        router_logits, _ = self.gate(hidden_states)

        if hidden_states.dtype != torch.float16:
            final_hidden_states = self.experts(
                hidden_states=hidden_states,
                router_logits=router_logits) * self.routed_scaling_factor
        else:
            # Fix FP16 overflow
            # See DeepseekV2DecoderLayer for more details.
            final_hidden_states = self.experts(hidden_states=hidden_states,
                                               router_logits=router_logits)
        if shared_output is not None:
            if hidden_states.dtype != torch.float16:
                final_hidden_states = final_hidden_states + shared_output
            else:
                # Fix FP16 overflow
                # See DeepseekV2DecoderLayer for more details.
                final_hidden_states = final_hidden_states + shared_output \
                    * (1. / self.routed_scaling_factor)

        if self.tp_size > 1:
            final_hidden_states = (
                self.experts.maybe_all_reduce_tensor_model_parallel(
                    final_hidden_states))

        return final_hidden_states.view(num_tokens, hidden_dim)

enable_eplb instance-attribute

enable_eplb = enable_eplb

ep_group instance-attribute

ep_group = device_group

ep_rank instance-attribute

ep_rank = rank()

ep_size instance-attribute

ep_size = size()

experts instance-attribute

experts = FusedMoE(
    num_experts=n_routed_experts,
    top_k=num_experts_per_tok,
    hidden_size=hidden_size,
    intermediate_size=moe_intermediate_size,
    reduce_results=False,
    renormalize=norm_topk_prob,
    quant_config=quant_config,
    use_grouped_topk=True,
    num_expert_group=n_group,
    topk_group=topk_group,
    prefix=f"{prefix}.experts",
    scoring_func=scoring_func,
    e_score_correction_bias=e_score_correction_bias,
    enable_eplb=enable_eplb,
    num_redundant_experts=n_redundant_experts,
)

gate instance-attribute

gate = ReplicatedLinear(
    hidden_size,
    n_routed_experts,
    bias=False,
    quant_config=None,
    prefix=f"{prefix}.gate",
)

n_local_physical_experts instance-attribute

n_local_physical_experts = n_physical_experts // ep_size

n_logical_experts instance-attribute

n_logical_experts = n_routed_experts

n_physical_experts instance-attribute

n_physical_experts = n_logical_experts + n_redundant_experts

n_redundant_experts instance-attribute

n_redundant_experts = num_redundant_experts

n_routed_experts instance-attribute

n_routed_experts: int = n_routed_experts

n_shared_experts instance-attribute

n_shared_experts: int = n_shared_experts

physical_expert_end instance-attribute

physical_expert_end = (
    physical_expert_start + n_local_physical_experts
)

physical_expert_start instance-attribute

physical_expert_start = ep_rank * n_local_physical_experts

routed_scaling_factor instance-attribute

routed_scaling_factor = routed_scaling_factor

shared_experts instance-attribute

shared_experts = DeepseekV2MLP(
    hidden_size=hidden_size,
    intermediate_size=intermediate_size,
    hidden_act=hidden_act,
    quant_config=quant_config,
    reduce_results=must_reduce_shared_expert_outputs(),
    prefix=f"{prefix}.shared_experts",
)

tp_size instance-attribute

__init__

__init__(
    config: PretrainedConfig,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
    enable_eplb: bool = False,
)
Source code in vllm/model_executor/models/deepseek_v2.py
def __init__(
    self,
    config: PretrainedConfig,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
    enable_eplb: bool = False,
):
    super().__init__()
    self.tp_size = get_tensor_model_parallel_world_size()
    self.routed_scaling_factor = config.routed_scaling_factor

    self.ep_group = get_ep_group().device_group
    self.ep_rank = self.ep_group.rank()
    self.ep_size = self.ep_group.size()
    self.n_routed_experts: int = config.n_routed_experts
    self.n_shared_experts: int = config.n_shared_experts

    if config.hidden_act != "silu":
        raise ValueError(f"Unsupported activation: {config.hidden_act}. "
                         "Only silu is supported for now.")

    self.gate = ReplicatedLinear(config.hidden_size,
                                 config.n_routed_experts,
                                 bias=False,
                                 quant_config=None,
                                 prefix=f"{prefix}.gate")
    if config.topk_method == "noaux_tc":
        self.gate.e_score_correction_bias = nn.Parameter(
            torch.empty(config.n_routed_experts))
    else:
        self.gate.e_score_correction_bias = None

    # Load balancing settings.
    vllm_config = get_current_vllm_config()
    parallel_config = vllm_config.parallel_config
    self.enable_eplb = enable_eplb

    self.n_redundant_experts = parallel_config.num_redundant_experts
    self.n_logical_experts = self.n_routed_experts
    self.n_physical_experts = (self.n_logical_experts +
                               self.n_redundant_experts)
    self.n_local_physical_experts = self.n_physical_experts // self.ep_size

    self.physical_expert_start = (self.ep_rank *
                                  self.n_local_physical_experts)
    self.physical_expert_end = (self.physical_expert_start +
                                self.n_local_physical_experts)

    self.experts = FusedMoE(
        num_experts=config.n_routed_experts,
        top_k=config.num_experts_per_tok,
        hidden_size=config.hidden_size,
        intermediate_size=config.moe_intermediate_size,
        reduce_results=False,
        renormalize=config.norm_topk_prob,
        quant_config=quant_config,
        use_grouped_topk=True,
        num_expert_group=config.n_group,
        topk_group=config.topk_group,
        prefix=f"{prefix}.experts",
        scoring_func=config.scoring_func,
        e_score_correction_bias=self.gate.e_score_correction_bias,
        enable_eplb=self.enable_eplb,
        num_redundant_experts=self.n_redundant_experts)

    if config.n_shared_experts is not None:
        intermediate_size = (config.moe_intermediate_size *
                             config.n_shared_experts)
        self.shared_experts = DeepseekV2MLP(
            hidden_size=config.hidden_size,
            intermediate_size=intermediate_size,
            hidden_act=config.hidden_act,
            quant_config=quant_config,
            reduce_results=self.experts.must_reduce_shared_expert_outputs(
            ),
            prefix=f"{prefix}.shared_experts",
        )

forward

forward(hidden_states: Tensor) -> Tensor
Source code in vllm/model_executor/models/deepseek_v2.py
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
    num_tokens, hidden_dim = hidden_states.shape
    hidden_states = hidden_states.view(-1, hidden_dim)
    if self.n_shared_experts is not None:
        shared_output = self.shared_experts(hidden_states)
    # router_logits: (num_tokens, n_experts)
    router_logits, _ = self.gate(hidden_states)

    if hidden_states.dtype != torch.float16:
        final_hidden_states = self.experts(
            hidden_states=hidden_states,
            router_logits=router_logits) * self.routed_scaling_factor
    else:
        # Fix FP16 overflow
        # See DeepseekV2DecoderLayer for more details.
        final_hidden_states = self.experts(hidden_states=hidden_states,
                                           router_logits=router_logits)
    if shared_output is not None:
        if hidden_states.dtype != torch.float16:
            final_hidden_states = final_hidden_states + shared_output
        else:
            # Fix FP16 overflow
            # See DeepseekV2DecoderLayer for more details.
            final_hidden_states = final_hidden_states + shared_output \
                * (1. / self.routed_scaling_factor)

    if self.tp_size > 1:
        final_hidden_states = (
            self.experts.maybe_all_reduce_tensor_model_parallel(
                final_hidden_states))

    return final_hidden_states.view(num_tokens, hidden_dim)

DeepseekV2Model

Bases: Module

Source code in vllm/model_executor/models/deepseek_v2.py
@support_torch_compile
class DeepseekV2Model(nn.Module):

    fall_back_to_pt_during_load = False

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        config = vllm_config.model_config.hf_config
        model_config = vllm_config.model_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        enable_eplb = vllm_config.parallel_config.enable_eplb
        self.config = config

        self.vocab_size = config.vocab_size

        if get_pp_group().is_first_rank:
            self.embed_tokens = VocabParallelEmbedding(
                config.vocab_size,
                config.hidden_size,
                quant_config=quant_config,
                prefix=f"{prefix}.embed_tokens")
        else:
            self.embed_tokens = PPMissingLayer()

        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            lambda prefix: DeepseekV2DecoderLayer(
                config,
                prefix,
                model_config=model_config,
                cache_config=cache_config,
                quant_config=quant_config,
                enable_eplb=enable_eplb,
            ),
            prefix=f"{prefix}.layers")

        if get_pp_group().is_last_rank:
            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        else:
            self.norm = PPMissingLayer()
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors],
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

        for layer in self.layers[self.start_layer:self.end_layer]:
            hidden_states, residual = layer(positions, hidden_states, residual)

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })

        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states

config instance-attribute

config = config

embed_tokens instance-attribute

embed_tokens = VocabParallelEmbedding(
    vocab_size,
    hidden_size,
    quant_config=quant_config,
    prefix=f"{prefix}.embed_tokens",
)

fall_back_to_pt_during_load class-attribute instance-attribute

fall_back_to_pt_during_load = False

make_empty_intermediate_tensors instance-attribute

make_empty_intermediate_tensors = (
    make_empty_intermediate_tensors_factory(
        ["hidden_states", "residual"], hidden_size
    )
)

norm instance-attribute

norm = RMSNorm(hidden_size, eps=rms_norm_eps)

vocab_size instance-attribute

vocab_size = vocab_size

__init__

__init__(*, vllm_config: VllmConfig, prefix: str = '')
Source code in vllm/model_executor/models/deepseek_v2.py
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
    super().__init__()

    config = vllm_config.model_config.hf_config
    model_config = vllm_config.model_config
    cache_config = vllm_config.cache_config
    quant_config = vllm_config.quant_config
    enable_eplb = vllm_config.parallel_config.enable_eplb
    self.config = config

    self.vocab_size = config.vocab_size

    if get_pp_group().is_first_rank:
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=f"{prefix}.embed_tokens")
    else:
        self.embed_tokens = PPMissingLayer()

    self.start_layer, self.end_layer, self.layers = make_layers(
        config.num_hidden_layers,
        lambda prefix: DeepseekV2DecoderLayer(
            config,
            prefix,
            model_config=model_config,
            cache_config=cache_config,
            quant_config=quant_config,
            enable_eplb=enable_eplb,
        ),
        prefix=f"{prefix}.layers")

    if get_pp_group().is_last_rank:
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
    else:
        self.norm = PPMissingLayer()
    self.make_empty_intermediate_tensors = (
        make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size))

forward

forward(
    input_ids: Tensor,
    positions: Tensor,
    intermediate_tensors: Optional[IntermediateTensors],
    inputs_embeds: Optional[Tensor] = None,
) -> Union[Tensor, IntermediateTensors]
Source code in vllm/model_executor/models/deepseek_v2.py
def forward(
    self,
    input_ids: torch.Tensor,
    positions: torch.Tensor,
    intermediate_tensors: Optional[IntermediateTensors],
    inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
    if get_pp_group().is_first_rank:
        if inputs_embeds is not None:
            hidden_states = inputs_embeds
        else:
            hidden_states = self.get_input_embeddings(input_ids)
        residual = None
    else:
        assert intermediate_tensors is not None
        hidden_states = intermediate_tensors["hidden_states"]
        residual = intermediate_tensors["residual"]

    for layer in self.layers[self.start_layer:self.end_layer]:
        hidden_states, residual = layer(positions, hidden_states, residual)

    if not get_pp_group().is_last_rank:
        return IntermediateTensors({
            "hidden_states": hidden_states,
            "residual": residual
        })

    hidden_states, _ = self.norm(hidden_states, residual)
    return hidden_states

get_input_embeddings

get_input_embeddings(input_ids: Tensor) -> Tensor
Source code in vllm/model_executor/models/deepseek_v2.py
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
    return self.embed_tokens(input_ids)

DeepseekV3ForCausalLM

Bases: DeepseekV2ForCausalLM

Source code in vllm/model_executor/models/deepseek_v2.py
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
    pass

get_spec_layer_idx_from_weight_name

get_spec_layer_idx_from_weight_name(
    config: PretrainedConfig, weight_name: str
) -> Optional[int]
Source code in vllm/model_executor/models/deepseek_v2.py
def get_spec_layer_idx_from_weight_name(config: PretrainedConfig,
                                        weight_name: str) -> Optional[int]:
    if hasattr(config,
               "num_nextn_predict_layers") and (config.num_nextn_predict_layers
                                                > 0):
        layer_idx = config.num_hidden_layers
        for i in range(config.num_nextn_predict_layers):
            if weight_name.startswith(f"model.layers.{layer_idx+i}."):
                return layer_idx + i
    return None

yarn_get_mscale

yarn_get_mscale(
    scale: float = 1, mscale: float = 1
) -> float
Source code in vllm/model_executor/models/deepseek_v2.py
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
    import math
    if scale <= 1:
        return 1.0
    return 0.1 * mscale * math.log(scale) + 1.0