Skip to content

vllm.model_executor.models.jais

Inference-only Jais model compatible with HuggingFace weights.

JAISAttention

Bases: Module

Source code in vllm/model_executor/models/jais.py
class JAISAttention(nn.Module):

    def __init__(
        self,
        config: JAISConfig,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
        super().__init__()
        self.hidden_size = config.hidden_size
        total_num_heads = config.num_attention_heads
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
        assert total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = total_num_heads // tensor_model_parallel_world_size
        self.head_dim = self.hidden_size // total_num_heads
        if hasattr(config, "scale_qk_dot_by_d"):
            config.mup_scale_qk_dot_by_d = config.scale_qk_dot_by_d
        self.attn_scale_power = 1.0 if config.mup_scale_qk_dot_by_d else 0.5
        self.scale = self.head_dim**-self.attn_scale_power

        self.c_attn = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            total_num_heads,
            bias=True,
            quant_config=quant_config,
        )
        self.c_proj = RowParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=True,
            quant_config=quant_config,
        )

        tp_rank = get_tensor_model_parallel_rank()
        head_start = tp_rank * self.num_heads
        head_end = (tp_rank + 1) * self.num_heads
        alibi_slopes = _get_alibi_slopes(total_num_heads)
        alibi_slopes = alibi_slopes[head_start:head_end]
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              scale=self.scale,
                              alibi_slopes=alibi_slopes,
                              cache_config=cache_config,
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.c_attn(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
        attn_output = self.attn(q, k, v)
        attn_output, _ = self.c_proj(attn_output)
        return attn_output

attn instance-attribute

attn = Attention(
    num_heads,
    head_dim,
    scale=scale,
    alibi_slopes=alibi_slopes,
    cache_config=cache_config,
    quant_config=quant_config,
    prefix=f"{prefix}.attn",
)

attn_scale_power instance-attribute

attn_scale_power = 1.0 if mup_scale_qk_dot_by_d else 0.5

c_attn instance-attribute

c_attn = QKVParallelLinear(
    hidden_size,
    head_dim,
    total_num_heads,
    bias=True,
    quant_config=quant_config,
)

c_proj instance-attribute

c_proj = RowParallelLinear(
    hidden_size,
    hidden_size,
    bias=True,
    quant_config=quant_config,
)

head_dim instance-attribute

head_dim = hidden_size // total_num_heads

hidden_size instance-attribute

hidden_size = hidden_size

num_heads instance-attribute

num_heads = (
    total_num_heads // tensor_model_parallel_world_size
)

scale instance-attribute

scale = head_dim ** -attn_scale_power

__init__

__init__(
    config: JAISConfig,
    cache_config: Optional[CacheConfig] = None,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
)
Source code in vllm/model_executor/models/jais.py
def __init__(
    self,
    config: JAISConfig,
    cache_config: Optional[CacheConfig] = None,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
):
    super().__init__()
    self.hidden_size = config.hidden_size
    total_num_heads = config.num_attention_heads
    tensor_model_parallel_world_size = (
        get_tensor_model_parallel_world_size())
    assert total_num_heads % tensor_model_parallel_world_size == 0
    self.num_heads = total_num_heads // tensor_model_parallel_world_size
    self.head_dim = self.hidden_size // total_num_heads
    if hasattr(config, "scale_qk_dot_by_d"):
        config.mup_scale_qk_dot_by_d = config.scale_qk_dot_by_d
    self.attn_scale_power = 1.0 if config.mup_scale_qk_dot_by_d else 0.5
    self.scale = self.head_dim**-self.attn_scale_power

    self.c_attn = QKVParallelLinear(
        self.hidden_size,
        self.head_dim,
        total_num_heads,
        bias=True,
        quant_config=quant_config,
    )
    self.c_proj = RowParallelLinear(
        self.hidden_size,
        self.hidden_size,
        bias=True,
        quant_config=quant_config,
    )

    tp_rank = get_tensor_model_parallel_rank()
    head_start = tp_rank * self.num_heads
    head_end = (tp_rank + 1) * self.num_heads
    alibi_slopes = _get_alibi_slopes(total_num_heads)
    alibi_slopes = alibi_slopes[head_start:head_end]
    self.attn = Attention(self.num_heads,
                          self.head_dim,
                          scale=self.scale,
                          alibi_slopes=alibi_slopes,
                          cache_config=cache_config,
                          quant_config=quant_config,
                          prefix=f"{prefix}.attn")

forward

forward(hidden_states: Tensor) -> Tensor
Source code in vllm/model_executor/models/jais.py
def forward(
    self,
    hidden_states: torch.Tensor,
) -> torch.Tensor:
    qkv, _ = self.c_attn(hidden_states)
    q, k, v = qkv.chunk(chunks=3, dim=-1)
    attn_output = self.attn(q, k, v)
    attn_output, _ = self.c_proj(attn_output)
    return attn_output

JAISBlock

Bases: Module

Source code in vllm/model_executor/models/jais.py
class JAISBlock(nn.Module):

    def __init__(
        self,
        config: JAISConfig,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
        super().__init__()
        hidden_size = config.hidden_size
        inner_dim = (config.n_inner if config.n_inner is not None else 4 *
                     hidden_size)

        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
        self.attn = JAISAttention(config,
                                  cache_config,
                                  quant_config,
                                  prefix=f"{prefix}.attn")
        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
        self.mlp = JAISMLP(inner_dim, config, quant_config)

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attn_output = self.attn(hidden_states=hidden_states, )
        # residual connection
        hidden_states = attn_output + residual

        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)
        feed_forward_hidden_states = self.mlp(hidden_states)
        # residual connection
        hidden_states = residual + feed_forward_hidden_states
        return hidden_states

attn instance-attribute

attn = JAISAttention(
    config,
    cache_config,
    quant_config,
    prefix=f"{prefix}.attn",
)

ln_1 instance-attribute

ln_1 = LayerNorm(hidden_size, eps=layer_norm_epsilon)

ln_2 instance-attribute

ln_2 = LayerNorm(hidden_size, eps=layer_norm_epsilon)

mlp instance-attribute

mlp = JAISMLP(inner_dim, config, quant_config)

__init__

__init__(
    config: JAISConfig,
    cache_config: Optional[CacheConfig] = None,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
)
Source code in vllm/model_executor/models/jais.py
def __init__(
    self,
    config: JAISConfig,
    cache_config: Optional[CacheConfig] = None,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
):
    super().__init__()
    hidden_size = config.hidden_size
    inner_dim = (config.n_inner if config.n_inner is not None else 4 *
                 hidden_size)

    self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
    self.attn = JAISAttention(config,
                              cache_config,
                              quant_config,
                              prefix=f"{prefix}.attn")
    self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
    self.mlp = JAISMLP(inner_dim, config, quant_config)

forward

forward(hidden_states: Tensor) -> Tensor
Source code in vllm/model_executor/models/jais.py
def forward(
    self,
    hidden_states: torch.Tensor,
) -> torch.Tensor:
    residual = hidden_states
    hidden_states = self.ln_1(hidden_states)
    attn_output = self.attn(hidden_states=hidden_states, )
    # residual connection
    hidden_states = attn_output + residual

    residual = hidden_states
    hidden_states = self.ln_2(hidden_states)
    feed_forward_hidden_states = self.mlp(hidden_states)
    # residual connection
    hidden_states = residual + feed_forward_hidden_states
    return hidden_states

JAISLMHeadModel

Bases: Module, SupportsPP

Source code in vllm/model_executor/models/jais.py
class JAISLMHeadModel(nn.Module, SupportsPP):

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        self.config = config
        self.quant_config = quant_config
        self.transformer = JAISModel(vllm_config=vllm_config,
                                     prefix=maybe_prefix(
                                         prefix, "transformer"))
        if self.config.tie_word_embeddings:
            self.lm_head = self.transformer.wte
        else:
            self.lm_head = ParallelLMHead(self.config.vocab_size,
                                          self.config.hidden_size)
        if hasattr(config, "width_scale"):
            self.output_logits_scale = config.width_scale
        else:
            self.output_logits_scale = (config.mup_output_alpha *
                                        config.mup_width_scale)
        self.logits_processor = LogitsProcessor(vocab_size=config.vocab_size,
                                                scale=self.output_logits_scale)
        self.make_empty_intermediate_tensors = (
            self.transformer.make_empty_intermediate_tensors)

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.transformer.get_input_embeddings(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> Union[IntermediateTensors, torch.Tensor]:
        hidden_states = self.transformer(input_ids, positions,
                                         intermediate_tensors, inputs_embeds)
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
        logits = self.logits_processor(self.lm_head, hidden_states,
                                       sampling_metadata)
        return logits

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        params_dict = dict(self.named_parameters(remove_duplicate=False))
        loaded_params: set[str] = set()
        for name, loaded_weight in weights:
            if "lm_head.weight" in name:
                # GPT-2 ties the weights of the embedding layer and the final
                # linear layer.
                continue
            if ".attn.bias" in name or ".attn.masked_bias" in name:
                # Skip attention mask.
                # NOTE: "c_attn.bias" should not be skipped.
                continue
            if "relative_pe" in name:
                continue
            if not name.startswith("transformer."):
                name = "transformer." + name

            if is_pp_missing_parameter(name, self):
                continue

            param = params_dict[name]
            # The HF's GPT-2 implementation uses Conv1D instead of Linear.
            # Because of this, we need to transpose the weights.
            # Note(zhuohan): the logic below might break quantized models.
            for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
                if conv1d_weight_name not in name:
                    continue
                if not name.endswith(".weight"):
                    continue
                loaded_weight = loaded_weight.t()
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

config instance-attribute

config = config

lm_head instance-attribute

lm_head = wte

logits_processor instance-attribute

logits_processor = LogitsProcessor(
    vocab_size=vocab_size, scale=output_logits_scale
)

make_empty_intermediate_tensors instance-attribute

make_empty_intermediate_tensors = (
    make_empty_intermediate_tensors
)

output_logits_scale instance-attribute

output_logits_scale = width_scale

quant_config instance-attribute

quant_config = quant_config

transformer instance-attribute

transformer = JAISModel(
    vllm_config=vllm_config,
    prefix=maybe_prefix(prefix, "transformer"),
)

__init__

__init__(*, vllm_config: VllmConfig, prefix: str = '')
Source code in vllm/model_executor/models/jais.py
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
    super().__init__()
    config = vllm_config.model_config.hf_config
    quant_config = vllm_config.quant_config
    self.config = config
    self.quant_config = quant_config
    self.transformer = JAISModel(vllm_config=vllm_config,
                                 prefix=maybe_prefix(
                                     prefix, "transformer"))
    if self.config.tie_word_embeddings:
        self.lm_head = self.transformer.wte
    else:
        self.lm_head = ParallelLMHead(self.config.vocab_size,
                                      self.config.hidden_size)
    if hasattr(config, "width_scale"):
        self.output_logits_scale = config.width_scale
    else:
        self.output_logits_scale = (config.mup_output_alpha *
                                    config.mup_width_scale)
    self.logits_processor = LogitsProcessor(vocab_size=config.vocab_size,
                                            scale=self.output_logits_scale)
    self.make_empty_intermediate_tensors = (
        self.transformer.make_empty_intermediate_tensors)

compute_logits

compute_logits(
    hidden_states: Tensor,
    sampling_metadata: SamplingMetadata,
) -> Optional[Tensor]
Source code in vllm/model_executor/models/jais.py
def compute_logits(
    self,
    hidden_states: torch.Tensor,
    sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
    logits = self.logits_processor(self.lm_head, hidden_states,
                                   sampling_metadata)
    return logits

forward

forward(
    input_ids: Tensor,
    positions: Tensor,
    intermediate_tensors: Optional[
        IntermediateTensors
    ] = None,
    inputs_embeds: Optional[Tensor] = None,
) -> Union[IntermediateTensors, Tensor]
Source code in vllm/model_executor/models/jais.py
def forward(
    self,
    input_ids: torch.Tensor,
    positions: torch.Tensor,
    intermediate_tensors: Optional[IntermediateTensors] = None,
    inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[IntermediateTensors, torch.Tensor]:
    hidden_states = self.transformer(input_ids, positions,
                                     intermediate_tensors, inputs_embeds)
    return hidden_states

get_input_embeddings

get_input_embeddings(input_ids: Tensor) -> Tensor
Source code in vllm/model_executor/models/jais.py
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
    return self.transformer.get_input_embeddings(input_ids)

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]
Source code in vllm/model_executor/models/jais.py
def load_weights(self, weights: Iterable[tuple[str,
                                               torch.Tensor]]) -> set[str]:
    params_dict = dict(self.named_parameters(remove_duplicate=False))
    loaded_params: set[str] = set()
    for name, loaded_weight in weights:
        if "lm_head.weight" in name:
            # GPT-2 ties the weights of the embedding layer and the final
            # linear layer.
            continue
        if ".attn.bias" in name or ".attn.masked_bias" in name:
            # Skip attention mask.
            # NOTE: "c_attn.bias" should not be skipped.
            continue
        if "relative_pe" in name:
            continue
        if not name.startswith("transformer."):
            name = "transformer." + name

        if is_pp_missing_parameter(name, self):
            continue

        param = params_dict[name]
        # The HF's GPT-2 implementation uses Conv1D instead of Linear.
        # Because of this, we need to transpose the weights.
        # Note(zhuohan): the logic below might break quantized models.
        for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
            if conv1d_weight_name not in name:
                continue
            if not name.endswith(".weight"):
                continue
            loaded_weight = loaded_weight.t()
        weight_loader = getattr(param, "weight_loader",
                                default_weight_loader)
        weight_loader(param, loaded_weight)
        loaded_params.add(name)
    return loaded_params

JAISMLP

Bases: Module

Source code in vllm/model_executor/models/jais.py
class JAISMLP(nn.Module):

    def __init__(
        self,
        intermediate_size: int,
        config: JAISConfig,
        quant_config: Optional[QuantizationConfig] = None,
    ):
        super().__init__()
        hidden_size = config.hidden_size
        self.swiglu = config.activation_function == "swiglu"
        self.c_fc = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
            bias=True,
            quant_config=quant_config,
        )
        self.c_fc2 = (ColumnParallelLinear(
            hidden_size,
            intermediate_size,
            bias=True,
            quant_config=quant_config,
        ) if self.swiglu else None)
        self.c_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=True,
            quant_config=quant_config,
        )

        self.act = SwiGLUActivation()

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        if self.swiglu:
            hidden_states2, _ = self.c_fc2(hidden_states)
        hidden_states, _ = self.c_fc(hidden_states)
        hidden_states = (self.act(hidden_states, hidden_states2)
                         if self.swiglu else self.act(hidden_states))
        hidden_states, _ = self.c_proj(hidden_states)
        return hidden_states

act instance-attribute

c_fc instance-attribute

c_fc = ColumnParallelLinear(
    hidden_size,
    intermediate_size,
    bias=True,
    quant_config=quant_config,
)

c_fc2 instance-attribute

c_fc2 = (
    ColumnParallelLinear(
        hidden_size,
        intermediate_size,
        bias=True,
        quant_config=quant_config,
    )
    if swiglu
    else None
)

c_proj instance-attribute

c_proj = RowParallelLinear(
    intermediate_size,
    hidden_size,
    bias=True,
    quant_config=quant_config,
)

swiglu instance-attribute

swiglu = activation_function == 'swiglu'

__init__

__init__(
    intermediate_size: int,
    config: JAISConfig,
    quant_config: Optional[QuantizationConfig] = None,
)
Source code in vllm/model_executor/models/jais.py
def __init__(
    self,
    intermediate_size: int,
    config: JAISConfig,
    quant_config: Optional[QuantizationConfig] = None,
):
    super().__init__()
    hidden_size = config.hidden_size
    self.swiglu = config.activation_function == "swiglu"
    self.c_fc = ColumnParallelLinear(
        hidden_size,
        intermediate_size,
        bias=True,
        quant_config=quant_config,
    )
    self.c_fc2 = (ColumnParallelLinear(
        hidden_size,
        intermediate_size,
        bias=True,
        quant_config=quant_config,
    ) if self.swiglu else None)
    self.c_proj = RowParallelLinear(
        intermediate_size,
        hidden_size,
        bias=True,
        quant_config=quant_config,
    )

    self.act = SwiGLUActivation()

forward

forward(hidden_states: Tensor) -> Tensor
Source code in vllm/model_executor/models/jais.py
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
    if self.swiglu:
        hidden_states2, _ = self.c_fc2(hidden_states)
    hidden_states, _ = self.c_fc(hidden_states)
    hidden_states = (self.act(hidden_states, hidden_states2)
                     if self.swiglu else self.act(hidden_states))
    hidden_states, _ = self.c_proj(hidden_states)
    return hidden_states

JAISModel

Bases: Module

Source code in vllm/model_executor/models/jais.py
@support_torch_compile
class JAISModel(nn.Module):

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

        self.config = config
        assert not config.add_cross_attention
        assert not config.scale_attn_by_inverse_layer_idx
        assert not config.reorder_and_upcast_attn
        self.embed_dim = config.hidden_size
        self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
        self.wpe = (nn.Embedding(config.max_position_embeddings,
                                 self.embed_dim)
                    if config.position_embedding_type != "alibi" else None)
        if hasattr(config, "embeddings_scale"):
            self.embeddings_scale = config.embeddings_scale
        else:
            self.embeddings_scale = config.mup_embeddings_scale

        self.start_layer, self.end_layer, self.h = make_layers(
            config.num_hidden_layers,
            lambda prefix: JAISBlock(config=config,
                                     cache_config=cache_config,
                                     quant_config=quant_config,
                                     prefix=prefix),
            prefix=f"{prefix}.h",
        )

        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.n_embd))

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.wte(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> Union[IntermediateTensors, torch.Tensor]:
        if get_pp_group().is_first_rank:
            if inputs_embeds is None:
                inputs_embeds = self.get_input_embeddings(input_ids)
            if self.wpe is not None:
                position_embeds = self.wpe(position_ids)
                hidden_states = inputs_embeds + position_embeds
            else:
                hidden_states = inputs_embeds
            hidden_states *= torch.tensor(float(self.embeddings_scale),
                                          dtype=hidden_states.dtype)
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]

        for layer in self.h[self.start_layer:self.end_layer]:
            hidden_states = layer(hidden_states)

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})

        hidden_states = self.ln_f(hidden_states)
        return hidden_states

config instance-attribute

config = config

embed_dim instance-attribute

embed_dim = hidden_size

embeddings_scale instance-attribute

embeddings_scale = embeddings_scale

ln_f instance-attribute

ln_f = LayerNorm(embed_dim, eps=layer_norm_epsilon)

make_empty_intermediate_tensors instance-attribute

make_empty_intermediate_tensors = (
    make_empty_intermediate_tensors_factory(
        ["hidden_states"], n_embd
    )
)

wpe instance-attribute

wpe = (
    Embedding(max_position_embeddings, embed_dim)
    if position_embedding_type != "alibi"
    else None
)

wte instance-attribute

wte = VocabParallelEmbedding(vocab_size, embed_dim)

__init__

__init__(*, vllm_config: VllmConfig, prefix: str = '')
Source code in vllm/model_executor/models/jais.py
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
    super().__init__()

    config = vllm_config.model_config.hf_config
    cache_config = vllm_config.cache_config
    quant_config = vllm_config.quant_config

    self.config = config
    assert not config.add_cross_attention
    assert not config.scale_attn_by_inverse_layer_idx
    assert not config.reorder_and_upcast_attn
    self.embed_dim = config.hidden_size
    self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
    self.wpe = (nn.Embedding(config.max_position_embeddings,
                             self.embed_dim)
                if config.position_embedding_type != "alibi" else None)
    if hasattr(config, "embeddings_scale"):
        self.embeddings_scale = config.embeddings_scale
    else:
        self.embeddings_scale = config.mup_embeddings_scale

    self.start_layer, self.end_layer, self.h = make_layers(
        config.num_hidden_layers,
        lambda prefix: JAISBlock(config=config,
                                 cache_config=cache_config,
                                 quant_config=quant_config,
                                 prefix=prefix),
        prefix=f"{prefix}.h",
    )

    self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
    self.make_empty_intermediate_tensors = (
        make_empty_intermediate_tensors_factory(["hidden_states"],
                                                config.n_embd))

forward

forward(
    input_ids: Tensor,
    position_ids: Tensor,
    intermediate_tensors: Optional[
        IntermediateTensors
    ] = None,
    inputs_embeds: Optional[Tensor] = None,
) -> Union[IntermediateTensors, Tensor]
Source code in vllm/model_executor/models/jais.py
def forward(
    self,
    input_ids: torch.Tensor,
    position_ids: torch.Tensor,
    intermediate_tensors: Optional[IntermediateTensors] = None,
    inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[IntermediateTensors, torch.Tensor]:
    if get_pp_group().is_first_rank:
        if inputs_embeds is None:
            inputs_embeds = self.get_input_embeddings(input_ids)
        if self.wpe is not None:
            position_embeds = self.wpe(position_ids)
            hidden_states = inputs_embeds + position_embeds
        else:
            hidden_states = inputs_embeds
        hidden_states *= torch.tensor(float(self.embeddings_scale),
                                      dtype=hidden_states.dtype)
    else:
        assert intermediate_tensors is not None
        hidden_states = intermediate_tensors["hidden_states"]

    for layer in self.h[self.start_layer:self.end_layer]:
        hidden_states = layer(hidden_states)

    if not get_pp_group().is_last_rank:
        return IntermediateTensors({"hidden_states": hidden_states})

    hidden_states = self.ln_f(hidden_states)
    return hidden_states

get_input_embeddings

get_input_embeddings(input_ids: Tensor) -> Tensor
Source code in vllm/model_executor/models/jais.py
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
    return self.wte(input_ids)

SwiGLUActivation

Bases: Module

Source code in vllm/model_executor/models/jais.py
class SwiGLUActivation(nn.Module):

    def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
        return x1 * nn.functional.silu(x2)

forward

forward(x1: Tensor, x2: Tensor) -> Tensor
Source code in vllm/model_executor/models/jais.py
def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
    return x1 * nn.functional.silu(x2)

_get_alibi_slopes

_get_alibi_slopes(n)
Source code in vllm/model_executor/models/jais.py
def _get_alibi_slopes(n):

    def get_slopes_power_of_2(n):
        start = 2**(-(2**-(math.log2(n) - 3)))
        ratio = start
        return [start * ratio**i for i in range(n)]

    if math.log2(n).is_integer():
        return get_slopes_power_of_2(n)
    else:
        closest_power_of_2 = 2**math.floor(math.log2(n))
        return (get_slopes_power_of_2(closest_power_of_2) + _get_alibi_slopes(
            2 * closest_power_of_2)[0::2][:n - closest_power_of_2])