Skip to content

vllm.model_executor.models.gemma3n

logger module-attribute

logger = init_logger(__name__)

Gemma3nAltUp

Bases: Module

Alternating updates (Altup) The AltUp module wraps transformer layers. The predict step modifies the input to the transformer layer, and the correct step propagates the output of the transformer layer to the sparsely updated dimensions. See more in the research paper: https://proceedings.neurips.cc/paper_files/paper/2023/file/f2059277ac6ce66e7e5543001afa8bb5-Paper-Conference.pdf

Source code in vllm/model_executor/models/gemma3n.py
class Gemma3nAltUp(nn.Module):
    """Alternating updates (Altup)
    The AltUp module wraps transformer layers. The `predict` step modifies the
    input to the transformer layer, and the `correct` step propagates the output
    of the transformer layer to the sparsely updated dimensions.
    See more in the research paper:
    https://proceedings.neurips.cc/paper_files/paper/2023/file/f2059277ac6ce66e7e5543001afa8bb5-Paper-Conference.pdf
    """

    def __init__(
        self,
        hidden_size: int,
        rms_norm_eps: float,
        altup_num_inputs: int,
        altup_coef_clip: float,
        altup_active_idx: int,
        prefix: str,
    ):
        super().__init__()

        self.altup_num_inputs = altup_num_inputs
        self.altup_active_idx = altup_active_idx
        self.altup_coef_clip = altup_coef_clip

        self.correction_coefs = ReplicatedLinear(
            altup_num_inputs,
            altup_num_inputs,
            bias=False,
            prefix=f"{prefix}.correction_coefs",
            return_bias=False,
        )
        self.prediction_coefs = ReplicatedLinear(
            altup_num_inputs,
            altup_num_inputs**2,
            bias=False,
            prefix=f"{prefix}.prediction_coefs",
            return_bias=False,
        )
        self.modality_router = ReplicatedLinear(
            hidden_size,
            altup_num_inputs,
            bias=False,
            prefix=f"{prefix}.modality_router",
            return_bias=False,
        )
        self.router_norm = RMSNorm(
            hidden_size=hidden_size,
            eps=rms_norm_eps,
        )
        self.router_input_scale = torch.tensor(
            hidden_size**-1.0, dtype=self.modality_router.weight.dtype)
        self.correct_output_scale = nn.Parameter(
            torch.zeros(hidden_size, dtype=torch.float32))

    def _compute_router_modalities(self, x: torch.Tensor) -> torch.Tensor:
        router_inputs = self.router_norm(x) * self.router_input_scale
        routed = self.modality_router(router_inputs)
        return torch.tanh(routed.float()).type_as(x)

    def scale_corrected_output(self, corrected: torch.Tensor) -> torch.Tensor:
        return (corrected.type_as(self.correct_output_scale) *
                self.correct_output_scale).type_as(corrected)

    def predict(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # hidden:       [altup_num_inputs, num_tokens, hidden_size]
        # modalities:   [num_tokens, num_altup_inputs]
        # all_coefs:    [num_tokens, num_altup_inputs ** 2]
        modalities = self._compute_router_modalities(
            hidden_states[self.altup_active_idx])
        all_coefs = self.prediction_coefs(modalities)

        # Reshape and transpose the 2D matrix for the matmul.
        # all_coefs_T:  [num_tokens, num_altup_inputs, num_altup_inputs]
        all_coefs_T = all_coefs.reshape(
            -1,
            self.altup_num_inputs,
            self.altup_num_inputs,
        ).permute(0, 2, 1)

        # hidden_states to [num_tokens, hidden_size, altup_num_inputs]
        predictions = torch.matmul(hidden_states.permute(1, 2, 0), all_coefs_T)
        # [altup_num_inputs, num_tokens, hidden_size]
        predictions = predictions.permute(2, 0, 1)
        predictions += hidden_states
        return predictions.contiguous()

    def correct(self, predictions: torch.Tensor,
                activated: torch.Tensor) -> torch.Tensor:
        # predictions:  [altup_num_inputs, num_tokens, hidden_size]
        # activated:    [num_tokens, hidden_size]
        # modalities:   [num_tokens, altup_num_inputs]
        modalities = self._compute_router_modalities(activated)
        # innovation:   [num_tokens, altup_num_inputs]
        innovation = activated - predictions[self.altup_active_idx]
        # innovation:   [altup_num_inputs, num_tokens, hidden_size]
        innovation = innovation.repeat(self.altup_num_inputs, 1, 1)

        # Permute to [altup_num_inputs, num_tokens] as the last dim
        # is a scalar applied to each altup input and expand on
        # num_tokens dim for broadcastability over hidden_size.
        # all_coefs:    [num_tokens, altup_num_inputs]
        all_coefs = self.correction_coefs(modalities) + 1.0
        # all_coefs:    [altup_num_inputs, num_tokens, 1]
        all_coefs = all_coefs.T.unsqueeze(-1)

        # Elementwise (broadcast over hidden_size).
        corrected = torch.mul(innovation, all_coefs)
        corrected += predictions

        return corrected.contiguous()

altup_active_idx instance-attribute

altup_active_idx = altup_active_idx

altup_coef_clip instance-attribute

altup_coef_clip = altup_coef_clip

altup_num_inputs instance-attribute

altup_num_inputs = altup_num_inputs

correct_output_scale instance-attribute

correct_output_scale = Parameter(
    zeros(hidden_size, dtype=float32)
)

correction_coefs instance-attribute

correction_coefs = ReplicatedLinear(
    altup_num_inputs,
    altup_num_inputs,
    bias=False,
    prefix=f"{prefix}.correction_coefs",
    return_bias=False,
)

modality_router instance-attribute

modality_router = ReplicatedLinear(
    hidden_size,
    altup_num_inputs,
    bias=False,
    prefix=f"{prefix}.modality_router",
    return_bias=False,
)

prediction_coefs instance-attribute

prediction_coefs = ReplicatedLinear(
    altup_num_inputs,
    altup_num_inputs**2,
    bias=False,
    prefix=f"{prefix}.prediction_coefs",
    return_bias=False,
)

router_input_scale instance-attribute

router_input_scale = tensor(hidden_size**-1.0, dtype=dtype)

router_norm instance-attribute

router_norm = RMSNorm(
    hidden_size=hidden_size, eps=rms_norm_eps
)

__init__

__init__(
    hidden_size: int,
    rms_norm_eps: float,
    altup_num_inputs: int,
    altup_coef_clip: float,
    altup_active_idx: int,
    prefix: str,
)
Source code in vllm/model_executor/models/gemma3n.py
def __init__(
    self,
    hidden_size: int,
    rms_norm_eps: float,
    altup_num_inputs: int,
    altup_coef_clip: float,
    altup_active_idx: int,
    prefix: str,
):
    super().__init__()

    self.altup_num_inputs = altup_num_inputs
    self.altup_active_idx = altup_active_idx
    self.altup_coef_clip = altup_coef_clip

    self.correction_coefs = ReplicatedLinear(
        altup_num_inputs,
        altup_num_inputs,
        bias=False,
        prefix=f"{prefix}.correction_coefs",
        return_bias=False,
    )
    self.prediction_coefs = ReplicatedLinear(
        altup_num_inputs,
        altup_num_inputs**2,
        bias=False,
        prefix=f"{prefix}.prediction_coefs",
        return_bias=False,
    )
    self.modality_router = ReplicatedLinear(
        hidden_size,
        altup_num_inputs,
        bias=False,
        prefix=f"{prefix}.modality_router",
        return_bias=False,
    )
    self.router_norm = RMSNorm(
        hidden_size=hidden_size,
        eps=rms_norm_eps,
    )
    self.router_input_scale = torch.tensor(
        hidden_size**-1.0, dtype=self.modality_router.weight.dtype)
    self.correct_output_scale = nn.Parameter(
        torch.zeros(hidden_size, dtype=torch.float32))

_compute_router_modalities

_compute_router_modalities(x: Tensor) -> Tensor
Source code in vllm/model_executor/models/gemma3n.py
def _compute_router_modalities(self, x: torch.Tensor) -> torch.Tensor:
    router_inputs = self.router_norm(x) * self.router_input_scale
    routed = self.modality_router(router_inputs)
    return torch.tanh(routed.float()).type_as(x)

correct

correct(predictions: Tensor, activated: Tensor) -> Tensor
Source code in vllm/model_executor/models/gemma3n.py
def correct(self, predictions: torch.Tensor,
            activated: torch.Tensor) -> torch.Tensor:
    # predictions:  [altup_num_inputs, num_tokens, hidden_size]
    # activated:    [num_tokens, hidden_size]
    # modalities:   [num_tokens, altup_num_inputs]
    modalities = self._compute_router_modalities(activated)
    # innovation:   [num_tokens, altup_num_inputs]
    innovation = activated - predictions[self.altup_active_idx]
    # innovation:   [altup_num_inputs, num_tokens, hidden_size]
    innovation = innovation.repeat(self.altup_num_inputs, 1, 1)

    # Permute to [altup_num_inputs, num_tokens] as the last dim
    # is a scalar applied to each altup input and expand on
    # num_tokens dim for broadcastability over hidden_size.
    # all_coefs:    [num_tokens, altup_num_inputs]
    all_coefs = self.correction_coefs(modalities) + 1.0
    # all_coefs:    [altup_num_inputs, num_tokens, 1]
    all_coefs = all_coefs.T.unsqueeze(-1)

    # Elementwise (broadcast over hidden_size).
    corrected = torch.mul(innovation, all_coefs)
    corrected += predictions

    return corrected.contiguous()

predict

predict(hidden_states: Tensor) -> Tensor
Source code in vllm/model_executor/models/gemma3n.py
def predict(self, hidden_states: torch.Tensor) -> torch.Tensor:
    # hidden:       [altup_num_inputs, num_tokens, hidden_size]
    # modalities:   [num_tokens, num_altup_inputs]
    # all_coefs:    [num_tokens, num_altup_inputs ** 2]
    modalities = self._compute_router_modalities(
        hidden_states[self.altup_active_idx])
    all_coefs = self.prediction_coefs(modalities)

    # Reshape and transpose the 2D matrix for the matmul.
    # all_coefs_T:  [num_tokens, num_altup_inputs, num_altup_inputs]
    all_coefs_T = all_coefs.reshape(
        -1,
        self.altup_num_inputs,
        self.altup_num_inputs,
    ).permute(0, 2, 1)

    # hidden_states to [num_tokens, hidden_size, altup_num_inputs]
    predictions = torch.matmul(hidden_states.permute(1, 2, 0), all_coefs_T)
    # [altup_num_inputs, num_tokens, hidden_size]
    predictions = predictions.permute(2, 0, 1)
    predictions += hidden_states
    return predictions.contiguous()

scale_corrected_output

scale_corrected_output(corrected: Tensor) -> Tensor
Source code in vllm/model_executor/models/gemma3n.py
def scale_corrected_output(self, corrected: torch.Tensor) -> torch.Tensor:
    return (corrected.type_as(self.correct_output_scale) *
            self.correct_output_scale).type_as(corrected)

Gemma3nAttention

Bases: Module

Source code in vllm/model_executor/models/gemma3n.py
class Gemma3nAttention(nn.Module):

    def __init__(self,
                 config: Gemma3nTextConfig,
                 hidden_size: int,
                 num_heads: int,
                 num_kv_heads: int,
                 head_dim: int,
                 max_position_embeddings: int,
                 cache_config: Optional[CacheConfig] = None,
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = "") -> None:
        super().__init__()
        self.config = config
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = head_dim
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim

        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=config.attention_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=config.attention_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )
        self.q_norm = RMSNorm(hidden_size=self.head_dim,
                              eps=config.rms_norm_eps)
        self.k_norm = RMSNorm(hidden_size=self.head_dim,
                              eps=config.rms_norm_eps)
        self.v_norm = RMSNorm(hidden_size=self.head_dim,
                              eps=config.rms_norm_eps,
                              has_weight=False)

        layer_idx = extract_layer_index(prefix)
        if config.layer_types[layer_idx] == "sliding_attention":
            self.sliding_window = config.sliding_window
            rope_theta = config.rope_local_base_freq
            rope_scaling = {"rope_type": "default"}
        else:
            self.sliding_window = None
            rope_theta = config.rope_theta
            rope_scaling = config.rope_scaling

        first_kv_shared_layer_idx = (config.num_hidden_layers -
                                     config.num_kv_shared_layers)
        self.is_kv_shared = layer_idx >= first_kv_shared_layer_idx

        if self.is_kv_shared:
            # Last full attention layer is 1 before sharing
            # Last sliding attention layer is 2 before sharing
            offset = 2 if self.sliding_window is not None else 1
            kv_shared_layer_index = first_kv_shared_layer_idx - offset
            kv_sharing_target_layer_name = f"model.language_model.layers.{kv_shared_layer_index}.self_attn.attn"  # noqa: E501
        else:
            kv_sharing_target_layer_name = None

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            is_neox_style=True,
            rope_scaling=rope_scaling,
        )

        self.attn = Attention(
            num_heads=self.num_heads,
            head_size=self.head_dim,
            scale=1.0,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            per_layer_sliding_window=self.sliding_window,
            kv_sharing_target_layer_name=kv_sharing_target_layer_name,
            prefix=f"{prefix}.attn")

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        **kwargs,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)

        q = q.unflatten(-1, (self.num_heads, self.head_dim))
        q = self.q_norm(q)
        q = q.flatten(-2, -1)
        k = k.unflatten(-1, (self.num_kv_heads, self.head_dim))
        k = self.k_norm(k)
        k = k.flatten(-2, -1)
        v = v.unflatten(-1, (self.num_kv_heads, self.head_dim))
        v = self.v_norm(v)
        v = v.flatten(-2, -1)

        q, k = self.rotary_emb(positions, q, k)
        attn_output = self.attn(q, k, v)

        output, _ = self.o_proj(attn_output)
        return output

attn instance-attribute

attn = Attention(
    num_heads=num_heads,
    head_size=head_dim,
    scale=1.0,
    num_kv_heads=num_kv_heads,
    cache_config=cache_config,
    quant_config=quant_config,
    per_layer_sliding_window=sliding_window,
    kv_sharing_target_layer_name=kv_sharing_target_layer_name,
    prefix=f"{prefix}.attn",
)

config instance-attribute

config = config

head_dim instance-attribute

head_dim = head_dim

hidden_size instance-attribute

hidden_size = hidden_size

is_kv_shared instance-attribute

is_kv_shared = layer_idx >= first_kv_shared_layer_idx

k_norm instance-attribute

k_norm = RMSNorm(hidden_size=head_dim, eps=rms_norm_eps)

kv_size instance-attribute

kv_size = num_kv_heads * head_dim

num_heads instance-attribute

num_heads = total_num_heads // tp_size

num_kv_heads instance-attribute

num_kv_heads = max(1, total_num_kv_heads // tp_size)

o_proj instance-attribute

o_proj = RowParallelLinear(
    total_num_heads * head_dim,
    hidden_size,
    bias=attention_bias,
    quant_config=quant_config,
    prefix=f"{prefix}.o_proj",
)

q_norm instance-attribute

q_norm = RMSNorm(hidden_size=head_dim, eps=rms_norm_eps)

q_size instance-attribute

q_size = num_heads * head_dim

qkv_proj instance-attribute

qkv_proj = QKVParallelLinear(
    hidden_size,
    head_dim,
    total_num_heads,
    total_num_kv_heads,
    bias=attention_bias,
    quant_config=quant_config,
    prefix=f"{prefix}.qkv_proj",
)

rotary_emb instance-attribute

rotary_emb = get_rope(
    head_dim,
    rotary_dim=head_dim,
    max_position=max_position_embeddings,
    base=rope_theta,
    is_neox_style=True,
    rope_scaling=rope_scaling,
)

sliding_window instance-attribute

sliding_window = sliding_window

total_num_heads instance-attribute

total_num_heads = num_heads

total_num_kv_heads instance-attribute

total_num_kv_heads = num_kv_heads

v_norm instance-attribute

v_norm = RMSNorm(
    hidden_size=head_dim, eps=rms_norm_eps, has_weight=False
)

__init__

__init__(
    config: Gemma3nTextConfig,
    hidden_size: int,
    num_heads: int,
    num_kv_heads: int,
    head_dim: int,
    max_position_embeddings: int,
    cache_config: Optional[CacheConfig] = None,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
) -> None
Source code in vllm/model_executor/models/gemma3n.py
def __init__(self,
             config: Gemma3nTextConfig,
             hidden_size: int,
             num_heads: int,
             num_kv_heads: int,
             head_dim: int,
             max_position_embeddings: int,
             cache_config: Optional[CacheConfig] = None,
             quant_config: Optional[QuantizationConfig] = None,
             prefix: str = "") -> None:
    super().__init__()
    self.config = config
    self.hidden_size = hidden_size
    tp_size = get_tensor_model_parallel_world_size()
    self.total_num_heads = num_heads
    assert self.total_num_heads % tp_size == 0
    self.num_heads = self.total_num_heads // tp_size
    self.total_num_kv_heads = num_kv_heads
    if self.total_num_kv_heads >= tp_size:
        # Number of KV heads is greater than TP size, so we partition
        # the KV heads across multiple tensor parallel GPUs.
        assert self.total_num_kv_heads % tp_size == 0
    else:
        # Number of KV heads is less than TP size, so we replicate
        # the KV heads across multiple tensor parallel GPUs.
        assert tp_size % self.total_num_kv_heads == 0
    self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
    self.head_dim = head_dim
    self.q_size = self.num_heads * self.head_dim
    self.kv_size = self.num_kv_heads * self.head_dim

    self.qkv_proj = QKVParallelLinear(
        hidden_size,
        self.head_dim,
        self.total_num_heads,
        self.total_num_kv_heads,
        bias=config.attention_bias,
        quant_config=quant_config,
        prefix=f"{prefix}.qkv_proj",
    )
    self.o_proj = RowParallelLinear(
        self.total_num_heads * self.head_dim,
        hidden_size,
        bias=config.attention_bias,
        quant_config=quant_config,
        prefix=f"{prefix}.o_proj",
    )
    self.q_norm = RMSNorm(hidden_size=self.head_dim,
                          eps=config.rms_norm_eps)
    self.k_norm = RMSNorm(hidden_size=self.head_dim,
                          eps=config.rms_norm_eps)
    self.v_norm = RMSNorm(hidden_size=self.head_dim,
                          eps=config.rms_norm_eps,
                          has_weight=False)

    layer_idx = extract_layer_index(prefix)
    if config.layer_types[layer_idx] == "sliding_attention":
        self.sliding_window = config.sliding_window
        rope_theta = config.rope_local_base_freq
        rope_scaling = {"rope_type": "default"}
    else:
        self.sliding_window = None
        rope_theta = config.rope_theta
        rope_scaling = config.rope_scaling

    first_kv_shared_layer_idx = (config.num_hidden_layers -
                                 config.num_kv_shared_layers)
    self.is_kv_shared = layer_idx >= first_kv_shared_layer_idx

    if self.is_kv_shared:
        # Last full attention layer is 1 before sharing
        # Last sliding attention layer is 2 before sharing
        offset = 2 if self.sliding_window is not None else 1
        kv_shared_layer_index = first_kv_shared_layer_idx - offset
        kv_sharing_target_layer_name = f"model.language_model.layers.{kv_shared_layer_index}.self_attn.attn"  # noqa: E501
    else:
        kv_sharing_target_layer_name = None

    self.rotary_emb = get_rope(
        self.head_dim,
        rotary_dim=self.head_dim,
        max_position=max_position_embeddings,
        base=rope_theta,
        is_neox_style=True,
        rope_scaling=rope_scaling,
    )

    self.attn = Attention(
        num_heads=self.num_heads,
        head_size=self.head_dim,
        scale=1.0,
        num_kv_heads=self.num_kv_heads,
        cache_config=cache_config,
        quant_config=quant_config,
        per_layer_sliding_window=self.sliding_window,
        kv_sharing_target_layer_name=kv_sharing_target_layer_name,
        prefix=f"{prefix}.attn")

forward

forward(
    positions: Tensor, hidden_states: Tensor, **kwargs
) -> Tensor
Source code in vllm/model_executor/models/gemma3n.py
def forward(
    self,
    positions: torch.Tensor,
    hidden_states: torch.Tensor,
    **kwargs,
) -> torch.Tensor:
    qkv, _ = self.qkv_proj(hidden_states)
    q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)

    q = q.unflatten(-1, (self.num_heads, self.head_dim))
    q = self.q_norm(q)
    q = q.flatten(-2, -1)
    k = k.unflatten(-1, (self.num_kv_heads, self.head_dim))
    k = self.k_norm(k)
    k = k.flatten(-2, -1)
    v = v.unflatten(-1, (self.num_kv_heads, self.head_dim))
    v = self.v_norm(v)
    v = v.flatten(-2, -1)

    q, k = self.rotary_emb(positions, q, k)
    attn_output = self.attn(q, k, v)

    output, _ = self.o_proj(attn_output)
    return output

Gemma3nDecoderLayer

Bases: Module

Source code in vllm/model_executor/models/gemma3n.py
class Gemma3nDecoderLayer(nn.Module):

    def __init__(
        self,
        config: Gemma3nTextConfig,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.altup_active_idx = config.altup_active_idx
        assert config.altup_correct_scale

        self.altup = Gemma3nAltUp(
            hidden_size=config.hidden_size,
            rms_norm_eps=config.rms_norm_eps,
            altup_num_inputs=config.altup_num_inputs,
            altup_coef_clip=config.altup_coef_clip,
            altup_active_idx=config.altup_active_idx,
            prefix=f"{prefix}.altup",
        )
        self.self_attn = Gemma3nAttention(
            config=config,
            hidden_size=config.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            head_dim=config.head_dim,
            max_position_embeddings=config.max_position_embeddings,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
        )
        self.mlp = Gemma3nMLP(
            hidden_size=config.hidden_size,
            # NOTE: Matformer https://github.com/huggingface/transformers/blob/a52478253bbe522a420e88ea3940d4d98a935300/src/transformers/models/gemma3n/modular_gemma3n.py#L258 # noqa: E501
            intermediate_size=config.intermediate_size[extract_layer_index(
                prefix)],
            hidden_activation=config.hidden_activation,
            quant_config=quant_config,
            activation_sparsity=config.activation_sparsity_pattern[
                extract_layer_index(prefix)],
            prefix=f"{prefix}.mlp",
        )
        self.laurel = Gemma3nLaurelBlock(
            hidden_size=config.hidden_size,
            laurel_rank=config.laurel_rank,
            rms_norm_eps=config.rms_norm_eps,
            prefix=f"{prefix}.laurel",
        )

        # NOTE(rob): should be ColumnParallelLinear and RowParallelLinear
        # But, we need to add per_layer_input_gate(x) to per_layer_input.
        # per_layer_input cannot be sharded, so we replicate for now.
        self.per_layer_input_gate = ReplicatedLinear(
            config.hidden_size,
            config.hidden_size_per_layer_input,
            bias=False,
            prefix=f"{prefix}.per_layer_input_gate",
            return_bias=False,
        )
        self.per_layer_projection = ReplicatedLinear(
            config.hidden_size_per_layer_input,
            config.hidden_size,
            bias=False,
            prefix=f"{prefix}.per_layer_projection",
            return_bias=False,
        )

        # LayerNorms.
        self.input_layernorm = RMSNorm(
            config.hidden_size,
            eps=config.rms_norm_eps,
        )
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size,
            eps=config.rms_norm_eps,
        )
        self.pre_feedforward_layernorm = RMSNorm(
            config.hidden_size,
            eps=config.rms_norm_eps,
        )
        self.post_feedforward_layernorm = RMSNorm(
            config.hidden_size,
            eps=config.rms_norm_eps,
        )
        self.post_per_layer_input_norm = RMSNorm(
            config.hidden_size,
            eps=config.rms_norm_eps,
        )

        self.act_fn = _ACTIVATION_REGISTRY[config.hidden_activation]

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        per_layer_input: torch.Tensor,
        **kwargs,
    ) -> tuple[torch.Tensor, torch.Tensor]:

        # ActUp (predict).
        predictions = self.altup.predict(hidden_states)
        active_prediction = predictions[self.altup_active_idx]
        active_prediction_normed = self.input_layernorm(active_prediction)
        laurel_output = self.laurel(active_prediction_normed)

        # Attention.
        attn = self.self_attn(
            positions=positions,
            hidden_states=active_prediction_normed,
            **kwargs,
        )
        attn = self.post_attention_layernorm(attn)
        attn_gated = attn + active_prediction
        attn_laurel = (attn_gated + laurel_output) / torch.sqrt(
            torch.tensor(2.0))

        # MLP.
        attn_norm = self.pre_feedforward_layernorm(attn_laurel)
        attn_ffw = self.mlp(attn_norm)
        attn_ffw_norm = self.post_feedforward_layernorm(attn_ffw)
        attn_ffw_laurel_gated = attn_laurel + attn_ffw_norm

        # ActUp (connect).
        corrected_predictions = self.altup.correct(predictions,
                                                   attn_ffw_laurel_gated)
        first_prediction = corrected_predictions[self.altup_active_idx]
        first_prediction = self.altup.scale_corrected_output(first_prediction)

        # per_layer_input_gate adapted from jax.numpy.einsum("btd,dp->btp", ...)
        first_prediction = self.per_layer_input_gate(first_prediction)
        first_prediction = self.act_fn(first_prediction)
        first_prediction = torch.mul(first_prediction, per_layer_input)

        # per_layer_projection adapted from jax.numpy.einsum("btp,pd->btd", ...)
        first_prediction = self.per_layer_projection(first_prediction)
        first_prediction = self.post_per_layer_input_norm(first_prediction)
        corrected_predictions[1:] += first_prediction

        return corrected_predictions

act_fn instance-attribute

act_fn = _ACTIVATION_REGISTRY[hidden_activation]

altup instance-attribute

altup = Gemma3nAltUp(
    hidden_size=hidden_size,
    rms_norm_eps=rms_norm_eps,
    altup_num_inputs=altup_num_inputs,
    altup_coef_clip=altup_coef_clip,
    altup_active_idx=altup_active_idx,
    prefix=f"{prefix}.altup",
)

altup_active_idx instance-attribute

altup_active_idx = altup_active_idx

input_layernorm instance-attribute

input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps)

laurel instance-attribute

laurel = Gemma3nLaurelBlock(
    hidden_size=hidden_size,
    laurel_rank=laurel_rank,
    rms_norm_eps=rms_norm_eps,
    prefix=f"{prefix}.laurel",
)

mlp instance-attribute

mlp = Gemma3nMLP(
    hidden_size=hidden_size,
    intermediate_size=intermediate_size[
        extract_layer_index(prefix)
    ],
    hidden_activation=hidden_activation,
    quant_config=quant_config,
    activation_sparsity=activation_sparsity_pattern[
        extract_layer_index(prefix)
    ],
    prefix=f"{prefix}.mlp",
)

per_layer_input_gate instance-attribute

per_layer_input_gate = ReplicatedLinear(
    hidden_size,
    hidden_size_per_layer_input,
    bias=False,
    prefix=f"{prefix}.per_layer_input_gate",
    return_bias=False,
)

per_layer_projection instance-attribute

per_layer_projection = ReplicatedLinear(
    hidden_size_per_layer_input,
    hidden_size,
    bias=False,
    prefix=f"{prefix}.per_layer_projection",
    return_bias=False,
)

post_attention_layernorm instance-attribute

post_attention_layernorm = RMSNorm(
    hidden_size, eps=rms_norm_eps
)

post_feedforward_layernorm instance-attribute

post_feedforward_layernorm = RMSNorm(
    hidden_size, eps=rms_norm_eps
)

post_per_layer_input_norm instance-attribute

post_per_layer_input_norm = RMSNorm(
    hidden_size, eps=rms_norm_eps
)

pre_feedforward_layernorm instance-attribute

pre_feedforward_layernorm = RMSNorm(
    hidden_size, eps=rms_norm_eps
)

self_attn instance-attribute

self_attn = Gemma3nAttention(
    config=config,
    hidden_size=hidden_size,
    num_heads=num_attention_heads,
    num_kv_heads=num_key_value_heads,
    head_dim=head_dim,
    max_position_embeddings=max_position_embeddings,
    cache_config=cache_config,
    quant_config=quant_config,
    prefix=f"{prefix}.self_attn",
)

__init__

__init__(
    config: Gemma3nTextConfig,
    cache_config: Optional[CacheConfig] = None,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
) -> None
Source code in vllm/model_executor/models/gemma3n.py
def __init__(
    self,
    config: Gemma3nTextConfig,
    cache_config: Optional[CacheConfig] = None,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
) -> None:
    super().__init__()
    self.altup_active_idx = config.altup_active_idx
    assert config.altup_correct_scale

    self.altup = Gemma3nAltUp(
        hidden_size=config.hidden_size,
        rms_norm_eps=config.rms_norm_eps,
        altup_num_inputs=config.altup_num_inputs,
        altup_coef_clip=config.altup_coef_clip,
        altup_active_idx=config.altup_active_idx,
        prefix=f"{prefix}.altup",
    )
    self.self_attn = Gemma3nAttention(
        config=config,
        hidden_size=config.hidden_size,
        num_heads=config.num_attention_heads,
        num_kv_heads=config.num_key_value_heads,
        head_dim=config.head_dim,
        max_position_embeddings=config.max_position_embeddings,
        cache_config=cache_config,
        quant_config=quant_config,
        prefix=f"{prefix}.self_attn",
    )
    self.mlp = Gemma3nMLP(
        hidden_size=config.hidden_size,
        # NOTE: Matformer https://github.com/huggingface/transformers/blob/a52478253bbe522a420e88ea3940d4d98a935300/src/transformers/models/gemma3n/modular_gemma3n.py#L258 # noqa: E501
        intermediate_size=config.intermediate_size[extract_layer_index(
            prefix)],
        hidden_activation=config.hidden_activation,
        quant_config=quant_config,
        activation_sparsity=config.activation_sparsity_pattern[
            extract_layer_index(prefix)],
        prefix=f"{prefix}.mlp",
    )
    self.laurel = Gemma3nLaurelBlock(
        hidden_size=config.hidden_size,
        laurel_rank=config.laurel_rank,
        rms_norm_eps=config.rms_norm_eps,
        prefix=f"{prefix}.laurel",
    )

    # NOTE(rob): should be ColumnParallelLinear and RowParallelLinear
    # But, we need to add per_layer_input_gate(x) to per_layer_input.
    # per_layer_input cannot be sharded, so we replicate for now.
    self.per_layer_input_gate = ReplicatedLinear(
        config.hidden_size,
        config.hidden_size_per_layer_input,
        bias=False,
        prefix=f"{prefix}.per_layer_input_gate",
        return_bias=False,
    )
    self.per_layer_projection = ReplicatedLinear(
        config.hidden_size_per_layer_input,
        config.hidden_size,
        bias=False,
        prefix=f"{prefix}.per_layer_projection",
        return_bias=False,
    )

    # LayerNorms.
    self.input_layernorm = RMSNorm(
        config.hidden_size,
        eps=config.rms_norm_eps,
    )
    self.post_attention_layernorm = RMSNorm(
        config.hidden_size,
        eps=config.rms_norm_eps,
    )
    self.pre_feedforward_layernorm = RMSNorm(
        config.hidden_size,
        eps=config.rms_norm_eps,
    )
    self.post_feedforward_layernorm = RMSNorm(
        config.hidden_size,
        eps=config.rms_norm_eps,
    )
    self.post_per_layer_input_norm = RMSNorm(
        config.hidden_size,
        eps=config.rms_norm_eps,
    )

    self.act_fn = _ACTIVATION_REGISTRY[config.hidden_activation]

forward

forward(
    positions: Tensor,
    hidden_states: Tensor,
    per_layer_input: Tensor,
    **kwargs,
) -> tuple[Tensor, Tensor]
Source code in vllm/model_executor/models/gemma3n.py
def forward(
    self,
    positions: torch.Tensor,
    hidden_states: torch.Tensor,
    per_layer_input: torch.Tensor,
    **kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:

    # ActUp (predict).
    predictions = self.altup.predict(hidden_states)
    active_prediction = predictions[self.altup_active_idx]
    active_prediction_normed = self.input_layernorm(active_prediction)
    laurel_output = self.laurel(active_prediction_normed)

    # Attention.
    attn = self.self_attn(
        positions=positions,
        hidden_states=active_prediction_normed,
        **kwargs,
    )
    attn = self.post_attention_layernorm(attn)
    attn_gated = attn + active_prediction
    attn_laurel = (attn_gated + laurel_output) / torch.sqrt(
        torch.tensor(2.0))

    # MLP.
    attn_norm = self.pre_feedforward_layernorm(attn_laurel)
    attn_ffw = self.mlp(attn_norm)
    attn_ffw_norm = self.post_feedforward_layernorm(attn_ffw)
    attn_ffw_laurel_gated = attn_laurel + attn_ffw_norm

    # ActUp (connect).
    corrected_predictions = self.altup.correct(predictions,
                                               attn_ffw_laurel_gated)
    first_prediction = corrected_predictions[self.altup_active_idx]
    first_prediction = self.altup.scale_corrected_output(first_prediction)

    # per_layer_input_gate adapted from jax.numpy.einsum("btd,dp->btp", ...)
    first_prediction = self.per_layer_input_gate(first_prediction)
    first_prediction = self.act_fn(first_prediction)
    first_prediction = torch.mul(first_prediction, per_layer_input)

    # per_layer_projection adapted from jax.numpy.einsum("btp,pd->btd", ...)
    first_prediction = self.per_layer_projection(first_prediction)
    first_prediction = self.post_per_layer_input_norm(first_prediction)
    corrected_predictions[1:] += first_prediction

    return corrected_predictions

Gemma3nForConditionalGeneration

Bases: Module

Source code in vllm/model_executor/models/gemma3n.py
class Gemma3nForConditionalGeneration(nn.Module):
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        config = vllm_config.model_config.hf_config
        lora_config = vllm_config.lora_config
        del lora_config  # Unused.
        super().__init__()
        self.config = config
        self.model = Gemma3nModel(vllm_config=vllm_config,
                                  prefix=maybe_prefix(prefix, "model"))
        self.logits_processor = LogitsProcessor(
            config.text_config.vocab_size,
            soft_cap=config.text_config.final_logit_softcapping)

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.language_model.get_input_embeddings(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        hidden_states = self.model(input_ids, positions, intermediate_tensors,
                                   inputs_embeds, **kwargs)
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: Optional[SamplingMetadata],
    ) -> Optional[torch.Tensor]:
        logits = self.logits_processor(self.model.language_model.embed_tokens,
                                       hidden_states, sampling_metadata)
        return logits

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        loader = AutoWeightsLoader(self,
                                   skip_substrs=([
                                       "embed_audio.", "embed_vision.",
                                       "audio_tower.", "vision_tower."
                                   ]))
        return loader.load_weights(weights)

config instance-attribute

config = config

logits_processor instance-attribute

logits_processor = LogitsProcessor(
    vocab_size, soft_cap=final_logit_softcapping
)

model instance-attribute

model = Gemma3nModel(
    vllm_config=vllm_config,
    prefix=maybe_prefix(prefix, "model"),
)

packed_modules_mapping class-attribute instance-attribute

packed_modules_mapping = {
    "qkv_proj": ["q_proj", "k_proj", "v_proj"],
    "gate_up_proj": ["gate_proj", "up_proj"],
}

__init__

__init__(*, vllm_config: VllmConfig, prefix: str = '')
Source code in vllm/model_executor/models/gemma3n.py
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
    config = vllm_config.model_config.hf_config
    lora_config = vllm_config.lora_config
    del lora_config  # Unused.
    super().__init__()
    self.config = config
    self.model = Gemma3nModel(vllm_config=vllm_config,
                              prefix=maybe_prefix(prefix, "model"))
    self.logits_processor = LogitsProcessor(
        config.text_config.vocab_size,
        soft_cap=config.text_config.final_logit_softcapping)

compute_logits

compute_logits(
    hidden_states: Tensor,
    sampling_metadata: Optional[SamplingMetadata],
) -> Optional[Tensor]
Source code in vllm/model_executor/models/gemma3n.py
def compute_logits(
    self,
    hidden_states: torch.Tensor,
    sampling_metadata: Optional[SamplingMetadata],
) -> Optional[torch.Tensor]:
    logits = self.logits_processor(self.model.language_model.embed_tokens,
                                   hidden_states, sampling_metadata)
    return logits

forward

forward(
    input_ids: Tensor,
    positions: Tensor,
    intermediate_tensors: Optional[
        IntermediateTensors
    ] = None,
    inputs_embeds: Optional[Tensor] = None,
    **kwargs,
) -> Union[Tensor, IntermediateTensors]
Source code in vllm/model_executor/models/gemma3n.py
def forward(
    self,
    input_ids: torch.Tensor,
    positions: torch.Tensor,
    intermediate_tensors: Optional[IntermediateTensors] = None,
    inputs_embeds: Optional[torch.Tensor] = None,
    **kwargs,
) -> Union[torch.Tensor, IntermediateTensors]:
    hidden_states = self.model(input_ids, positions, intermediate_tensors,
                               inputs_embeds, **kwargs)
    return hidden_states

get_input_embeddings

get_input_embeddings(input_ids: Tensor) -> Tensor
Source code in vllm/model_executor/models/gemma3n.py
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
    return self.model.language_model.get_input_embeddings(input_ids)

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]
Source code in vllm/model_executor/models/gemma3n.py
def load_weights(self, weights: Iterable[tuple[str,
                                               torch.Tensor]]) -> set[str]:
    loader = AutoWeightsLoader(self,
                               skip_substrs=([
                                   "embed_audio.", "embed_vision.",
                                   "audio_tower.", "vision_tower."
                               ]))
    return loader.load_weights(weights)

Gemma3nLaurelBlock

Bases: Module

Learned Augmented Residual Layer

Source code in vllm/model_executor/models/gemma3n.py
class Gemma3nLaurelBlock(nn.Module):
    """Learned Augmented Residual Layer"""

    def __init__(self, hidden_size: int, laurel_rank: int, rms_norm_eps: float,
                 prefix: str):
        super().__init__()

        self.linear_left = ColumnParallelLinear(
            hidden_size,
            laurel_rank,
            bias=False,
            prefix=f"{prefix}.linear_left",
            return_bias=False,
        )
        self.linear_right = RowParallelLinear(laurel_rank,
                                              hidden_size,
                                              bias=False,
                                              prefix=f"{prefix}.linear_right",
                                              return_bias=False)
        self.post_laurel_norm = RMSNorm(
            hidden_size=hidden_size,
            eps=rms_norm_eps,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        laurel_x = self.linear_left(x)
        laurel_x = self.linear_right(laurel_x)
        normed_laurel_x = self.post_laurel_norm(laurel_x)
        return x + normed_laurel_x

linear_left instance-attribute

linear_left = ColumnParallelLinear(
    hidden_size,
    laurel_rank,
    bias=False,
    prefix=f"{prefix}.linear_left",
    return_bias=False,
)

linear_right instance-attribute

linear_right = RowParallelLinear(
    laurel_rank,
    hidden_size,
    bias=False,
    prefix=f"{prefix}.linear_right",
    return_bias=False,
)

post_laurel_norm instance-attribute

post_laurel_norm = RMSNorm(
    hidden_size=hidden_size, eps=rms_norm_eps
)

__init__

__init__(
    hidden_size: int,
    laurel_rank: int,
    rms_norm_eps: float,
    prefix: str,
)
Source code in vllm/model_executor/models/gemma3n.py
def __init__(self, hidden_size: int, laurel_rank: int, rms_norm_eps: float,
             prefix: str):
    super().__init__()

    self.linear_left = ColumnParallelLinear(
        hidden_size,
        laurel_rank,
        bias=False,
        prefix=f"{prefix}.linear_left",
        return_bias=False,
    )
    self.linear_right = RowParallelLinear(laurel_rank,
                                          hidden_size,
                                          bias=False,
                                          prefix=f"{prefix}.linear_right",
                                          return_bias=False)
    self.post_laurel_norm = RMSNorm(
        hidden_size=hidden_size,
        eps=rms_norm_eps,
    )

forward

forward(x: Tensor) -> Tensor
Source code in vllm/model_executor/models/gemma3n.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    laurel_x = self.linear_left(x)
    laurel_x = self.linear_right(laurel_x)
    normed_laurel_x = self.post_laurel_norm(laurel_x)
    return x + normed_laurel_x

Gemma3nMLP

Bases: Module

Source code in vllm/model_executor/models/gemma3n.py
class Gemma3nMLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_activation: str,
        activation_sparsity: float = 0.0,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size,
            [intermediate_size] * 2,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.down_proj",
        )
        if hidden_activation != "gelu_pytorch_tanh":
            raise ValueError(
                "Gemma3 uses `gelu_pytorch_tanh` as the hidden activation "
                "function. Please set `hidden_act` and `hidden_activation` to "
                "`gelu_pytorch_tanh`.")

        self.act_fn = GeluAndMulSparse(
            activation_sparsity=activation_sparsity,
            approximate="tanh") if activation_sparsity > 0.0 else GeluAndMul(
                approximate="tanh")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x

act_fn instance-attribute

act_fn = (
    GeluAndMulSparse(
        activation_sparsity=activation_sparsity,
        approximate="tanh",
    )
    if activation_sparsity > 0.0
    else GeluAndMul(approximate="tanh")
)

down_proj instance-attribute

down_proj = RowParallelLinear(
    intermediate_size,
    hidden_size,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.down_proj",
)

gate_up_proj instance-attribute

gate_up_proj = MergedColumnParallelLinear(
    hidden_size,
    [intermediate_size] * 2,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.gate_up_proj",
)

__init__

__init__(
    hidden_size: int,
    intermediate_size: int,
    hidden_activation: str,
    activation_sparsity: float = 0.0,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
) -> None
Source code in vllm/model_executor/models/gemma3n.py
def __init__(
    self,
    hidden_size: int,
    intermediate_size: int,
    hidden_activation: str,
    activation_sparsity: float = 0.0,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
) -> None:
    super().__init__()
    self.gate_up_proj = MergedColumnParallelLinear(
        hidden_size,
        [intermediate_size] * 2,
        bias=False,
        quant_config=quant_config,
        prefix=f"{prefix}.gate_up_proj",
    )
    self.down_proj = RowParallelLinear(
        intermediate_size,
        hidden_size,
        bias=False,
        quant_config=quant_config,
        prefix=f"{prefix}.down_proj",
    )
    if hidden_activation != "gelu_pytorch_tanh":
        raise ValueError(
            "Gemma3 uses `gelu_pytorch_tanh` as the hidden activation "
            "function. Please set `hidden_act` and `hidden_activation` to "
            "`gelu_pytorch_tanh`.")

    self.act_fn = GeluAndMulSparse(
        activation_sparsity=activation_sparsity,
        approximate="tanh") if activation_sparsity > 0.0 else GeluAndMul(
            approximate="tanh")

forward

forward(x: Tensor) -> Tensor
Source code in vllm/model_executor/models/gemma3n.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    gate_up, _ = self.gate_up_proj(x)
    x = self.act_fn(gate_up)
    x, _ = self.down_proj(x)
    return x

Gemma3nModel

Bases: Module

Source code in vllm/model_executor/models/gemma3n.py
class Gemma3nModel(nn.Module):

    def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        self.language_model = Gemma3nTextModel(vllm_config=vllm_config,
                                               prefix=maybe_prefix(
                                                   prefix, "language_model"))

    def forward(
        self,
        input_ids: Optional[torch.Tensor],
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> torch.Tensor:
        return self.language_model(input_ids=input_ids,
                                   positions=positions,
                                   inputs_embeds=inputs_embeds,
                                   **kwargs)

language_model instance-attribute

language_model = Gemma3nTextModel(
    vllm_config=vllm_config,
    prefix=maybe_prefix(prefix, "language_model"),
)

__init__

__init__(vllm_config: VllmConfig, prefix: str = '')
Source code in vllm/model_executor/models/gemma3n.py
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
    super().__init__()
    self.language_model = Gemma3nTextModel(vllm_config=vllm_config,
                                           prefix=maybe_prefix(
                                               prefix, "language_model"))

forward

forward(
    input_ids: Optional[Tensor],
    positions: Tensor,
    intermediate_tensors: Optional[
        IntermediateTensors
    ] = None,
    inputs_embeds: Optional[Tensor] = None,
    **kwargs,
) -> Tensor
Source code in vllm/model_executor/models/gemma3n.py
def forward(
    self,
    input_ids: Optional[torch.Tensor],
    positions: torch.Tensor,
    intermediate_tensors: Optional[IntermediateTensors] = None,
    inputs_embeds: Optional[torch.Tensor] = None,
    **kwargs,
) -> torch.Tensor:
    return self.language_model(input_ids=input_ids,
                               positions=positions,
                               inputs_embeds=inputs_embeds,
                               **kwargs)

Gemma3nTextModel

Bases: Module

Source code in vllm/model_executor/models/gemma3n.py
@support_torch_compile
class Gemma3nTextModel(nn.Module):

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config.text_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        self.config = config
        self.quant_config = quant_config

        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
            prefix=f"{prefix}.embed_tokens",
        )
        self.embed_scale = torch.tensor(
            config.hidden_size**0.5,
            dtype=self.embed_tokens.weight.dtype,
        )
        self.embed_tokens_per_layer = VocabParallelEmbedding(
            config.vocab_size_per_layer_input,
            config.num_hidden_layers * config.hidden_size_per_layer_input,
            prefix=f"{prefix}.per_layer_embed_tokens",
        )
        self.embed_scale_per_layer = torch.tensor(
            config.hidden_size_per_layer_input**0.5,
            dtype=self.embed_tokens.weight.dtype,
        )
        self.per_layer_model_projection = ColumnParallelLinear(
            config.hidden_size,
            config.num_hidden_layers * config.hidden_size_per_layer_input,
            bias=False,
            gather_output=True,
            return_bias=False,
            prefix=f"{prefix}.per_layer_model_projection",
        )
        self.per_layer_projection_norm = RMSNorm(
            hidden_size=config.hidden_size_per_layer_input,
            eps=config.rms_norm_eps,
        )
        self.per_layer_input_scale = torch.rsqrt(torch.tensor(2.0)).to(
            self.embed_tokens.weight.dtype)
        self.per_layer_projection_scale = torch.tensor(
            config.hidden_size**0.5,
            dtype=self.embed_tokens.weight.dtype,
        )
        self.altup_projections = nn.ModuleList([
            ColumnParallelLinear(
                config.hidden_size,
                config.hidden_size,
                bias=False,
                gather_output=True,
                return_bias=False,
                prefix=f"{prefix}.{idx-1}.altup_projections",
            ) for idx in range(1, self.config.altup_num_inputs)
        ])
        self.altup_unembed_projections = nn.ModuleList([
            ColumnParallelLinear(
                config.hidden_size,
                config.hidden_size,
                bias=False,
                gather_output=True,
                return_bias=False,
                prefix=f"{prefix}.{idx-1}.altup_unembed_projections",
            ) for idx in range(1, self.config.altup_num_inputs)
        ])

        # Transformer blocks.
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            lambda prefix: Gemma3nDecoderLayer(
                config, cache_config, quant_config, prefix=prefix),
            prefix=f"{prefix}.layers")
        self.norm = RMSNorm(
            config.hidden_size,
            eps=config.rms_norm_eps,
        )
        self.eps = torch.tensor(torch.finfo().min)

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids) * self.embed_scale

    def get_per_layer_input_embeddings(
            self, input_ids: torch.Tensor) -> torch.Tensor:
        # Deal with the fact that vocab_size_per_layer_input < vocab_size
        # which causes us to have some out of vocab tokens by setting
        # those token ids to 0. This matches the HF implementation.
        per_layer_inputs_mask = torch.logical_and(
            input_ids >= 0, input_ids < self.config.vocab_size_per_layer_input)
        per_layer_inputs_tokens = torch.where(per_layer_inputs_mask, input_ids,
                                              torch.zeros_like(input_ids))
        return self.embed_tokens_per_layer(
            per_layer_inputs_tokens) * self.embed_scale_per_layer

    def forward(
        self,
        input_ids: Optional[torch.Tensor],
        positions: torch.Tensor,
        inputs_embeds: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if inputs_embeds is not None:
            hidden_states_0 = inputs_embeds
        else:
            hidden_states_0 = self.get_input_embeddings(input_ids)

        # Per layer inputs.
        if input_ids is None:
            raise ValueError("Passing None for input ids is not supported.")
        per_layer_inputs = self.get_per_layer_input_embeddings(input_ids)
        per_layer_inputs = per_layer_inputs.reshape(
            -1, self.config.num_hidden_layers,
            self.config.hidden_size_per_layer_input)
        per_layer_projection = self.per_layer_model_projection(hidden_states_0)
        per_layer_projection = per_layer_projection.reshape(
            *hidden_states_0.shape[:-1],
            self.config.num_hidden_layers,
            self.config.hidden_size_per_layer_input,
        )
        per_layer_projection = self.per_layer_projection_norm(
            per_layer_projection)
        per_layer_inputs = per_layer_projection + per_layer_inputs
        per_layer_inputs *= self.per_layer_input_scale

        # Altup embed.
        hidden_states = [hidden_states_0] * self.config.altup_num_inputs
        target_magnitude = torch.mean(hidden_states_0**2, dim=-1,
                                      keepdim=True)**0.5
        for i in range(1, self.config.altup_num_inputs):
            hidden_states[i] = self.altup_projections[i - 1](hidden_states[i])
            new_magnitude = torch.mean(hidden_states[i]**2,
                                       dim=-1,
                                       keepdim=True)**0.5
            hidden_states[i] *= target_magnitude / torch.maximum(
                new_magnitude, self.eps)
        hidden_states = torch.stack(hidden_states, dim=0)

        # Transformer blocks.
        for layer_idx, layer in enumerate(self.layers):
            # [altup_num_inputs, num_tokens, hidden_size]
            hidden_states = layer(
                positions=positions,
                hidden_states=hidden_states,
                per_layer_input=per_layer_inputs[:, layer_idx, :],
                **kwargs,
            )

        # Altup unembed.
        target_magnitude = torch.mean(hidden_states[0]**2,
                                      dim=-1,
                                      keepdim=True)**0.5
        for i in range(1, self.config.altup_num_inputs):
            hidden_states[i] = self.altup_unembed_projections[i - 1](
                hidden_states[i])
            new_magnitude = torch.mean(hidden_states[i]**2,
                                       dim=-1,
                                       keepdim=True)**0.5
            hidden_states[i] *= target_magnitude / torch.maximum(
                new_magnitude, self.eps)
        # [altup_num_inputs,num_tokens,hidden_size] -> [num_tokens,hidden_size]
        hidden_states = torch.mean(hidden_states, dim=0)

        return self.norm(hidden_states)

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
        for name, loaded_weight in weights:
            if (self.quant_config is not None and
                (scale_name := self.quant_config.get_cache_scale(name))):
                # Loading kv cache scales for compressed-tensors quantization
                param = params_dict[scale_name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                loaded_weight = loaded_weight[0]
                weight_loader(param, loaded_weight)
                loaded_params.add(scale_name)
                continue
            for (param_name, shard_name, shard_id) in stacked_params_mapping:
                if shard_name not in name:
                    continue
                # Avoid spurious match with ".up_proj".
                if "altup_projections" in name:
                    continue
                name = name.replace(shard_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # Remapping the name of FP8 kv-scale.
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)

        return loaded_params

altup_projections instance-attribute

altup_projections = ModuleList(
    [
        ColumnParallelLinear(
            hidden_size,
            hidden_size,
            bias=False,
            gather_output=True,
            return_bias=False,
            prefix=f"{prefix}.{idx - 1}.altup_projections",
        )
        for idx in range(1, altup_num_inputs)
    ]
)

altup_unembed_projections instance-attribute

altup_unembed_projections = ModuleList(
    [
        ColumnParallelLinear(
            hidden_size,
            hidden_size,
            bias=False,
            gather_output=True,
            return_bias=False,
            prefix=f"{prefix}.{idx - 1}.altup_unembed_projections",
        )
        for idx in range(1, altup_num_inputs)
    ]
)

config instance-attribute

config = config

embed_scale instance-attribute

embed_scale = tensor(hidden_size ** 0.5, dtype=dtype)

embed_scale_per_layer instance-attribute

embed_scale_per_layer = tensor(
    hidden_size_per_layer_input**0.5, dtype=dtype
)

embed_tokens instance-attribute

embed_tokens = VocabParallelEmbedding(
    vocab_size, hidden_size, prefix=f"{prefix}.embed_tokens"
)

embed_tokens_per_layer instance-attribute

embed_tokens_per_layer = VocabParallelEmbedding(
    vocab_size_per_layer_input,
    num_hidden_layers * hidden_size_per_layer_input,
    prefix=f"{prefix}.per_layer_embed_tokens",
)

eps instance-attribute

eps = tensor(min)

norm instance-attribute

norm = RMSNorm(hidden_size, eps=rms_norm_eps)

per_layer_input_scale instance-attribute

per_layer_input_scale = to(dtype)

per_layer_model_projection instance-attribute

per_layer_model_projection = ColumnParallelLinear(
    hidden_size,
    num_hidden_layers * hidden_size_per_layer_input,
    bias=False,
    gather_output=True,
    return_bias=False,
    prefix=f"{prefix}.per_layer_model_projection",
)

per_layer_projection_norm instance-attribute

per_layer_projection_norm = RMSNorm(
    hidden_size=hidden_size_per_layer_input,
    eps=rms_norm_eps,
)

per_layer_projection_scale instance-attribute

per_layer_projection_scale = tensor(
    hidden_size**0.5, dtype=dtype
)

quant_config instance-attribute

quant_config = quant_config

__init__

__init__(*, vllm_config: VllmConfig, prefix: str = '')
Source code in vllm/model_executor/models/gemma3n.py
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
    super().__init__()
    config = vllm_config.model_config.hf_config.text_config
    cache_config = vllm_config.cache_config
    quant_config = vllm_config.quant_config
    self.config = config
    self.quant_config = quant_config

    self.embed_tokens = VocabParallelEmbedding(
        config.vocab_size,
        config.hidden_size,
        prefix=f"{prefix}.embed_tokens",
    )
    self.embed_scale = torch.tensor(
        config.hidden_size**0.5,
        dtype=self.embed_tokens.weight.dtype,
    )
    self.embed_tokens_per_layer = VocabParallelEmbedding(
        config.vocab_size_per_layer_input,
        config.num_hidden_layers * config.hidden_size_per_layer_input,
        prefix=f"{prefix}.per_layer_embed_tokens",
    )
    self.embed_scale_per_layer = torch.tensor(
        config.hidden_size_per_layer_input**0.5,
        dtype=self.embed_tokens.weight.dtype,
    )
    self.per_layer_model_projection = ColumnParallelLinear(
        config.hidden_size,
        config.num_hidden_layers * config.hidden_size_per_layer_input,
        bias=False,
        gather_output=True,
        return_bias=False,
        prefix=f"{prefix}.per_layer_model_projection",
    )
    self.per_layer_projection_norm = RMSNorm(
        hidden_size=config.hidden_size_per_layer_input,
        eps=config.rms_norm_eps,
    )
    self.per_layer_input_scale = torch.rsqrt(torch.tensor(2.0)).to(
        self.embed_tokens.weight.dtype)
    self.per_layer_projection_scale = torch.tensor(
        config.hidden_size**0.5,
        dtype=self.embed_tokens.weight.dtype,
    )
    self.altup_projections = nn.ModuleList([
        ColumnParallelLinear(
            config.hidden_size,
            config.hidden_size,
            bias=False,
            gather_output=True,
            return_bias=False,
            prefix=f"{prefix}.{idx-1}.altup_projections",
        ) for idx in range(1, self.config.altup_num_inputs)
    ])
    self.altup_unembed_projections = nn.ModuleList([
        ColumnParallelLinear(
            config.hidden_size,
            config.hidden_size,
            bias=False,
            gather_output=True,
            return_bias=False,
            prefix=f"{prefix}.{idx-1}.altup_unembed_projections",
        ) for idx in range(1, self.config.altup_num_inputs)
    ])

    # Transformer blocks.
    self.start_layer, self.end_layer, self.layers = make_layers(
        config.num_hidden_layers,
        lambda prefix: Gemma3nDecoderLayer(
            config, cache_config, quant_config, prefix=prefix),
        prefix=f"{prefix}.layers")
    self.norm = RMSNorm(
        config.hidden_size,
        eps=config.rms_norm_eps,
    )
    self.eps = torch.tensor(torch.finfo().min)

forward

forward(
    input_ids: Optional[Tensor],
    positions: Tensor,
    inputs_embeds: Optional[Tensor] = None,
    **kwargs,
) -> Union[Tensor, IntermediateTensors]
Source code in vllm/model_executor/models/gemma3n.py
def forward(
    self,
    input_ids: Optional[torch.Tensor],
    positions: torch.Tensor,
    inputs_embeds: Optional[torch.Tensor] = None,
    **kwargs,
) -> Union[torch.Tensor, IntermediateTensors]:
    if inputs_embeds is not None:
        hidden_states_0 = inputs_embeds
    else:
        hidden_states_0 = self.get_input_embeddings(input_ids)

    # Per layer inputs.
    if input_ids is None:
        raise ValueError("Passing None for input ids is not supported.")
    per_layer_inputs = self.get_per_layer_input_embeddings(input_ids)
    per_layer_inputs = per_layer_inputs.reshape(
        -1, self.config.num_hidden_layers,
        self.config.hidden_size_per_layer_input)
    per_layer_projection = self.per_layer_model_projection(hidden_states_0)
    per_layer_projection = per_layer_projection.reshape(
        *hidden_states_0.shape[:-1],
        self.config.num_hidden_layers,
        self.config.hidden_size_per_layer_input,
    )
    per_layer_projection = self.per_layer_projection_norm(
        per_layer_projection)
    per_layer_inputs = per_layer_projection + per_layer_inputs
    per_layer_inputs *= self.per_layer_input_scale

    # Altup embed.
    hidden_states = [hidden_states_0] * self.config.altup_num_inputs
    target_magnitude = torch.mean(hidden_states_0**2, dim=-1,
                                  keepdim=True)**0.5
    for i in range(1, self.config.altup_num_inputs):
        hidden_states[i] = self.altup_projections[i - 1](hidden_states[i])
        new_magnitude = torch.mean(hidden_states[i]**2,
                                   dim=-1,
                                   keepdim=True)**0.5
        hidden_states[i] *= target_magnitude / torch.maximum(
            new_magnitude, self.eps)
    hidden_states = torch.stack(hidden_states, dim=0)

    # Transformer blocks.
    for layer_idx, layer in enumerate(self.layers):
        # [altup_num_inputs, num_tokens, hidden_size]
        hidden_states = layer(
            positions=positions,
            hidden_states=hidden_states,
            per_layer_input=per_layer_inputs[:, layer_idx, :],
            **kwargs,
        )

    # Altup unembed.
    target_magnitude = torch.mean(hidden_states[0]**2,
                                  dim=-1,
                                  keepdim=True)**0.5
    for i in range(1, self.config.altup_num_inputs):
        hidden_states[i] = self.altup_unembed_projections[i - 1](
            hidden_states[i])
        new_magnitude = torch.mean(hidden_states[i]**2,
                                   dim=-1,
                                   keepdim=True)**0.5
        hidden_states[i] *= target_magnitude / torch.maximum(
            new_magnitude, self.eps)
    # [altup_num_inputs,num_tokens,hidden_size] -> [num_tokens,hidden_size]
    hidden_states = torch.mean(hidden_states, dim=0)

    return self.norm(hidden_states)

get_input_embeddings

get_input_embeddings(input_ids: Tensor) -> Tensor
Source code in vllm/model_executor/models/gemma3n.py
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
    return self.embed_tokens(input_ids) * self.embed_scale

get_per_layer_input_embeddings

get_per_layer_input_embeddings(input_ids: Tensor) -> Tensor
Source code in vllm/model_executor/models/gemma3n.py
def get_per_layer_input_embeddings(
        self, input_ids: torch.Tensor) -> torch.Tensor:
    # Deal with the fact that vocab_size_per_layer_input < vocab_size
    # which causes us to have some out of vocab tokens by setting
    # those token ids to 0. This matches the HF implementation.
    per_layer_inputs_mask = torch.logical_and(
        input_ids >= 0, input_ids < self.config.vocab_size_per_layer_input)
    per_layer_inputs_tokens = torch.where(per_layer_inputs_mask, input_ids,
                                          torch.zeros_like(input_ids))
    return self.embed_tokens_per_layer(
        per_layer_inputs_tokens) * self.embed_scale_per_layer

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]
Source code in vllm/model_executor/models/gemma3n.py
def load_weights(self, weights: Iterable[tuple[str,
                                               torch.Tensor]]) -> set[str]:
    stacked_params_mapping = [
        # (param_name, shard_name, shard_id)
        ("qkv_proj", "q_proj", "q"),
        ("qkv_proj", "k_proj", "k"),
        ("qkv_proj", "v_proj", "v"),
        ("gate_up_proj", "gate_proj", 0),
        ("gate_up_proj", "up_proj", 1),
    ]
    params_dict = dict(self.named_parameters())
    loaded_params: set[str] = set()
    for name, loaded_weight in weights:
        if (self.quant_config is not None and
            (scale_name := self.quant_config.get_cache_scale(name))):
            # Loading kv cache scales for compressed-tensors quantization
            param = params_dict[scale_name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            loaded_weight = loaded_weight[0]
            weight_loader(param, loaded_weight)
            loaded_params.add(scale_name)
            continue
        for (param_name, shard_name, shard_id) in stacked_params_mapping:
            if shard_name not in name:
                continue
            # Avoid spurious match with ".up_proj".
            if "altup_projections" in name:
                continue
            name = name.replace(shard_name, param_name)
            # Skip loading extra bias for GPTQ models.
            if name.endswith(".bias") and name not in params_dict:
                continue
            if is_pp_missing_parameter(name, self):
                continue
            param = params_dict[name]
            weight_loader = param.weight_loader
            weight_loader(param, loaded_weight, shard_id)
            break
        else:
            # Skip loading extra bias for GPTQ models.
            if name.endswith(".bias") and name not in params_dict:
                continue
            # Remapping the name of FP8 kv-scale.
            name = maybe_remap_kv_scale_name(name, params_dict)
            if name is None:
                continue
            if is_pp_missing_parameter(name, self):
                continue
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)
        loaded_params.add(name)

    return loaded_params