Skip to content

vllm.model_executor.models.nemotron

Inference-only Nemotron model compatible with HuggingFace weights.

NemotronAttention

Bases: Module

Source code in vllm/model_executor/models/nemotron.py
class NemotronAttention(nn.Module):

    def __init__(
        self,
        config: NemotronConfig,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[dict[str, Any]] = None,
        max_position_embeddings: int = 8192,
        quant_config: Optional[QuantizationConfig] = None,
        bias: bool = False,
        cache_config: Optional[CacheConfig] = None,
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        # MistralConfig has an optional head_dim introduced by Mistral-Nemo
        self.head_dim = getattr(config, "head_dim", None)
        if self.head_dim is None:
            self.head_dim = self.hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta
        self.partial_rotary_factor = config.partial_rotary_factor
        self.max_position_embeddings = max_position_embeddings

        self.qkv_proj = QKVParallelLinear(
            hidden_size=hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_kv_heads,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )
        self.o_proj = RowParallelLinear(
            input_size=self.total_num_heads * self.head_dim,
            output_size=hidden_size,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
            partial_rotary_factor=self.partial_rotary_factor,
        )
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
                              num_kv_heads=self.num_kv_heads,
                              cache_config=cache_config,
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
        attn_output = self.attn(q, k, v)
        output, _ = self.o_proj(attn_output)
        return output

attn instance-attribute

attn = Attention(
    num_heads,
    head_dim,
    scaling,
    num_kv_heads=num_kv_heads,
    cache_config=cache_config,
    quant_config=quant_config,
    prefix=f"{prefix}.attn",
)

head_dim instance-attribute

head_dim = getattr(config, 'head_dim', None)

hidden_size instance-attribute

hidden_size = hidden_size

kv_size instance-attribute

kv_size = num_kv_heads * head_dim

max_position_embeddings instance-attribute

max_position_embeddings = max_position_embeddings

num_heads instance-attribute

num_heads = total_num_heads // tp_size

num_kv_heads instance-attribute

num_kv_heads = max(1, total_num_kv_heads // tp_size)

o_proj instance-attribute

o_proj = RowParallelLinear(
    input_size=total_num_heads * head_dim,
    output_size=hidden_size,
    bias=bias,
    quant_config=quant_config,
    prefix=f"{prefix}.o_proj",
)

partial_rotary_factor instance-attribute

partial_rotary_factor = partial_rotary_factor

q_size instance-attribute

q_size = num_heads * head_dim

qkv_proj instance-attribute

qkv_proj = QKVParallelLinear(
    hidden_size=hidden_size,
    head_size=head_dim,
    total_num_heads=total_num_heads,
    total_num_kv_heads=total_num_kv_heads,
    bias=bias,
    quant_config=quant_config,
    prefix=f"{prefix}.qkv_proj",
)

rope_theta instance-attribute

rope_theta = rope_theta

rotary_emb instance-attribute

rotary_emb = get_rope(
    head_dim,
    rotary_dim=head_dim,
    max_position=max_position_embeddings,
    base=rope_theta,
    rope_scaling=rope_scaling,
    partial_rotary_factor=partial_rotary_factor,
)

scaling instance-attribute

scaling = head_dim ** -0.5

total_num_heads instance-attribute

total_num_heads = num_heads

total_num_kv_heads instance-attribute

total_num_kv_heads = num_kv_heads

__init__

__init__(
    config: NemotronConfig,
    hidden_size: int,
    num_heads: int,
    num_kv_heads: int,
    rope_theta: float = 10000,
    rope_scaling: Optional[dict[str, Any]] = None,
    max_position_embeddings: int = 8192,
    quant_config: Optional[QuantizationConfig] = None,
    bias: bool = False,
    cache_config: Optional[CacheConfig] = None,
    prefix: str = "",
) -> None
Source code in vllm/model_executor/models/nemotron.py
def __init__(
    self,
    config: NemotronConfig,
    hidden_size: int,
    num_heads: int,
    num_kv_heads: int,
    rope_theta: float = 10000,
    rope_scaling: Optional[dict[str, Any]] = None,
    max_position_embeddings: int = 8192,
    quant_config: Optional[QuantizationConfig] = None,
    bias: bool = False,
    cache_config: Optional[CacheConfig] = None,
    prefix: str = "",
) -> None:
    super().__init__()
    self.hidden_size = hidden_size
    tp_size = get_tensor_model_parallel_world_size()
    self.total_num_heads = num_heads
    assert self.total_num_heads % tp_size == 0
    self.num_heads = self.total_num_heads // tp_size
    self.total_num_kv_heads = num_kv_heads
    if self.total_num_kv_heads >= tp_size:
        # Number of KV heads is greater than TP size, so we partition
        # the KV heads across multiple tensor parallel GPUs.
        assert self.total_num_kv_heads % tp_size == 0
    else:
        # Number of KV heads is less than TP size, so we replicate
        # the KV heads across multiple tensor parallel GPUs.
        assert tp_size % self.total_num_kv_heads == 0
    self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
    # MistralConfig has an optional head_dim introduced by Mistral-Nemo
    self.head_dim = getattr(config, "head_dim", None)
    if self.head_dim is None:
        self.head_dim = self.hidden_size // self.total_num_heads
    self.q_size = self.num_heads * self.head_dim
    self.kv_size = self.num_kv_heads * self.head_dim
    self.scaling = self.head_dim**-0.5
    self.rope_theta = rope_theta
    self.partial_rotary_factor = config.partial_rotary_factor
    self.max_position_embeddings = max_position_embeddings

    self.qkv_proj = QKVParallelLinear(
        hidden_size=hidden_size,
        head_size=self.head_dim,
        total_num_heads=self.total_num_heads,
        total_num_kv_heads=self.total_num_kv_heads,
        bias=bias,
        quant_config=quant_config,
        prefix=f"{prefix}.qkv_proj",
    )
    self.o_proj = RowParallelLinear(
        input_size=self.total_num_heads * self.head_dim,
        output_size=hidden_size,
        bias=bias,
        quant_config=quant_config,
        prefix=f"{prefix}.o_proj",
    )

    self.rotary_emb = get_rope(
        self.head_dim,
        rotary_dim=self.head_dim,
        max_position=max_position_embeddings,
        base=rope_theta,
        rope_scaling=rope_scaling,
        partial_rotary_factor=self.partial_rotary_factor,
    )
    self.attn = Attention(self.num_heads,
                          self.head_dim,
                          self.scaling,
                          num_kv_heads=self.num_kv_heads,
                          cache_config=cache_config,
                          quant_config=quant_config,
                          prefix=f"{prefix}.attn")

forward

forward(positions: Tensor, hidden_states: Tensor) -> Tensor
Source code in vllm/model_executor/models/nemotron.py
def forward(
    self,
    positions: torch.Tensor,
    hidden_states: torch.Tensor,
) -> torch.Tensor:
    qkv, _ = self.qkv_proj(hidden_states)
    q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
    q, k = self.rotary_emb(positions, q, k)
    attn_output = self.attn(q, k, v)
    output, _ = self.o_proj(attn_output)
    return output

NemotronDecoderLayer

Bases: Module

Source code in vllm/model_executor/models/nemotron.py
class NemotronDecoderLayer(nn.Module):

    def __init__(
        self,
        config: NemotronConfig,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        if rope_scaling is not None and getattr(
                config, "original_max_position_embeddings", None):
            rope_scaling["original_max_position_embeddings"] = (
                config.original_max_position_embeddings)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
        # Support abacusai/Smaug-72B-v0.1 with attention_bias
        # Support internlm/internlm-7b with bias
        attention_bias = getattr(config, "attention_bias", False) or getattr(
            config, "bias", False)
        self.self_attn = NemotronAttention(
            config=config,
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=getattr(config, "num_key_value_heads",
                                 config.num_attention_heads),
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
            quant_config=quant_config,
            bias=attention_bias,
            cache_config=cache_config,
            prefix=f"{prefix}.self_attn",
        )
        self.mlp = NemotronMLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
            quant_config=quant_config,
            bias=getattr(config, "mlp_bias", False),
            prefix=f"{prefix}.mlp",
        )
        self.input_layernorm = NemotronLayerNorm1P(config.hidden_size,
                                                   eps=config.norm_eps)
        self.post_attention_layernorm = NemotronLayerNorm1P(
            config.hidden_size, eps=config.norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
    ) -> tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual

hidden_size instance-attribute

hidden_size = hidden_size

input_layernorm instance-attribute

input_layernorm = NemotronLayerNorm1P(
    hidden_size, eps=norm_eps
)

mlp instance-attribute

mlp = NemotronMLP(
    hidden_size=hidden_size,
    intermediate_size=intermediate_size,
    hidden_act=hidden_act,
    quant_config=quant_config,
    bias=getattr(config, "mlp_bias", False),
    prefix=f"{prefix}.mlp",
)

post_attention_layernorm instance-attribute

post_attention_layernorm = NemotronLayerNorm1P(
    hidden_size, eps=norm_eps
)

self_attn instance-attribute

self_attn = NemotronAttention(
    config=config,
    hidden_size=hidden_size,
    num_heads=num_attention_heads,
    num_kv_heads=getattr(
        config, "num_key_value_heads", num_attention_heads
    ),
    rope_theta=rope_theta,
    rope_scaling=rope_scaling,
    max_position_embeddings=max_position_embeddings,
    quant_config=quant_config,
    bias=attention_bias,
    cache_config=cache_config,
    prefix=f"{prefix}.self_attn",
)

__init__

__init__(
    config: NemotronConfig,
    cache_config: Optional[CacheConfig] = None,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
) -> None
Source code in vllm/model_executor/models/nemotron.py
def __init__(
    self,
    config: NemotronConfig,
    cache_config: Optional[CacheConfig] = None,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
) -> None:
    super().__init__()
    self.hidden_size = config.hidden_size
    rope_theta = getattr(config, "rope_theta", 10000)
    rope_scaling = getattr(config, "rope_scaling", None)
    if rope_scaling is not None and getattr(
            config, "original_max_position_embeddings", None):
        rope_scaling["original_max_position_embeddings"] = (
            config.original_max_position_embeddings)
    max_position_embeddings = getattr(config, "max_position_embeddings",
                                      8192)
    # Support abacusai/Smaug-72B-v0.1 with attention_bias
    # Support internlm/internlm-7b with bias
    attention_bias = getattr(config, "attention_bias", False) or getattr(
        config, "bias", False)
    self.self_attn = NemotronAttention(
        config=config,
        hidden_size=self.hidden_size,
        num_heads=config.num_attention_heads,
        num_kv_heads=getattr(config, "num_key_value_heads",
                             config.num_attention_heads),
        rope_theta=rope_theta,
        rope_scaling=rope_scaling,
        max_position_embeddings=max_position_embeddings,
        quant_config=quant_config,
        bias=attention_bias,
        cache_config=cache_config,
        prefix=f"{prefix}.self_attn",
    )
    self.mlp = NemotronMLP(
        hidden_size=self.hidden_size,
        intermediate_size=config.intermediate_size,
        hidden_act=config.hidden_act,
        quant_config=quant_config,
        bias=getattr(config, "mlp_bias", False),
        prefix=f"{prefix}.mlp",
    )
    self.input_layernorm = NemotronLayerNorm1P(config.hidden_size,
                                               eps=config.norm_eps)
    self.post_attention_layernorm = NemotronLayerNorm1P(
        config.hidden_size, eps=config.norm_eps)

forward

forward(
    positions: Tensor,
    hidden_states: Tensor,
    residual: Optional[Tensor],
) -> tuple[Tensor, Tensor]
Source code in vllm/model_executor/models/nemotron.py
def forward(
    self,
    positions: torch.Tensor,
    hidden_states: torch.Tensor,
    residual: Optional[torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor]:
    # Self Attention
    if residual is None:
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
    else:
        hidden_states, residual = self.input_layernorm(
            hidden_states, residual)
    hidden_states = self.self_attn(
        positions=positions,
        hidden_states=hidden_states,
    )

    # Fully Connected
    hidden_states, residual = self.post_attention_layernorm(
        hidden_states, residual)
    hidden_states = self.mlp(hidden_states)
    return hidden_states, residual

NemotronForCausalLM

Bases: Module, SupportsLoRA, SupportsPP

Source code in vllm/model_executor/models/nemotron.py
class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
    }

    # LoRA specific attributes
    embedding_modules = {
        "embed_tokens": "input_embeddings",
        "lm_head": "output_embeddings",
    }
    embedding_padding_modules = ["lm_head"]

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
        assert isinstance(config, NemotronConfig)

        self.config = config
        self.lora_config = lora_config
        self.quant_config = quant_config

        self.model = NemotronModel(vllm_config=vllm_config,
                                   prefix=maybe_prefix(prefix, "model"))
        if get_pp_group().is_last_rank:
            self.unpadded_vocab_size = config.vocab_size
            if lora_config:
                self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
            self.lm_head = ParallelLMHead(
                self.unpadded_vocab_size,
                config.hidden_size,
                org_num_embeddings=config.vocab_size,
                padding_size=DEFAULT_VOCAB_PADDING_SIZE
                # We need bigger padding if using lora for kernel
                # compatibility
                if not lora_config else lora_config.lora_vocab_padding_size,
                quant_config=quant_config,
            )
            if config.tie_word_embeddings:
                self.lm_head.weight = self.model.embed_tokens.weight

            logit_scale = getattr(config, "logit_scale", 1.0)
            self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                    config.vocab_size,
                                                    logit_scale)
        else:
            self.lm_head = PPMissingLayer()

        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        model_output = self.model(input_ids, positions, intermediate_tensors,
                                  inputs_embeds)
        return model_output

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
        logits = self.logits_processor(self.lm_head, hidden_states,
                                       sampling_metadata)
        return logits

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)

config instance-attribute

config = config

embedding_modules class-attribute instance-attribute

embedding_modules = {
    "embed_tokens": "input_embeddings",
    "lm_head": "output_embeddings",
}

embedding_padding_modules class-attribute instance-attribute

embedding_padding_modules = ['lm_head']

lm_head instance-attribute

lm_head = ParallelLMHead(
    unpadded_vocab_size,
    hidden_size,
    org_num_embeddings=vocab_size,
    padding_size=DEFAULT_VOCAB_PADDING_SIZE
    if not lora_config
    else lora_vocab_padding_size,
    quant_config=quant_config,
)

logits_processor instance-attribute

logits_processor = LogitsProcessor(
    unpadded_vocab_size, vocab_size, logit_scale
)

lora_config instance-attribute

lora_config = lora_config

make_empty_intermediate_tensors instance-attribute

make_empty_intermediate_tensors = (
    make_empty_intermediate_tensors
)

model instance-attribute

model = NemotronModel(
    vllm_config=vllm_config,
    prefix=maybe_prefix(prefix, "model"),
)

packed_modules_mapping class-attribute instance-attribute

packed_modules_mapping = {
    "qkv_proj": ["q_proj", "k_proj", "v_proj"]
}

quant_config instance-attribute

quant_config = quant_config

unpadded_vocab_size instance-attribute

unpadded_vocab_size = vocab_size

__init__

__init__(*, vllm_config: VllmConfig, prefix: str = '')
Source code in vllm/model_executor/models/nemotron.py
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
    super().__init__()
    config = vllm_config.model_config.hf_config
    quant_config = vllm_config.quant_config
    lora_config = vllm_config.lora_config
    assert isinstance(config, NemotronConfig)

    self.config = config
    self.lora_config = lora_config
    self.quant_config = quant_config

    self.model = NemotronModel(vllm_config=vllm_config,
                               prefix=maybe_prefix(prefix, "model"))
    if get_pp_group().is_last_rank:
        self.unpadded_vocab_size = config.vocab_size
        if lora_config:
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
        self.lm_head = ParallelLMHead(
            self.unpadded_vocab_size,
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
            padding_size=DEFAULT_VOCAB_PADDING_SIZE
            # We need bigger padding if using lora for kernel
            # compatibility
            if not lora_config else lora_config.lora_vocab_padding_size,
            quant_config=quant_config,
        )
        if config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight

        logit_scale = getattr(config, "logit_scale", 1.0)
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size,
                                                logit_scale)
    else:
        self.lm_head = PPMissingLayer()

    self.make_empty_intermediate_tensors = (
        self.model.make_empty_intermediate_tensors)

compute_logits

compute_logits(
    hidden_states: Tensor,
    sampling_metadata: SamplingMetadata,
) -> Optional[Tensor]
Source code in vllm/model_executor/models/nemotron.py
def compute_logits(
    self,
    hidden_states: torch.Tensor,
    sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
    logits = self.logits_processor(self.lm_head, hidden_states,
                                   sampling_metadata)
    return logits

forward

forward(
    input_ids: Tensor,
    positions: Tensor,
    intermediate_tensors: Optional[
        IntermediateTensors
    ] = None,
    inputs_embeds: Optional[Tensor] = None,
) -> Union[Tensor, IntermediateTensors]
Source code in vllm/model_executor/models/nemotron.py
def forward(
    self,
    input_ids: torch.Tensor,
    positions: torch.Tensor,
    intermediate_tensors: Optional[IntermediateTensors] = None,
    inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
    model_output = self.model(input_ids, positions, intermediate_tensors,
                              inputs_embeds)
    return model_output

get_input_embeddings

get_input_embeddings(input_ids: Tensor) -> Tensor
Source code in vllm/model_executor/models/nemotron.py
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
    return self.model.get_input_embeddings(input_ids)

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]
Source code in vllm/model_executor/models/nemotron.py
def load_weights(self, weights: Iterable[tuple[str,
                                               torch.Tensor]]) -> set[str]:
    loader = AutoWeightsLoader(self)
    return loader.load_weights(weights)

NemotronLayerNorm1P

Bases: LayerNorm

Source code in vllm/model_executor/models/nemotron.py
class NemotronLayerNorm1P(nn.LayerNorm):

    def __init__(self,
                 normalized_shape: Union[int, list[int], torch.Size],
                 eps: float = 1e-5,
                 elementwise_affine: bool = True,
                 bias: bool = True,
                 device=None,
                 dtype=None):
        super().__init__(normalized_shape, eps, elementwise_affine, bias,
                         device, dtype)

    def forward(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        if residual is not None:
            x = x + residual
            residual = x
        args = _cast_if_autocast_enabled(x, self.normalized_shape,
                                         self.weight + 1, self.bias, self.eps)
        with torch.amp.autocast("cuda", enabled=False):
            x = torch.nn.functional.layer_norm(*args)
            return x if residual is None else (x, residual)

__init__

__init__(
    normalized_shape: Union[int, list[int], Size],
    eps: float = 1e-05,
    elementwise_affine: bool = True,
    bias: bool = True,
    device=None,
    dtype=None,
)
Source code in vllm/model_executor/models/nemotron.py
def __init__(self,
             normalized_shape: Union[int, list[int], torch.Size],
             eps: float = 1e-5,
             elementwise_affine: bool = True,
             bias: bool = True,
             device=None,
             dtype=None):
    super().__init__(normalized_shape, eps, elementwise_affine, bias,
                     device, dtype)

forward

forward(
    x: Tensor, residual: Optional[Tensor] = None
) -> Tensor
Source code in vllm/model_executor/models/nemotron.py
def forward(
    self,
    x: torch.Tensor,
    residual: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    if residual is not None:
        x = x + residual
        residual = x
    args = _cast_if_autocast_enabled(x, self.normalized_shape,
                                     self.weight + 1, self.bias, self.eps)
    with torch.amp.autocast("cuda", enabled=False):
        x = torch.nn.functional.layer_norm(*args)
        return x if residual is None else (x, residual)

NemotronMLP

Bases: Module

Source code in vllm/model_executor/models/nemotron.py
class NemotronMLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
        quant_config: Optional[QuantizationConfig] = None,
        bias: bool = False,
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.up_proj = ColumnParallelLinear(input_size=hidden_size,
                                            output_size=intermediate_size,
                                            bias=bias,
                                            quant_config=quant_config,
                                            prefix=f"{prefix}.up_proj")
        self.down_proj = RowParallelLinear(input_size=intermediate_size,
                                           output_size=hidden_size,
                                           bias=bias,
                                           quant_config=quant_config,
                                           prefix=f"{prefix}.down_proj")
        self.act_fn = get_act_fn(hidden_act)

    def forward(self, x):
        up, _ = self.up_proj(x)
        x = self.act_fn(up)
        x, _ = self.down_proj(x)
        return x

act_fn instance-attribute

act_fn = get_act_fn(hidden_act)

down_proj instance-attribute

down_proj = RowParallelLinear(
    input_size=intermediate_size,
    output_size=hidden_size,
    bias=bias,
    quant_config=quant_config,
    prefix=f"{prefix}.down_proj",
)

up_proj instance-attribute

up_proj = ColumnParallelLinear(
    input_size=hidden_size,
    output_size=intermediate_size,
    bias=bias,
    quant_config=quant_config,
    prefix=f"{prefix}.up_proj",
)

__init__

__init__(
    hidden_size: int,
    intermediate_size: int,
    hidden_act: str,
    quant_config: Optional[QuantizationConfig] = None,
    bias: bool = False,
    prefix: str = "",
) -> None
Source code in vllm/model_executor/models/nemotron.py
def __init__(
    self,
    hidden_size: int,
    intermediate_size: int,
    hidden_act: str,
    quant_config: Optional[QuantizationConfig] = None,
    bias: bool = False,
    prefix: str = "",
) -> None:
    super().__init__()
    self.up_proj = ColumnParallelLinear(input_size=hidden_size,
                                        output_size=intermediate_size,
                                        bias=bias,
                                        quant_config=quant_config,
                                        prefix=f"{prefix}.up_proj")
    self.down_proj = RowParallelLinear(input_size=intermediate_size,
                                       output_size=hidden_size,
                                       bias=bias,
                                       quant_config=quant_config,
                                       prefix=f"{prefix}.down_proj")
    self.act_fn = get_act_fn(hidden_act)

forward

forward(x)
Source code in vllm/model_executor/models/nemotron.py
def forward(self, x):
    up, _ = self.up_proj(x)
    x = self.act_fn(up)
    x, _ = self.down_proj(x)
    return x

NemotronModel

Bases: Module

Source code in vllm/model_executor/models/nemotron.py
@support_torch_compile
class NemotronModel(nn.Module):

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config

        self.config = config
        self.quant_config = quant_config
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size
        if get_pp_group().is_first_rank or (config.tie_word_embeddings
                                            and get_pp_group().is_last_rank):
            self.embed_tokens = VocabParallelEmbedding(
                self.vocab_size,
                config.hidden_size,
                org_num_embeddings=config.vocab_size,
            )
        else:
            self.embed_tokens = PPMissingLayer()
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            lambda prefix: NemotronDecoderLayer(config=config,
                                                cache_config=cache_config,
                                                quant_config=quant_config,
                                                prefix=prefix),
            prefix=f"{prefix}.layers")
        if get_pp_group().is_last_rank:
            self.norm = NemotronLayerNorm1P(config.hidden_size,
                                            eps=config.norm_eps)
        else:
            self.norm = PPMissingLayer()
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

    def forward(
        self,
        input_ids: Optional[torch.Tensor],
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors],
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

        for layer in self.layers[self.start_layer:self.end_layer]:
            hidden_states, residual = layer(positions, hidden_states, residual)

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })

        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
        for name, loaded_weight in weights:
            if (self.quant_config is not None and
                (scale_name := self.quant_config.get_cache_scale(name))):
                # Loading kv cache quantization scales
                param = params_dict[scale_name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
                                 loaded_weight[0])
                weight_loader(param, loaded_weight)
                loaded_params.add(scale_name)
                continue
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)

                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # Remapping the name of FP8 kv-scale.
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

config instance-attribute

config = config

embed_tokens instance-attribute

embed_tokens = VocabParallelEmbedding(
    vocab_size, hidden_size, org_num_embeddings=vocab_size
)

make_empty_intermediate_tensors instance-attribute

make_empty_intermediate_tensors = (
    make_empty_intermediate_tensors_factory(
        ["hidden_states", "residual"], hidden_size
    )
)

norm instance-attribute

norm = NemotronLayerNorm1P(hidden_size, eps=norm_eps)

org_vocab_size instance-attribute

org_vocab_size = vocab_size

quant_config instance-attribute

quant_config = quant_config

vocab_size instance-attribute

vocab_size = vocab_size + lora_vocab

__init__

__init__(*, vllm_config: VllmConfig, prefix: str = '')
Source code in vllm/model_executor/models/nemotron.py
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
    super().__init__()

    config = vllm_config.model_config.hf_config
    cache_config = vllm_config.cache_config
    quant_config = vllm_config.quant_config
    lora_config = vllm_config.lora_config

    self.config = config
    self.quant_config = quant_config
    lora_vocab = (lora_config.lora_extra_vocab_size *
                  (lora_config.max_loras or 1)) if lora_config else 0
    self.vocab_size = config.vocab_size + lora_vocab
    self.org_vocab_size = config.vocab_size
    if get_pp_group().is_first_rank or (config.tie_word_embeddings
                                        and get_pp_group().is_last_rank):
        self.embed_tokens = VocabParallelEmbedding(
            self.vocab_size,
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
        )
    else:
        self.embed_tokens = PPMissingLayer()
    self.start_layer, self.end_layer, self.layers = make_layers(
        config.num_hidden_layers,
        lambda prefix: NemotronDecoderLayer(config=config,
                                            cache_config=cache_config,
                                            quant_config=quant_config,
                                            prefix=prefix),
        prefix=f"{prefix}.layers")
    if get_pp_group().is_last_rank:
        self.norm = NemotronLayerNorm1P(config.hidden_size,
                                        eps=config.norm_eps)
    else:
        self.norm = PPMissingLayer()
    self.make_empty_intermediate_tensors = (
        make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size))

forward

forward(
    input_ids: Optional[Tensor],
    positions: Tensor,
    intermediate_tensors: Optional[IntermediateTensors],
    inputs_embeds: Optional[Tensor] = None,
) -> Union[Tensor, IntermediateTensors]
Source code in vllm/model_executor/models/nemotron.py
def forward(
    self,
    input_ids: Optional[torch.Tensor],
    positions: torch.Tensor,
    intermediate_tensors: Optional[IntermediateTensors],
    inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
    if get_pp_group().is_first_rank:
        if inputs_embeds is not None:
            hidden_states = inputs_embeds
        else:
            hidden_states = self.get_input_embeddings(input_ids)
        residual = None
    else:
        assert intermediate_tensors is not None
        hidden_states = intermediate_tensors["hidden_states"]
        residual = intermediate_tensors["residual"]

    for layer in self.layers[self.start_layer:self.end_layer]:
        hidden_states, residual = layer(positions, hidden_states, residual)

    if not get_pp_group().is_last_rank:
        return IntermediateTensors({
            "hidden_states": hidden_states,
            "residual": residual
        })

    hidden_states, _ = self.norm(hidden_states, residual)
    return hidden_states

get_input_embeddings

get_input_embeddings(input_ids: Tensor) -> Tensor
Source code in vllm/model_executor/models/nemotron.py
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
    return self.embed_tokens(input_ids)

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]
Source code in vllm/model_executor/models/nemotron.py
def load_weights(self, weights: Iterable[tuple[str,
                                               torch.Tensor]]) -> set[str]:
    stacked_params_mapping = [
        # (param_name, shard_name, shard_id)
        (".qkv_proj", ".q_proj", "q"),
        (".qkv_proj", ".k_proj", "k"),
        (".qkv_proj", ".v_proj", "v"),
    ]
    params_dict = dict(self.named_parameters())
    loaded_params: set[str] = set()
    for name, loaded_weight in weights:
        if (self.quant_config is not None and
            (scale_name := self.quant_config.get_cache_scale(name))):
            # Loading kv cache quantization scales
            param = params_dict[scale_name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
                             loaded_weight[0])
            weight_loader(param, loaded_weight)
            loaded_params.add(scale_name)
            continue
        for (param_name, weight_name, shard_id) in stacked_params_mapping:
            if weight_name not in name:
                continue
            name = name.replace(weight_name, param_name)
            # Skip loading extra bias for GPTQ models.
            if name.endswith(".bias") and name not in params_dict:
                continue

            if is_pp_missing_parameter(name, self):
                continue

            param = params_dict[name]
            weight_loader = param.weight_loader
            weight_loader(param, loaded_weight, shard_id)

            break
        else:
            # Skip loading extra bias for GPTQ models.
            if name.endswith(".bias") and name not in params_dict:
                continue
            # Remapping the name of FP8 kv-scale.
            name = maybe_remap_kv_scale_name(name, params_dict)
            if name is None:
                continue

            if is_pp_missing_parameter(name, self):
                continue

            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)
        loaded_params.add(name)
    return loaded_params

_cast_if_autocast_enabled

_cast_if_autocast_enabled(*args)
Source code in vllm/model_executor/models/nemotron.py
def _cast_if_autocast_enabled(*args):
    if not torch.is_autocast_enabled():
        return args
    else:
        return torch.amp.autocast_mode._cast(
            args, device_type="cuda", dtype=torch.get_autocast_gpu_dtype())