Skip to content

vllm.model_executor.models.internlm2

InternLM2Attention

Bases: Module

Source code in vllm/model_executor/models/internlm2.py
class InternLM2Attention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[dict[str, Any]] = None,
        max_position_embeddings: int = 8192,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        self.tp_size = get_tensor_model_parallel_world_size()
        self.tp_rank = get_tensor_model_parallel_rank()
        self.total_num_heads = num_heads
        assert self.total_num_heads % self.tp_size == 0
        self.num_heads = self.total_num_heads // self.tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= self.tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % self.tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert self.tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.key_value_groups = int(self.num_heads / self.num_kv_heads)
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

        self.wqkv = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.wqkv",
        )
        self.wo = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.wo",
        )

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )

    def split_qkv(self, qkv: torch.Tensor):
        seq_len = qkv.shape[0]
        if self.tp_size > 1:
            qkv_map = [self.q_size, self.kv_size, self.kv_size] * self.tp_size
            qkv = tensor_model_parallel_all_gather(qkv)
            qkv = torch.split(qkv, qkv_map, dim=-1)
            qkv = qkv[::3] + qkv[1::3] + qkv[2::3]
            qkv = torch.cat(qkv, dim=-1)

        qkv = qkv.view(seq_len, self.total_num_kv_heads,
                       self.key_value_groups + 2, self.head_dim)
        q, k, v = torch.split(qkv, [self.key_value_groups, 1, 1], dim=-2)
        q = q.reshape(seq_len, self.q_size * self.tp_size)
        k = k.reshape(seq_len, self.kv_size * self.tp_size)
        v = v.reshape(seq_len, self.kv_size * self.tp_size)

        if self.tp_size > 1:
            splitter = partial(split_tensor_along_last_dim,
                               num_partitions=self.tp_size)
            q = splitter(q)[self.tp_rank]
            k = splitter(k)[self.tp_rank]
            v = splitter(v)[self.tp_rank]
        return q, k, v

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.wqkv(hidden_states)
        q, k, v = self.split_qkv(qkv)
        q, k = self.rotary_emb(positions, q, k)
        attn_output = self.attn(q, k, v)
        output, _ = self.wo(attn_output)
        return output

attn instance-attribute

attn = Attention(
    num_heads,
    head_dim,
    scaling,
    num_kv_heads=num_kv_heads,
    cache_config=cache_config,
    quant_config=quant_config,
    prefix=f"{prefix}.attn",
)

head_dim instance-attribute

head_dim = hidden_size // total_num_heads

hidden_size instance-attribute

hidden_size = hidden_size

key_value_groups instance-attribute

key_value_groups = int(num_heads / num_kv_heads)

kv_size instance-attribute

kv_size = num_kv_heads * head_dim

max_position_embeddings instance-attribute

max_position_embeddings = max_position_embeddings

num_heads instance-attribute

num_heads = total_num_heads // tp_size

num_kv_heads instance-attribute

num_kv_heads = max(1, total_num_kv_heads // tp_size)

q_size instance-attribute

q_size = num_heads * head_dim

rope_theta instance-attribute

rope_theta = rope_theta

rotary_emb instance-attribute

rotary_emb = get_rope(
    head_dim,
    rotary_dim=head_dim,
    max_position=max_position_embeddings,
    base=rope_theta,
    rope_scaling=rope_scaling,
)

scaling instance-attribute

scaling = head_dim ** -0.5

total_num_heads instance-attribute

total_num_heads = num_heads

total_num_kv_heads instance-attribute

total_num_kv_heads = num_kv_heads

tp_rank instance-attribute

tp_size instance-attribute

wo instance-attribute

wo = RowParallelLinear(
    total_num_heads * head_dim,
    hidden_size,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.wo",
)

wqkv instance-attribute

wqkv = QKVParallelLinear(
    hidden_size,
    head_dim,
    total_num_heads,
    total_num_kv_heads,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.wqkv",
)

__init__

__init__(
    hidden_size: int,
    num_heads: int,
    num_kv_heads: int,
    rope_theta: float = 10000,
    rope_scaling: Optional[dict[str, Any]] = None,
    max_position_embeddings: int = 8192,
    cache_config: Optional[CacheConfig] = None,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
) -> None
Source code in vllm/model_executor/models/internlm2.py
def __init__(
    self,
    hidden_size: int,
    num_heads: int,
    num_kv_heads: int,
    rope_theta: float = 10000,
    rope_scaling: Optional[dict[str, Any]] = None,
    max_position_embeddings: int = 8192,
    cache_config: Optional[CacheConfig] = None,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
) -> None:
    super().__init__()
    self.hidden_size = hidden_size
    self.tp_size = get_tensor_model_parallel_world_size()
    self.tp_rank = get_tensor_model_parallel_rank()
    self.total_num_heads = num_heads
    assert self.total_num_heads % self.tp_size == 0
    self.num_heads = self.total_num_heads // self.tp_size
    self.total_num_kv_heads = num_kv_heads
    if self.total_num_kv_heads >= self.tp_size:
        # Number of KV heads is greater than TP size, so we partition
        # the KV heads across multiple tensor parallel GPUs.
        assert self.total_num_kv_heads % self.tp_size == 0
    else:
        # Number of KV heads is less than TP size, so we replicate
        # the KV heads across multiple tensor parallel GPUs.
        assert self.tp_size % self.total_num_kv_heads == 0
    self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size)
    self.head_dim = hidden_size // self.total_num_heads
    self.q_size = self.num_heads * self.head_dim
    self.kv_size = self.num_kv_heads * self.head_dim
    self.key_value_groups = int(self.num_heads / self.num_kv_heads)
    self.scaling = self.head_dim**-0.5
    self.rope_theta = rope_theta
    self.max_position_embeddings = max_position_embeddings

    self.wqkv = QKVParallelLinear(
        hidden_size,
        self.head_dim,
        self.total_num_heads,
        self.total_num_kv_heads,
        bias=False,
        quant_config=quant_config,
        prefix=f"{prefix}.wqkv",
    )
    self.wo = RowParallelLinear(
        self.total_num_heads * self.head_dim,
        hidden_size,
        bias=False,
        quant_config=quant_config,
        prefix=f"{prefix}.wo",
    )

    self.rotary_emb = get_rope(
        self.head_dim,
        rotary_dim=self.head_dim,
        max_position=max_position_embeddings,
        base=rope_theta,
        rope_scaling=rope_scaling,
    )
    self.attn = Attention(
        self.num_heads,
        self.head_dim,
        self.scaling,
        num_kv_heads=self.num_kv_heads,
        cache_config=cache_config,
        quant_config=quant_config,
        prefix=f"{prefix}.attn",
    )

forward

forward(positions: Tensor, hidden_states: Tensor) -> Tensor
Source code in vllm/model_executor/models/internlm2.py
def forward(
    self,
    positions: torch.Tensor,
    hidden_states: torch.Tensor,
) -> torch.Tensor:
    qkv, _ = self.wqkv(hidden_states)
    q, k, v = self.split_qkv(qkv)
    q, k = self.rotary_emb(positions, q, k)
    attn_output = self.attn(q, k, v)
    output, _ = self.wo(attn_output)
    return output

split_qkv

split_qkv(qkv: Tensor)
Source code in vllm/model_executor/models/internlm2.py
def split_qkv(self, qkv: torch.Tensor):
    seq_len = qkv.shape[0]
    if self.tp_size > 1:
        qkv_map = [self.q_size, self.kv_size, self.kv_size] * self.tp_size
        qkv = tensor_model_parallel_all_gather(qkv)
        qkv = torch.split(qkv, qkv_map, dim=-1)
        qkv = qkv[::3] + qkv[1::3] + qkv[2::3]
        qkv = torch.cat(qkv, dim=-1)

    qkv = qkv.view(seq_len, self.total_num_kv_heads,
                   self.key_value_groups + 2, self.head_dim)
    q, k, v = torch.split(qkv, [self.key_value_groups, 1, 1], dim=-2)
    q = q.reshape(seq_len, self.q_size * self.tp_size)
    k = k.reshape(seq_len, self.kv_size * self.tp_size)
    v = v.reshape(seq_len, self.kv_size * self.tp_size)

    if self.tp_size > 1:
        splitter = partial(split_tensor_along_last_dim,
                           num_partitions=self.tp_size)
        q = splitter(q)[self.tp_rank]
        k = splitter(k)[self.tp_rank]
        v = splitter(v)[self.tp_rank]
    return q, k, v

InternLM2ForCausalLM

Bases: Module, SupportsPP, SupportsLoRA

Source code in vllm/model_executor/models/internlm2.py
class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
    packed_modules_mapping = {
        "wqkv": ["wqkv"],
        "gate_up_proj": ["w1", "w3"],
    }

    def __init__(self,
                 *,
                 vllm_config: VllmConfig,
                 prefix: str = "",
                 model_type: type[InternLM2Model] = InternLM2Model):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config

        self.config = config
        self.quant_config = quant_config
        self.lora_config = lora_config

        self.model = model_type(vllm_config=vllm_config,
                                prefix=maybe_prefix(prefix, "model"))
        self.output = ParallelLMHead(config.vocab_size,
                                     config.hidden_size,
                                     quant_config=quant_config,
                                     prefix=maybe_prefix(prefix, "output"))
        if self.config.tie_word_embeddings:
            self.output.weight = self.model.tok_embeddings.weight
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors],
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        hidden_states = self.model(input_ids, positions, intermediate_tensors,
                                   inputs_embeds)
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
        logits = self.logits_processor(self.output, hidden_states,
                                       sampling_metadata)
        return logits

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "w1", 0),
            ("gate_up_proj", "w3", 1),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

config instance-attribute

config = config

logits_processor instance-attribute

logits_processor = LogitsProcessor(vocab_size)

lora_config instance-attribute

lora_config = lora_config

make_empty_intermediate_tensors instance-attribute

make_empty_intermediate_tensors = (
    make_empty_intermediate_tensors
)

model instance-attribute

model = model_type(
    vllm_config=vllm_config,
    prefix=maybe_prefix(prefix, "model"),
)

output instance-attribute

output = ParallelLMHead(
    vocab_size,
    hidden_size,
    quant_config=quant_config,
    prefix=maybe_prefix(prefix, "output"),
)

packed_modules_mapping class-attribute instance-attribute

packed_modules_mapping = {
    "wqkv": ["wqkv"],
    "gate_up_proj": ["w1", "w3"],
}

quant_config instance-attribute

quant_config = quant_config

__init__

__init__(
    *,
    vllm_config: VllmConfig,
    prefix: str = "",
    model_type: type[InternLM2Model] = InternLM2Model,
)
Source code in vllm/model_executor/models/internlm2.py
def __init__(self,
             *,
             vllm_config: VllmConfig,
             prefix: str = "",
             model_type: type[InternLM2Model] = InternLM2Model):
    super().__init__()
    config = vllm_config.model_config.hf_config
    quant_config = vllm_config.quant_config
    lora_config = vllm_config.lora_config

    self.config = config
    self.quant_config = quant_config
    self.lora_config = lora_config

    self.model = model_type(vllm_config=vllm_config,
                            prefix=maybe_prefix(prefix, "model"))
    self.output = ParallelLMHead(config.vocab_size,
                                 config.hidden_size,
                                 quant_config=quant_config,
                                 prefix=maybe_prefix(prefix, "output"))
    if self.config.tie_word_embeddings:
        self.output.weight = self.model.tok_embeddings.weight
    self.logits_processor = LogitsProcessor(config.vocab_size)
    self.make_empty_intermediate_tensors = (
        self.model.make_empty_intermediate_tensors)

compute_logits

compute_logits(
    hidden_states: Tensor,
    sampling_metadata: SamplingMetadata,
) -> Optional[Tensor]
Source code in vllm/model_executor/models/internlm2.py
def compute_logits(
    self,
    hidden_states: torch.Tensor,
    sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
    logits = self.logits_processor(self.output, hidden_states,
                                   sampling_metadata)
    return logits

forward

forward(
    input_ids: Tensor,
    positions: Tensor,
    intermediate_tensors: Optional[IntermediateTensors],
    inputs_embeds: Optional[Tensor] = None,
) -> Tensor
Source code in vllm/model_executor/models/internlm2.py
def forward(
    self,
    input_ids: torch.Tensor,
    positions: torch.Tensor,
    intermediate_tensors: Optional[IntermediateTensors],
    inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    hidden_states = self.model(input_ids, positions, intermediate_tensors,
                               inputs_embeds)
    return hidden_states

get_input_embeddings

get_input_embeddings(input_ids: Tensor) -> Tensor
Source code in vllm/model_executor/models/internlm2.py
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
    return self.model.get_input_embeddings(input_ids)

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]
Source code in vllm/model_executor/models/internlm2.py
def load_weights(self, weights: Iterable[tuple[str,
                                               torch.Tensor]]) -> set[str]:
    stacked_params_mapping = [
        # (param_name, shard_name, shard_id)
        ("gate_up_proj", "w1", 0),
        ("gate_up_proj", "w3", 1),
    ]
    params_dict = dict(self.named_parameters())
    loaded_params: set[str] = set()
    for name, loaded_weight in weights:
        if "rotary_emb.inv_freq" in name:
            continue
        for (param_name, weight_name, shard_id) in stacked_params_mapping:
            if weight_name not in name:
                continue
            name = name.replace(weight_name, param_name)
            # Skip loading extra bias for GPTQ models.
            if name.endswith(".bias") and name not in params_dict:
                continue
            if is_pp_missing_parameter(name, self):
                continue
            param = params_dict[name]
            weight_loader = param.weight_loader
            weight_loader(param, loaded_weight, shard_id)
            break
        else:
            # Skip loading extra bias for GPTQ models.
            if name.endswith(".bias") and name not in params_dict:
                continue
            if is_pp_missing_parameter(name, self):
                continue
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)
        loaded_params.add(name)
    return loaded_params

InternLM2ForRewardModel

Bases: InternLM2ForCausalLM

Source code in vllm/model_executor/models/internlm2.py
class InternLM2ForRewardModel(InternLM2ForCausalLM):

    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        model_type: type[InternLM2Model] = InternLM2Model,
    ):
        super().__init__(vllm_config=vllm_config,
                         prefix=prefix,
                         model_type=model_type)

        for attr in ("output", "logits_processor"):
            delattr(self, attr)

        config = vllm_config.model_config.hf_config
        self.v_head = RowParallelLinear(
            config.hidden_size,
            1,
            bias=False,
            input_is_parallel=False,
            prefix=maybe_prefix(prefix, "v_head"),
        )

        pooler_config = vllm_config.model_config.pooler_config
        self._pooler = Pooler.from_config_with_defaults(
            pooler_config,
            pooling_type=PoolingType.ALL,
            normalize=False,
            softmax=False,
        )

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        hidden_states = self.model(input_ids, positions, intermediate_tensors,
                                   inputs_embeds)
        logits, _ = self.v_head(hidden_states)
        return logits

    def pooler(
        self,
        hidden_states: torch.Tensor,
        pooling_metadata: PoolingMetadata,
    ) -> Optional[PoolerOutput]:
        return self._pooler(hidden_states, pooling_metadata)

_pooler instance-attribute

_pooler = from_config_with_defaults(
    pooler_config,
    pooling_type=ALL,
    normalize=False,
    softmax=False,
)

v_head instance-attribute

v_head = RowParallelLinear(
    hidden_size,
    1,
    bias=False,
    input_is_parallel=False,
    prefix=maybe_prefix(prefix, "v_head"),
)

__init__

__init__(
    *,
    vllm_config: VllmConfig,
    prefix: str = "",
    model_type: type[InternLM2Model] = InternLM2Model,
)
Source code in vllm/model_executor/models/internlm2.py
def __init__(
    self,
    *,
    vllm_config: VllmConfig,
    prefix: str = "",
    model_type: type[InternLM2Model] = InternLM2Model,
):
    super().__init__(vllm_config=vllm_config,
                     prefix=prefix,
                     model_type=model_type)

    for attr in ("output", "logits_processor"):
        delattr(self, attr)

    config = vllm_config.model_config.hf_config
    self.v_head = RowParallelLinear(
        config.hidden_size,
        1,
        bias=False,
        input_is_parallel=False,
        prefix=maybe_prefix(prefix, "v_head"),
    )

    pooler_config = vllm_config.model_config.pooler_config
    self._pooler = Pooler.from_config_with_defaults(
        pooler_config,
        pooling_type=PoolingType.ALL,
        normalize=False,
        softmax=False,
    )

forward

forward(
    input_ids: Tensor,
    positions: Tensor,
    intermediate_tensors: Optional[
        IntermediateTensors
    ] = None,
    inputs_embeds: Optional[Tensor] = None,
) -> Union[Tensor, IntermediateTensors]
Source code in vllm/model_executor/models/internlm2.py
def forward(
    self,
    input_ids: torch.Tensor,
    positions: torch.Tensor,
    intermediate_tensors: Optional[IntermediateTensors] = None,
    inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
    hidden_states = self.model(input_ids, positions, intermediate_tensors,
                               inputs_embeds)
    logits, _ = self.v_head(hidden_states)
    return logits

pooler

pooler(
    hidden_states: Tensor, pooling_metadata: PoolingMetadata
) -> Optional[PoolerOutput]
Source code in vllm/model_executor/models/internlm2.py
def pooler(
    self,
    hidden_states: torch.Tensor,
    pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
    return self._pooler(hidden_states, pooling_metadata)

InternLM2MLP

Bases: Module

Source code in vllm/model_executor/models/internlm2.py
class InternLM2MLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size,
            [intermediate_size] * 2,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.gate_up_proj",
        )
        self.w2 = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.w2",
        )
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.w2(x)
        return x

act_fn instance-attribute

act_fn = SiluAndMul()

gate_up_proj instance-attribute

gate_up_proj = MergedColumnParallelLinear(
    hidden_size,
    [intermediate_size] * 2,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.gate_up_proj",
)

w2 instance-attribute

w2 = RowParallelLinear(
    intermediate_size,
    hidden_size,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.w2",
)

__init__

__init__(
    hidden_size: int,
    intermediate_size: int,
    hidden_act: str,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
) -> None
Source code in vllm/model_executor/models/internlm2.py
def __init__(
    self,
    hidden_size: int,
    intermediate_size: int,
    hidden_act: str,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
) -> None:
    super().__init__()
    self.gate_up_proj = MergedColumnParallelLinear(
        hidden_size,
        [intermediate_size] * 2,
        bias=False,
        quant_config=quant_config,
        prefix=f"{prefix}.gate_up_proj",
    )
    self.w2 = RowParallelLinear(
        intermediate_size,
        hidden_size,
        bias=False,
        quant_config=quant_config,
        prefix=f"{prefix}.w2",
    )
    if hidden_act != "silu":
        raise ValueError(f"Unsupported activation: {hidden_act}. "
                         "Only silu is supported for now.")
    self.act_fn = SiluAndMul()

forward

forward(x)
Source code in vllm/model_executor/models/internlm2.py
def forward(self, x):
    gate_up, _ = self.gate_up_proj(x)
    x = self.act_fn(gate_up)
    x, _ = self.w2(x)
    return x

InternLM2Model

Bases: Module

Source code in vllm/model_executor/models/internlm2.py
@support_torch_compile
class InternLM2Model(nn.Module):

    def __init__(
            self,
            *,
            vllm_config: VllmConfig,
            prefix: str = "",
            layer_type: type[InternLMDecoderLayer] = InternLMDecoderLayer):
        super().__init__()

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

        self.config = config
        self.vocab_size = config.vocab_size
        self.tok_embeddings = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            lambda prefix: layer_type(
                config, cache_config, quant_config, prefix=prefix),
            prefix=f"{prefix}.layers")
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.tok_embeddings(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
        for layer in self.layers[self.start_layer:self.end_layer]:
            hidden_states, residual = layer(positions, hidden_states, residual)
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states

config instance-attribute

config = config

make_empty_intermediate_tensors instance-attribute

make_empty_intermediate_tensors = (
    make_empty_intermediate_tensors_factory(
        ["hidden_states", "residual"], hidden_size
    )
)

norm instance-attribute

norm = RMSNorm(hidden_size, eps=rms_norm_eps)

tok_embeddings instance-attribute

tok_embeddings = VocabParallelEmbedding(
    vocab_size, hidden_size
)

vocab_size instance-attribute

vocab_size = vocab_size

__init__

__init__(
    *,
    vllm_config: VllmConfig,
    prefix: str = "",
    layer_type: type[
        InternLMDecoderLayer
    ] = InternLMDecoderLayer,
)
Source code in vllm/model_executor/models/internlm2.py
def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        layer_type: type[InternLMDecoderLayer] = InternLMDecoderLayer):
    super().__init__()

    config = vllm_config.model_config.hf_config
    cache_config = vllm_config.cache_config
    quant_config = vllm_config.quant_config

    self.config = config
    self.vocab_size = config.vocab_size
    self.tok_embeddings = VocabParallelEmbedding(
        config.vocab_size,
        config.hidden_size,
    )
    self.start_layer, self.end_layer, self.layers = make_layers(
        config.num_hidden_layers,
        lambda prefix: layer_type(
            config, cache_config, quant_config, prefix=prefix),
        prefix=f"{prefix}.layers")
    self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
    self.make_empty_intermediate_tensors = (
        make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size))

forward

forward(
    input_ids: Tensor,
    positions: Tensor,
    intermediate_tensors: Optional[
        IntermediateTensors
    ] = None,
    inputs_embeds: Optional[Tensor] = None,
) -> Union[Tensor, IntermediateTensors]
Source code in vllm/model_executor/models/internlm2.py
def forward(
    self,
    input_ids: torch.Tensor,
    positions: torch.Tensor,
    intermediate_tensors: Optional[IntermediateTensors] = None,
    inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
    if get_pp_group().is_first_rank:
        if inputs_embeds is not None:
            hidden_states = inputs_embeds
        else:
            hidden_states = self.get_input_embeddings(input_ids)
        residual = None
    else:
        assert intermediate_tensors is not None
        hidden_states = intermediate_tensors["hidden_states"]
        residual = intermediate_tensors["residual"]
    for layer in self.layers[self.start_layer:self.end_layer]:
        hidden_states, residual = layer(positions, hidden_states, residual)
    if not get_pp_group().is_last_rank:
        return IntermediateTensors({
            "hidden_states": hidden_states,
            "residual": residual
        })
    hidden_states, _ = self.norm(hidden_states, residual)
    return hidden_states

get_input_embeddings

get_input_embeddings(input_ids: Tensor) -> Tensor
Source code in vllm/model_executor/models/internlm2.py
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
    return self.tok_embeddings(input_ids)

InternLMDecoderLayer

Bases: Module

Source code in vllm/model_executor/models/internlm2.py
class InternLMDecoderLayer(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
        self.attention = InternLM2Attention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attention",
        )
        self.feed_forward = InternLM2MLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
            quant_config=quant_config,
            prefix=f"{prefix}.feed_forward",
        )
        self.attention_norm = RMSNorm(config.hidden_size,
                                      eps=config.rms_norm_eps)
        self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
    ) -> tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.attention_norm(hidden_states)
        else:
            hidden_states, residual = self.attention_norm(
                hidden_states, residual)
        hidden_states = self.attention(
            positions=positions,
            hidden_states=hidden_states,
        )

        # Fully Connected
        hidden_states, residual = self.ffn_norm(hidden_states, residual)
        hidden_states = self.feed_forward(hidden_states)
        return hidden_states, residual

attention instance-attribute

attention = InternLM2Attention(
    hidden_size=hidden_size,
    num_heads=num_attention_heads,
    num_kv_heads=num_key_value_heads,
    rope_theta=rope_theta,
    rope_scaling=rope_scaling,
    max_position_embeddings=max_position_embeddings,
    cache_config=cache_config,
    quant_config=quant_config,
    prefix=f"{prefix}.attention",
)

attention_norm instance-attribute

attention_norm = RMSNorm(hidden_size, eps=rms_norm_eps)

feed_forward instance-attribute

feed_forward = InternLM2MLP(
    hidden_size=hidden_size,
    intermediate_size=intermediate_size,
    hidden_act=hidden_act,
    quant_config=quant_config,
    prefix=f"{prefix}.feed_forward",
)

ffn_norm instance-attribute

ffn_norm = RMSNorm(hidden_size, eps=rms_norm_eps)

hidden_size instance-attribute

hidden_size = hidden_size

__init__

__init__(
    config: PretrainedConfig,
    cache_config: Optional[CacheConfig] = None,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
) -> None
Source code in vllm/model_executor/models/internlm2.py
def __init__(
    self,
    config: PretrainedConfig,
    cache_config: Optional[CacheConfig] = None,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
) -> None:
    super().__init__()
    self.hidden_size = config.hidden_size
    rope_theta = getattr(config, "rope_theta", 10000)
    rope_scaling = getattr(config, "rope_scaling", None)
    max_position_embeddings = getattr(config, "max_position_embeddings",
                                      8192)
    self.attention = InternLM2Attention(
        hidden_size=self.hidden_size,
        num_heads=config.num_attention_heads,
        num_kv_heads=config.num_key_value_heads,
        rope_theta=rope_theta,
        rope_scaling=rope_scaling,
        max_position_embeddings=max_position_embeddings,
        cache_config=cache_config,
        quant_config=quant_config,
        prefix=f"{prefix}.attention",
    )
    self.feed_forward = InternLM2MLP(
        hidden_size=self.hidden_size,
        intermediate_size=config.intermediate_size,
        hidden_act=config.hidden_act,
        quant_config=quant_config,
        prefix=f"{prefix}.feed_forward",
    )
    self.attention_norm = RMSNorm(config.hidden_size,
                                  eps=config.rms_norm_eps)
    self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

forward

forward(
    positions: Tensor,
    hidden_states: Tensor,
    residual: Optional[Tensor],
) -> tuple[Tensor, Tensor]
Source code in vllm/model_executor/models/internlm2.py
def forward(
    self,
    positions: torch.Tensor,
    hidden_states: torch.Tensor,
    residual: Optional[torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor]:
    # Self Attention
    if residual is None:
        residual = hidden_states
        hidden_states = self.attention_norm(hidden_states)
    else:
        hidden_states, residual = self.attention_norm(
            hidden_states, residual)
    hidden_states = self.attention(
        positions=positions,
        hidden_states=hidden_states,
    )

    # Fully Connected
    hidden_states, residual = self.ffn_norm(hidden_states, residual)
    hidden_states = self.feed_forward(hidden_states)
    return hidden_states, residual