Skip to content

vllm.model_executor.models.chatglm

Inference-only ChatGLM model compatible with THUDM weights.

ChatGLMBaseModel

Bases: Module

Source code in vllm/model_executor/models/chatglm.py
class ChatGLMBaseModel(nn.Module):
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_substr={".word_embeddings": ""}, )

    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        transformer_type: type[ChatGLMModel] = ChatGLMModel,
    ) -> None:
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
        multimodal_config = vllm_config.model_config.multimodal_config
        self.config = config
        self.lora_config = lora_config
        self.multimodal_config = multimodal_config

        self.quant_config = quant_config
        self.max_position_embeddings = getattr(config, "max_sequence_length",
                                               8192)
        self.transformer = transformer_type(vllm_config=vllm_config,
                                            prefix=maybe_prefix(
                                                prefix, "transformer"))
        if self.config.tie_word_embeddings:
            self.transformer.output_layer.weight = (
                self.transformer.embedding.weight)
        self.lm_head = self.transformer.output_layer
        self.logits_processor = LogitsProcessor(config.padded_vocab_size)
        self.make_empty_intermediate_tensors = (
            self.transformer.make_empty_intermediate_tensors)

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
        logits = self.logits_processor(self.lm_head, hidden_states,
                                       sampling_metadata)
        return logits

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

config instance-attribute

config = config

hf_to_vllm_mapper class-attribute instance-attribute

hf_to_vllm_mapper = WeightsMapper(
    orig_to_new_substr={".word_embeddings": ""}
)

lm_head instance-attribute

lm_head = output_layer

logits_processor instance-attribute

logits_processor = LogitsProcessor(padded_vocab_size)

lora_config instance-attribute

lora_config = lora_config

make_empty_intermediate_tensors instance-attribute

make_empty_intermediate_tensors = (
    make_empty_intermediate_tensors
)

max_position_embeddings instance-attribute

max_position_embeddings = getattr(
    config, "max_sequence_length", 8192
)

multimodal_config instance-attribute

multimodal_config = multimodal_config

quant_config instance-attribute

quant_config = quant_config

transformer instance-attribute

transformer = transformer_type(
    vllm_config=vllm_config,
    prefix=maybe_prefix(prefix, "transformer"),
)

__init__

__init__(
    *,
    vllm_config: VllmConfig,
    prefix: str = "",
    transformer_type: type[ChatGLMModel] = ChatGLMModel,
) -> None
Source code in vllm/model_executor/models/chatglm.py
def __init__(
    self,
    *,
    vllm_config: VllmConfig,
    prefix: str = "",
    transformer_type: type[ChatGLMModel] = ChatGLMModel,
) -> None:
    super().__init__()
    config = vllm_config.model_config.hf_config
    quant_config = vllm_config.quant_config
    lora_config = vllm_config.lora_config
    multimodal_config = vllm_config.model_config.multimodal_config
    self.config = config
    self.lora_config = lora_config
    self.multimodal_config = multimodal_config

    self.quant_config = quant_config
    self.max_position_embeddings = getattr(config, "max_sequence_length",
                                           8192)
    self.transformer = transformer_type(vllm_config=vllm_config,
                                        prefix=maybe_prefix(
                                            prefix, "transformer"))
    if self.config.tie_word_embeddings:
        self.transformer.output_layer.weight = (
            self.transformer.embedding.weight)
    self.lm_head = self.transformer.output_layer
    self.logits_processor = LogitsProcessor(config.padded_vocab_size)
    self.make_empty_intermediate_tensors = (
        self.transformer.make_empty_intermediate_tensors)

compute_logits

compute_logits(
    hidden_states: Tensor,
    sampling_metadata: SamplingMetadata,
) -> Optional[Tensor]
Source code in vllm/model_executor/models/chatglm.py
def compute_logits(
    self,
    hidden_states: torch.Tensor,
    sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
    logits = self.logits_processor(self.lm_head, hidden_states,
                                   sampling_metadata)
    return logits

load_weights

load_weights(weights: Iterable[tuple[str, Tensor]])
Source code in vllm/model_executor/models/chatglm.py
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
    loader = AutoWeightsLoader(self)
    return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

ChatGLMForCausalLM

Bases: ChatGLMBaseModel, SupportsLoRA, SupportsPP, SupportsQuant

Source code in vllm/model_executor/models/chatglm.py
class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
                         SupportsQuant):
    packed_modules_mapping = {
        "query_key_value": ["query_key_value"],
        "dense_h_to_4h": ["dense_h_to_4h"]
    }

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        config = vllm_config.model_config.hf_config
        if hasattr(config, "vision_config"):
            hf_overrides = {"architectures": ["GLM4VForCausalLM"]}
            raise RuntimeError(
                "The configuration of this model indicates that it supports "
                "vision inputs, but you instantiated the text-only version "
                "of this model. Please use the vision model by setting "
                f"`--hf-overrides '{json.dumps(hf_overrides)}'`")

        super().__init__(vllm_config=vllm_config, prefix=prefix)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        hidden_states = self.transformer(input_ids, positions,
                                         intermediate_tensors, inputs_embeds)
        return hidden_states

packed_modules_mapping class-attribute instance-attribute

packed_modules_mapping = {
    "query_key_value": ["query_key_value"],
    "dense_h_to_4h": ["dense_h_to_4h"],
}

__init__

__init__(*, vllm_config: VllmConfig, prefix: str = '')
Source code in vllm/model_executor/models/chatglm.py
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
    config = vllm_config.model_config.hf_config
    if hasattr(config, "vision_config"):
        hf_overrides = {"architectures": ["GLM4VForCausalLM"]}
        raise RuntimeError(
            "The configuration of this model indicates that it supports "
            "vision inputs, but you instantiated the text-only version "
            "of this model. Please use the vision model by setting "
            f"`--hf-overrides '{json.dumps(hf_overrides)}'`")

    super().__init__(vllm_config=vllm_config, prefix=prefix)

forward

forward(
    input_ids: Tensor,
    positions: Tensor,
    intermediate_tensors: Optional[
        IntermediateTensors
    ] = None,
    inputs_embeds: Optional[Tensor] = None,
) -> Union[Tensor, IntermediateTensors]
Source code in vllm/model_executor/models/chatglm.py
def forward(
    self,
    input_ids: torch.Tensor,
    positions: torch.Tensor,
    intermediate_tensors: Optional[IntermediateTensors] = None,
    inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
    hidden_states = self.transformer(input_ids, positions,
                                     intermediate_tensors, inputs_embeds)
    return hidden_states

ChatGLMModel

Bases: Module, SupportsQuant

Source code in vllm/model_executor/models/chatglm.py
@support_torch_compile
class ChatGLMModel(nn.Module, SupportsQuant):
    packed_modules_mapping = {
        "linear_proj.merged_proj":
        ["linear_proj.gate_proj", "linear_proj.dense_h_to_4h"]
    }

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

        self.config = config

        self.embedding = VocabParallelEmbedding(config.padded_vocab_size,
                                                config.hidden_size,
                                                quant_config=quant_config,
                                                prefix=f"{prefix}.embedding")

        self.num_layers = config.num_layers
        self.multi_query_group_num = config.multi_query_group_num
        self.kv_channels = config.kv_channels
        self.encoder = GLMTransformer(config,
                                      cache_config,
                                      quant_config,
                                      prefix=f"{prefix}.encoder")

        self.output_layer = ParallelLMHead(config.padded_vocab_size,
                                           config.hidden_size,
                                           quant_config=quant_config,
                                           prefix=f"{prefix}.output_layer")

        self.make_empty_intermediate_tensors = (
            self.encoder.make_empty_intermediate_tensors)

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embedding(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        **kwargs: object,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]

        # Run encoder.
        hidden_states = self.encoder(
            hidden_states=hidden_states,
            position_ids=positions,
        )

        return hidden_states

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("linear_proj.merged_proj", "linear_proj.gate_proj", 0),
            ("linear_proj.merged_proj", "linear_proj.dense_h_to_4h", 1),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()

        for name, loaded_weight in weights:
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                if "rotary_pos_emb.inv_freq" in name:
                    continue
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

config instance-attribute

config = config

embedding instance-attribute

embedding = VocabParallelEmbedding(
    padded_vocab_size,
    hidden_size,
    quant_config=quant_config,
    prefix=f"{prefix}.embedding",
)

encoder instance-attribute

encoder = GLMTransformer(
    config,
    cache_config,
    quant_config,
    prefix=f"{prefix}.encoder",
)

kv_channels instance-attribute

kv_channels = kv_channels

make_empty_intermediate_tensors instance-attribute

make_empty_intermediate_tensors = (
    make_empty_intermediate_tensors
)

multi_query_group_num instance-attribute

multi_query_group_num = multi_query_group_num

num_layers instance-attribute

num_layers = num_layers

output_layer instance-attribute

output_layer = ParallelLMHead(
    padded_vocab_size,
    hidden_size,
    quant_config=quant_config,
    prefix=f"{prefix}.output_layer",
)

packed_modules_mapping class-attribute instance-attribute

packed_modules_mapping = {
    "linear_proj.merged_proj": [
        "linear_proj.gate_proj",
        "linear_proj.dense_h_to_4h",
    ]
}

__init__

__init__(*, vllm_config: VllmConfig, prefix: str = '')
Source code in vllm/model_executor/models/chatglm.py
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
    super().__init__()

    config = vllm_config.model_config.hf_config
    cache_config = vllm_config.cache_config
    quant_config = vllm_config.quant_config

    self.config = config

    self.embedding = VocabParallelEmbedding(config.padded_vocab_size,
                                            config.hidden_size,
                                            quant_config=quant_config,
                                            prefix=f"{prefix}.embedding")

    self.num_layers = config.num_layers
    self.multi_query_group_num = config.multi_query_group_num
    self.kv_channels = config.kv_channels
    self.encoder = GLMTransformer(config,
                                  cache_config,
                                  quant_config,
                                  prefix=f"{prefix}.encoder")

    self.output_layer = ParallelLMHead(config.padded_vocab_size,
                                       config.hidden_size,
                                       quant_config=quant_config,
                                       prefix=f"{prefix}.output_layer")

    self.make_empty_intermediate_tensors = (
        self.encoder.make_empty_intermediate_tensors)

forward

forward(
    input_ids: Tensor,
    positions: Tensor,
    intermediate_tensors: Optional[
        IntermediateTensors
    ] = None,
    inputs_embeds: Optional[Tensor] = None,
    **kwargs: object,
) -> Union[Tensor, IntermediateTensors]
Source code in vllm/model_executor/models/chatglm.py
def forward(
    self,
    input_ids: torch.Tensor,
    positions: torch.Tensor,
    intermediate_tensors: Optional[IntermediateTensors] = None,
    inputs_embeds: Optional[torch.Tensor] = None,
    **kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]:
    if get_pp_group().is_first_rank:
        if inputs_embeds is not None:
            hidden_states = inputs_embeds
        else:
            hidden_states = self.get_input_embeddings(input_ids)
    else:
        assert intermediate_tensors is not None
        hidden_states = intermediate_tensors["hidden_states"]

    # Run encoder.
    hidden_states = self.encoder(
        hidden_states=hidden_states,
        position_ids=positions,
    )

    return hidden_states

get_input_embeddings

get_input_embeddings(input_ids: Tensor) -> Tensor
Source code in vllm/model_executor/models/chatglm.py
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
    return self.embedding(input_ids)

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]
Source code in vllm/model_executor/models/chatglm.py
def load_weights(self, weights: Iterable[tuple[str,
                                               torch.Tensor]]) -> set[str]:
    stacked_params_mapping = [
        # (param_name, shard_name, shard_id)
        ("linear_proj.merged_proj", "linear_proj.gate_proj", 0),
        ("linear_proj.merged_proj", "linear_proj.dense_h_to_4h", 1),
    ]
    params_dict = dict(self.named_parameters())
    loaded_params: set[str] = set()

    for name, loaded_weight in weights:
        for (param_name, weight_name, shard_id) in stacked_params_mapping:
            if weight_name not in name:
                continue
            name = name.replace(weight_name, param_name)
            # Skip loading extra bias for GPTQ models.
            if name.endswith(".bias") and name not in params_dict:
                continue
            if is_pp_missing_parameter(name, self):
                continue
            param = params_dict[name]
            weight_loader = param.weight_loader
            weight_loader(param, loaded_weight, shard_id)
            break
        else:
            if "rotary_pos_emb.inv_freq" in name:
                continue
            if name.endswith(".bias") and name not in params_dict:
                continue
            if is_pp_missing_parameter(name, self):
                continue
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)
        loaded_params.add(name)
    return loaded_params

GLMAttention

Bases: Module

Source code in vllm/model_executor/models/chatglm.py
class GLMAttention(nn.Module):

    def __init__(
        self,
        config: ChatGLMConfig,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
        super().__init__()
        self.hidden_size = config.hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = config.num_attention_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.multi_query_attention = config.multi_query_attention
        self.total_num_kv_heads = (config.multi_query_group_num
                                   if config.multi_query_attention else
                                   config.num_attention_heads)
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = config.hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5

        self.query_key_value = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=config.add_bias_linear or config.add_qkv_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.query_key_value",
        )
        self.dense = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            config.hidden_size,
            bias=config.add_bias_linear,
            quant_config=quant_config,
            prefix=f"{prefix}.dense",
        )

        # https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
        rope_ratio = getattr(config, "rope_ratio", 1.0)
        max_positions = getattr(config, "seq_length", 8192)
        # NOTE: THUDM/cogagent-9b-20241220 uses original_rope=False,
        # which is equivalent to is_neox_style=True
        is_neox_style = not config.original_rope
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim // 2,
            max_position=max_positions,
            base=10000 * rope_ratio,
            is_neox_style=is_neox_style,
        )
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
                              num_kv_heads=self.num_kv_heads,
                              cache_config=cache_config,
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_ids: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.query_key_value(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(position_ids, q, k)
        context_layer = self.attn(q, k, v)
        attn_output, _ = self.dense(context_layer)
        return attn_output

attn instance-attribute

attn = Attention(
    num_heads,
    head_dim,
    scaling,
    num_kv_heads=num_kv_heads,
    cache_config=cache_config,
    quant_config=quant_config,
    prefix=f"{prefix}.attn",
)

dense instance-attribute

dense = RowParallelLinear(
    total_num_heads * head_dim,
    hidden_size,
    bias=add_bias_linear,
    quant_config=quant_config,
    prefix=f"{prefix}.dense",
)

head_dim instance-attribute

head_dim = hidden_size // total_num_heads

hidden_size instance-attribute

hidden_size = hidden_size

kv_size instance-attribute

kv_size = num_kv_heads * head_dim

multi_query_attention instance-attribute

multi_query_attention = multi_query_attention

num_heads instance-attribute

num_heads = total_num_heads // tp_size

num_kv_heads instance-attribute

num_kv_heads = max(1, total_num_kv_heads // tp_size)

q_size instance-attribute

q_size = num_heads * head_dim

query_key_value instance-attribute

query_key_value = QKVParallelLinear(
    hidden_size,
    head_dim,
    total_num_heads,
    total_num_kv_heads,
    bias=add_bias_linear or add_qkv_bias,
    quant_config=quant_config,
    prefix=f"{prefix}.query_key_value",
)

rotary_emb instance-attribute

rotary_emb = get_rope(
    head_dim,
    rotary_dim=head_dim // 2,
    max_position=max_positions,
    base=10000 * rope_ratio,
    is_neox_style=is_neox_style,
)

scaling instance-attribute

scaling = head_dim ** -0.5

total_num_heads instance-attribute

total_num_heads = num_attention_heads

total_num_kv_heads instance-attribute

total_num_kv_heads = (
    multi_query_group_num
    if multi_query_attention
    else num_attention_heads
)

__init__

__init__(
    config: ChatGLMConfig,
    cache_config: Optional[CacheConfig] = None,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
)
Source code in vllm/model_executor/models/chatglm.py
def __init__(
    self,
    config: ChatGLMConfig,
    cache_config: Optional[CacheConfig] = None,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
):
    super().__init__()
    self.hidden_size = config.hidden_size
    tp_size = get_tensor_model_parallel_world_size()
    self.total_num_heads = config.num_attention_heads
    assert self.total_num_heads % tp_size == 0
    self.num_heads = self.total_num_heads // tp_size
    self.multi_query_attention = config.multi_query_attention
    self.total_num_kv_heads = (config.multi_query_group_num
                               if config.multi_query_attention else
                               config.num_attention_heads)
    if self.total_num_kv_heads >= tp_size:
        # Number of KV heads is greater than TP size, so we partition
        # the KV heads across multiple tensor parallel GPUs.
        assert self.total_num_kv_heads % tp_size == 0
    else:
        # Number of KV heads is less than TP size, so we replicate
        # the KV heads across multiple tensor parallel GPUs.
        assert tp_size % self.total_num_kv_heads == 0
    self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
    self.head_dim = config.hidden_size // self.total_num_heads
    self.q_size = self.num_heads * self.head_dim
    self.kv_size = self.num_kv_heads * self.head_dim
    self.scaling = self.head_dim**-0.5

    self.query_key_value = QKVParallelLinear(
        self.hidden_size,
        self.head_dim,
        self.total_num_heads,
        self.total_num_kv_heads,
        bias=config.add_bias_linear or config.add_qkv_bias,
        quant_config=quant_config,
        prefix=f"{prefix}.query_key_value",
    )
    self.dense = RowParallelLinear(
        self.total_num_heads * self.head_dim,
        config.hidden_size,
        bias=config.add_bias_linear,
        quant_config=quant_config,
        prefix=f"{prefix}.dense",
    )

    # https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
    rope_ratio = getattr(config, "rope_ratio", 1.0)
    max_positions = getattr(config, "seq_length", 8192)
    # NOTE: THUDM/cogagent-9b-20241220 uses original_rope=False,
    # which is equivalent to is_neox_style=True
    is_neox_style = not config.original_rope
    self.rotary_emb = get_rope(
        self.head_dim,
        rotary_dim=self.head_dim // 2,
        max_position=max_positions,
        base=10000 * rope_ratio,
        is_neox_style=is_neox_style,
    )
    self.attn = Attention(self.num_heads,
                          self.head_dim,
                          self.scaling,
                          num_kv_heads=self.num_kv_heads,
                          cache_config=cache_config,
                          quant_config=quant_config,
                          prefix=f"{prefix}.attn")

forward

forward(
    hidden_states: Tensor, position_ids: Tensor
) -> Tensor
Source code in vllm/model_executor/models/chatglm.py
def forward(
    self,
    hidden_states: torch.Tensor,
    position_ids: torch.Tensor,
) -> torch.Tensor:
    qkv, _ = self.query_key_value(hidden_states)
    q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
    q, k = self.rotary_emb(position_ids, q, k)
    context_layer = self.attn(q, k, v)
    attn_output, _ = self.dense(context_layer)
    return attn_output

GLMBlock

Bases: Module

A single transformer layer.

Transformer layer takes input with size [s, b, h] and returns an output of the same size.

Source code in vllm/model_executor/models/chatglm.py
class GLMBlock(nn.Module):
    """A single transformer layer.

    Transformer layer takes input with size [s, b, h] and returns an
    output of the same size.
    """

    def __init__(
        self,
        config: ChatGLMConfig,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
        super().__init__()
        self.apply_residual_connection_post_layernorm = (
            config.apply_residual_connection_post_layernorm)

        self.fp32_residual_connection = config.fp32_residual_connection

        layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
        # Layernorm on the input data.
        self.input_layernorm = layer_norm_func(config.hidden_size,
                                               eps=config.layernorm_epsilon)

        # Self attention.
        self.self_attention = GLMAttention(config,
                                           cache_config,
                                           quant_config,
                                           prefix=f"{prefix}.self_attention")
        self.hidden_dropout = config.hidden_dropout

        # Layernorm on the attention output
        self.post_attention_layernorm = layer_norm_func(
            config.hidden_size, eps=config.layernorm_epsilon)

        # MLP
        self.mlp = GLMMLP(config, quant_config, prefix=f"{prefix}.mlp")

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_ids: torch.Tensor,
    ) -> torch.Tensor:
        # hidden_states: [num_tokens, h]
        # Layer norm at the beginning of the transformer layer.
        layernorm_output = self.input_layernorm(hidden_states)
        # Self attention.
        attention_output = self.self_attention(
            hidden_states=layernorm_output,
            position_ids=position_ids,
        )

        # Residual connection.
        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = hidden_states

        layernorm_input = residual + attention_output

        # Layer norm post the self attention.
        layernorm_output = self.post_attention_layernorm(layernorm_input)

        # Second residual connection.
        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = layernorm_input

        output = self.mlp(layernorm_output) + residual

        return output

apply_residual_connection_post_layernorm instance-attribute

apply_residual_connection_post_layernorm = (
    apply_residual_connection_post_layernorm
)

fp32_residual_connection instance-attribute

fp32_residual_connection = fp32_residual_connection

hidden_dropout instance-attribute

hidden_dropout = hidden_dropout

input_layernorm instance-attribute

input_layernorm = layer_norm_func(
    hidden_size, eps=layernorm_epsilon
)

mlp instance-attribute

mlp = GLMMLP(config, quant_config, prefix=f'{prefix}.mlp')

post_attention_layernorm instance-attribute

post_attention_layernorm = layer_norm_func(
    hidden_size, eps=layernorm_epsilon
)

self_attention instance-attribute

self_attention = GLMAttention(
    config,
    cache_config,
    quant_config,
    prefix=f"{prefix}.self_attention",
)

__init__

__init__(
    config: ChatGLMConfig,
    cache_config: Optional[CacheConfig] = None,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
)
Source code in vllm/model_executor/models/chatglm.py
def __init__(
    self,
    config: ChatGLMConfig,
    cache_config: Optional[CacheConfig] = None,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
):
    super().__init__()
    self.apply_residual_connection_post_layernorm = (
        config.apply_residual_connection_post_layernorm)

    self.fp32_residual_connection = config.fp32_residual_connection

    layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
    # Layernorm on the input data.
    self.input_layernorm = layer_norm_func(config.hidden_size,
                                           eps=config.layernorm_epsilon)

    # Self attention.
    self.self_attention = GLMAttention(config,
                                       cache_config,
                                       quant_config,
                                       prefix=f"{prefix}.self_attention")
    self.hidden_dropout = config.hidden_dropout

    # Layernorm on the attention output
    self.post_attention_layernorm = layer_norm_func(
        config.hidden_size, eps=config.layernorm_epsilon)

    # MLP
    self.mlp = GLMMLP(config, quant_config, prefix=f"{prefix}.mlp")

forward

forward(
    hidden_states: Tensor, position_ids: Tensor
) -> Tensor
Source code in vllm/model_executor/models/chatglm.py
def forward(
    self,
    hidden_states: torch.Tensor,
    position_ids: torch.Tensor,
) -> torch.Tensor:
    # hidden_states: [num_tokens, h]
    # Layer norm at the beginning of the transformer layer.
    layernorm_output = self.input_layernorm(hidden_states)
    # Self attention.
    attention_output = self.self_attention(
        hidden_states=layernorm_output,
        position_ids=position_ids,
    )

    # Residual connection.
    if self.apply_residual_connection_post_layernorm:
        residual = layernorm_output
    else:
        residual = hidden_states

    layernorm_input = residual + attention_output

    # Layer norm post the self attention.
    layernorm_output = self.post_attention_layernorm(layernorm_input)

    # Second residual connection.
    if self.apply_residual_connection_post_layernorm:
        residual = layernorm_output
    else:
        residual = layernorm_input

    output = self.mlp(layernorm_output) + residual

    return output

GLMMLP

Bases: Module

MLP.

MLP will take the input with h hidden state, project it to 4*h hidden dimension, perform nonlinear transformation, and project the state back into h hidden dimension.

Source code in vllm/model_executor/models/chatglm.py
class GLMMLP(nn.Module):
    """MLP.

    MLP will take the input with h hidden state, project it to 4*h
    hidden dimension, perform nonlinear transformation, and project the
    state back into h hidden dimension.
    """

    def __init__(
        self,
        config: ChatGLMConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
        super().__init__()

        self.add_bias = config.add_bias_linear

        # Project to 4h.
        self.dense_h_to_4h = MergedColumnParallelLinear(
            config.hidden_size,
            [config.ffn_hidden_size] * 2,
            bias=config.add_bias_linear,
            quant_config=quant_config,
            prefix=f"{prefix}.dense_h_to_4h",
        )

        self.activation_func = SiluAndMul()

        # Project back to h.
        self.dense_4h_to_h = RowParallelLinear(
            config.ffn_hidden_size,
            config.hidden_size,
            bias=config.add_bias_linear,
            quant_config=quant_config,
            prefix=f"{prefix}.dense_4h_to_h",
        )

    def forward(self, hidden_states):
        # [s, b, 4hp]
        intermediate_parallel, _ = self.dense_h_to_4h(hidden_states)
        intermediate_parallel = self.activation_func(intermediate_parallel)
        # [s, b, h]
        output, _ = self.dense_4h_to_h(intermediate_parallel)
        return output

activation_func instance-attribute

activation_func = SiluAndMul()

add_bias instance-attribute

add_bias = add_bias_linear

dense_4h_to_h instance-attribute

dense_4h_to_h = RowParallelLinear(
    ffn_hidden_size,
    hidden_size,
    bias=add_bias_linear,
    quant_config=quant_config,
    prefix=f"{prefix}.dense_4h_to_h",
)

dense_h_to_4h instance-attribute

dense_h_to_4h = MergedColumnParallelLinear(
    hidden_size,
    [ffn_hidden_size] * 2,
    bias=add_bias_linear,
    quant_config=quant_config,
    prefix=f"{prefix}.dense_h_to_4h",
)

__init__

__init__(
    config: ChatGLMConfig,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
)
Source code in vllm/model_executor/models/chatglm.py
def __init__(
    self,
    config: ChatGLMConfig,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
):
    super().__init__()

    self.add_bias = config.add_bias_linear

    # Project to 4h.
    self.dense_h_to_4h = MergedColumnParallelLinear(
        config.hidden_size,
        [config.ffn_hidden_size] * 2,
        bias=config.add_bias_linear,
        quant_config=quant_config,
        prefix=f"{prefix}.dense_h_to_4h",
    )

    self.activation_func = SiluAndMul()

    # Project back to h.
    self.dense_4h_to_h = RowParallelLinear(
        config.ffn_hidden_size,
        config.hidden_size,
        bias=config.add_bias_linear,
        quant_config=quant_config,
        prefix=f"{prefix}.dense_4h_to_h",
    )

forward

forward(hidden_states)
Source code in vllm/model_executor/models/chatglm.py
def forward(self, hidden_states):
    # [s, b, 4hp]
    intermediate_parallel, _ = self.dense_h_to_4h(hidden_states)
    intermediate_parallel = self.activation_func(intermediate_parallel)
    # [s, b, h]
    output, _ = self.dense_4h_to_h(intermediate_parallel)
    return output

GLMTransformer

Bases: Module

Transformer class.

Source code in vllm/model_executor/models/chatglm.py
class GLMTransformer(nn.Module):
    """Transformer class."""

    def __init__(
        self,
        config: ChatGLMConfig,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
        super().__init__()
        self.post_layer_norm = config.post_layer_norm

        # Number of layers.
        self.num_layers = config.num_layers

        # Transformer layers.
        self.start_layer, self.end_layer, self.layers = make_layers(
            self.num_layers,
            lambda prefix: GLMBlock(
                config, cache_config, quant_config, prefix=prefix),
            prefix=f"{prefix}.layers",
        )

        if self.post_layer_norm:
            layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
            # Final layer norm before output.
            self.final_layernorm = layer_norm_func(
                config.hidden_size, eps=config.layernorm_epsilon)

        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.hidden_size))

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_ids: torch.Tensor,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        for layer in self.layers[self.start_layer:self.end_layer]:
            hidden_states = layer(hidden_states=hidden_states,
                                  position_ids=position_ids)

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})

        # Final layer norm.
        if self.post_layer_norm:
            hidden_states = self.final_layernorm(hidden_states)

        return hidden_states

final_layernorm instance-attribute

final_layernorm = layer_norm_func(
    hidden_size, eps=layernorm_epsilon
)

make_empty_intermediate_tensors instance-attribute

make_empty_intermediate_tensors = (
    make_empty_intermediate_tensors_factory(
        ["hidden_states"], hidden_size
    )
)

num_layers instance-attribute

num_layers = num_layers

post_layer_norm instance-attribute

post_layer_norm = post_layer_norm

__init__

__init__(
    config: ChatGLMConfig,
    cache_config: Optional[CacheConfig] = None,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
)
Source code in vllm/model_executor/models/chatglm.py
def __init__(
    self,
    config: ChatGLMConfig,
    cache_config: Optional[CacheConfig] = None,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
):
    super().__init__()
    self.post_layer_norm = config.post_layer_norm

    # Number of layers.
    self.num_layers = config.num_layers

    # Transformer layers.
    self.start_layer, self.end_layer, self.layers = make_layers(
        self.num_layers,
        lambda prefix: GLMBlock(
            config, cache_config, quant_config, prefix=prefix),
        prefix=f"{prefix}.layers",
    )

    if self.post_layer_norm:
        layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
        # Final layer norm before output.
        self.final_layernorm = layer_norm_func(
            config.hidden_size, eps=config.layernorm_epsilon)

    self.make_empty_intermediate_tensors = (
        make_empty_intermediate_tensors_factory(["hidden_states"],
                                                config.hidden_size))

forward

forward(
    hidden_states: Tensor, position_ids: Tensor
) -> Union[Tensor, IntermediateTensors]
Source code in vllm/model_executor/models/chatglm.py
def forward(
    self,
    hidden_states: torch.Tensor,
    position_ids: torch.Tensor,
) -> Union[torch.Tensor, IntermediateTensors]:
    for layer in self.layers[self.start_layer:self.end_layer]:
        hidden_states = layer(hidden_states=hidden_states,
                              position_ids=position_ids)

    if not get_pp_group().is_last_rank:
        return IntermediateTensors({"hidden_states": hidden_states})

    # Final layer norm.
    if self.post_layer_norm:
        hidden_states = self.final_layernorm(hidden_states)

    return hidden_states