Skip to content

vllm.model_executor.models.opt

Inference-only OPT model compatible with HuggingFace weights.

OPTAttention

Bases: Module

Source code in vllm/model_executor/models/opt.py
class OPTAttention(nn.Module):

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        bias: bool = True,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.embed_dim = embed_dim
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
        total_num_heads = num_heads
        assert num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = total_num_heads // tensor_model_parallel_world_size
        self.head_dim = embed_dim // total_num_heads
        self.scaling = self.head_dim**-0.5

        self.qkv_proj = QKVParallelLinear(
            embed_dim,
            self.head_dim,
            total_num_heads,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )
        self.out_proj = RowParallelLinear(
            embed_dim,
            embed_dim,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.out_proj",
        )
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              scale=self.scaling,
                              cache_config=cache_config,
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
        attn_output = self.attn(q, k, v)
        output, _ = self.out_proj(attn_output)
        return output

attn instance-attribute

attn = Attention(
    num_heads,
    head_dim,
    scale=scaling,
    cache_config=cache_config,
    quant_config=quant_config,
    prefix=f"{prefix}.attn",
)

embed_dim instance-attribute

embed_dim = embed_dim

head_dim instance-attribute

head_dim = embed_dim // total_num_heads

num_heads instance-attribute

num_heads = (
    total_num_heads // tensor_model_parallel_world_size
)

out_proj instance-attribute

out_proj = RowParallelLinear(
    embed_dim,
    embed_dim,
    bias=bias,
    quant_config=quant_config,
    prefix=f"{prefix}.out_proj",
)

qkv_proj instance-attribute

qkv_proj = QKVParallelLinear(
    embed_dim,
    head_dim,
    total_num_heads,
    bias=bias,
    quant_config=quant_config,
    prefix=f"{prefix}.qkv_proj",
)

scaling instance-attribute

scaling = head_dim ** -0.5

__init__

__init__(
    embed_dim: int,
    num_heads: int,
    bias: bool = True,
    cache_config: Optional[CacheConfig] = None,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
) -> None
Source code in vllm/model_executor/models/opt.py
def __init__(
    self,
    embed_dim: int,
    num_heads: int,
    bias: bool = True,
    cache_config: Optional[CacheConfig] = None,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
) -> None:
    super().__init__()
    self.embed_dim = embed_dim
    tensor_model_parallel_world_size = (
        get_tensor_model_parallel_world_size())
    total_num_heads = num_heads
    assert num_heads % tensor_model_parallel_world_size == 0
    self.num_heads = total_num_heads // tensor_model_parallel_world_size
    self.head_dim = embed_dim // total_num_heads
    self.scaling = self.head_dim**-0.5

    self.qkv_proj = QKVParallelLinear(
        embed_dim,
        self.head_dim,
        total_num_heads,
        bias=bias,
        quant_config=quant_config,
        prefix=f"{prefix}.qkv_proj",
    )
    self.out_proj = RowParallelLinear(
        embed_dim,
        embed_dim,
        bias=bias,
        quant_config=quant_config,
        prefix=f"{prefix}.out_proj",
    )
    self.attn = Attention(self.num_heads,
                          self.head_dim,
                          scale=self.scaling,
                          cache_config=cache_config,
                          quant_config=quant_config,
                          prefix=f"{prefix}.attn")

forward

forward(hidden_states: Tensor) -> Tensor
Source code in vllm/model_executor/models/opt.py
def forward(
    self,
    hidden_states: torch.Tensor,
) -> torch.Tensor:
    qkv, _ = self.qkv_proj(hidden_states)
    q, k, v = qkv.chunk(chunks=3, dim=-1)
    attn_output = self.attn(q, k, v)
    output, _ = self.out_proj(attn_output)
    return output

OPTDecoder

Bases: Module

Source code in vllm/model_executor/models/opt.py
class OPTDecoder(nn.Module):

    def __init__(
        self,
        config: OPTConfig,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
        super().__init__()
        self.config = config
        self.max_target_positions = config.max_position_embeddings
        self.vocab_size = config.vocab_size

        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.word_embed_proj_dim,
        )
        # Positional embeddings are replicated (not sharded).
        self.embed_positions = OPTLearnedPositionalEmbedding(
            config.max_position_embeddings, config.hidden_size)

        # Project out & in will be replicated if they exist.
        if config.word_embed_proj_dim != config.hidden_size:
            self.project_out = ReplicatedLinear(config.hidden_size,
                                                config.word_embed_proj_dim,
                                                bias=False,
                                                quant_config=quant_config,
                                                prefix=f"{prefix}.project_out")
        else:
            self.project_out = None

        if config.word_embed_proj_dim != config.hidden_size:
            self.project_in = ReplicatedLinear(config.word_embed_proj_dim,
                                               config.hidden_size,
                                               bias=False,
                                               quant_config=quant_config,
                                               prefix=f"{prefix}.project_in")
        else:
            self.project_in = None

        # Note that the only purpose of `config._remove_final_layer_norm` is to
        # keep backward compatibility with checkpoints that have been fine-tuned
        # before transformers v4.20.1
        # see https://github.com/facebookresearch/metaseq/pull/164
        if config.do_layer_norm_before and not config._remove_final_layer_norm:
            self.final_layer_norm = nn.LayerNorm(
                config.hidden_size,
                elementwise_affine=config.layer_norm_elementwise_affine)
        else:
            self.final_layer_norm = None

        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            lambda prefix: OPTDecoderLayer(
                config, cache_config, quant_config, prefix=prefix),
            prefix=f"{prefix}.layers")

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors],
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            if inputs_embeds is None:
                inputs_embeds = self.get_input_embeddings(input_ids)
            pos_embeds = self.embed_positions(positions)
            if self.project_in is not None:
                inputs_embeds, _ = self.project_in(inputs_embeds)
            hidden_states = inputs_embeds + pos_embeds
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]

        for layer in self.layers[self.start_layer:self.end_layer]:
            hidden_states = layer(hidden_states)

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
        if self.final_layer_norm is not None:
            hidden_states = self.final_layer_norm(hidden_states)
        if self.project_out is not None:
            hidden_states, _ = self.project_out(hidden_states)
        return hidden_states

config instance-attribute

config = config

embed_positions instance-attribute

embed_positions = OPTLearnedPositionalEmbedding(
    max_position_embeddings, hidden_size
)

embed_tokens instance-attribute

embed_tokens = VocabParallelEmbedding(
    vocab_size, word_embed_proj_dim
)

final_layer_norm instance-attribute

final_layer_norm = LayerNorm(
    hidden_size,
    elementwise_affine=layer_norm_elementwise_affine,
)

max_target_positions instance-attribute

max_target_positions = max_position_embeddings

project_in instance-attribute

project_in = ReplicatedLinear(
    word_embed_proj_dim,
    hidden_size,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.project_in",
)

project_out instance-attribute

project_out = ReplicatedLinear(
    hidden_size,
    word_embed_proj_dim,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.project_out",
)

vocab_size instance-attribute

vocab_size = vocab_size

__init__

__init__(
    config: OPTConfig,
    cache_config: Optional[CacheConfig] = None,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
)
Source code in vllm/model_executor/models/opt.py
def __init__(
    self,
    config: OPTConfig,
    cache_config: Optional[CacheConfig] = None,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
):
    super().__init__()
    self.config = config
    self.max_target_positions = config.max_position_embeddings
    self.vocab_size = config.vocab_size

    self.embed_tokens = VocabParallelEmbedding(
        config.vocab_size,
        config.word_embed_proj_dim,
    )
    # Positional embeddings are replicated (not sharded).
    self.embed_positions = OPTLearnedPositionalEmbedding(
        config.max_position_embeddings, config.hidden_size)

    # Project out & in will be replicated if they exist.
    if config.word_embed_proj_dim != config.hidden_size:
        self.project_out = ReplicatedLinear(config.hidden_size,
                                            config.word_embed_proj_dim,
                                            bias=False,
                                            quant_config=quant_config,
                                            prefix=f"{prefix}.project_out")
    else:
        self.project_out = None

    if config.word_embed_proj_dim != config.hidden_size:
        self.project_in = ReplicatedLinear(config.word_embed_proj_dim,
                                           config.hidden_size,
                                           bias=False,
                                           quant_config=quant_config,
                                           prefix=f"{prefix}.project_in")
    else:
        self.project_in = None

    # Note that the only purpose of `config._remove_final_layer_norm` is to
    # keep backward compatibility with checkpoints that have been fine-tuned
    # before transformers v4.20.1
    # see https://github.com/facebookresearch/metaseq/pull/164
    if config.do_layer_norm_before and not config._remove_final_layer_norm:
        self.final_layer_norm = nn.LayerNorm(
            config.hidden_size,
            elementwise_affine=config.layer_norm_elementwise_affine)
    else:
        self.final_layer_norm = None

    self.start_layer, self.end_layer, self.layers = make_layers(
        config.num_hidden_layers,
        lambda prefix: OPTDecoderLayer(
            config, cache_config, quant_config, prefix=prefix),
        prefix=f"{prefix}.layers")

forward

forward(
    input_ids: Tensor,
    positions: Tensor,
    intermediate_tensors: Optional[IntermediateTensors],
    inputs_embeds: Optional[Tensor] = None,
) -> Union[Tensor, IntermediateTensors]
Source code in vllm/model_executor/models/opt.py
def forward(
    self,
    input_ids: torch.Tensor,
    positions: torch.Tensor,
    intermediate_tensors: Optional[IntermediateTensors],
    inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
    if get_pp_group().is_first_rank:
        if inputs_embeds is None:
            inputs_embeds = self.get_input_embeddings(input_ids)
        pos_embeds = self.embed_positions(positions)
        if self.project_in is not None:
            inputs_embeds, _ = self.project_in(inputs_embeds)
        hidden_states = inputs_embeds + pos_embeds
    else:
        assert intermediate_tensors is not None
        hidden_states = intermediate_tensors["hidden_states"]

    for layer in self.layers[self.start_layer:self.end_layer]:
        hidden_states = layer(hidden_states)

    if not get_pp_group().is_last_rank:
        return IntermediateTensors({"hidden_states": hidden_states})
    if self.final_layer_norm is not None:
        hidden_states = self.final_layer_norm(hidden_states)
    if self.project_out is not None:
        hidden_states, _ = self.project_out(hidden_states)
    return hidden_states

get_input_embeddings

get_input_embeddings(input_ids: Tensor) -> Tensor
Source code in vllm/model_executor/models/opt.py
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
    return self.embed_tokens(input_ids)

OPTDecoderLayer

Bases: Module

Source code in vllm/model_executor/models/opt.py
class OPTDecoderLayer(nn.Module):

    def __init__(
        self,
        config: OPTConfig,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.self_attn = OPTAttention(
            embed_dim=self.embed_dim,
            num_heads=config.num_attention_heads,
            bias=config.enable_bias,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
        )
        self.do_layer_norm_before = config.do_layer_norm_before

        self.self_attn_layer_norm = nn.LayerNorm(
            self.embed_dim,
            elementwise_affine=config.layer_norm_elementwise_affine)
        self.fc1 = ColumnParallelLinear(
            self.embed_dim,
            config.ffn_dim,
            bias=config.enable_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.fc1",
        )
        self.activation_fn = get_act_fn(config.activation_function)
        self.fc2 = RowParallelLinear(
            config.ffn_dim,
            self.embed_dim,
            bias=config.enable_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.fc2",
        )
        self.final_layer_norm = nn.LayerNorm(
            self.embed_dim,
            elementwise_affine=config.layer_norm_elementwise_affine)

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        # Self Attention
        residual = hidden_states
        # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
        if self.do_layer_norm_before:
            hidden_states = self.self_attn_layer_norm(hidden_states)
        hidden_states = self.self_attn(hidden_states=hidden_states)
        hidden_states = residual + hidden_states
        # 350m applies layer norm AFTER attention
        if not self.do_layer_norm_before:
            hidden_states = self.self_attn_layer_norm(hidden_states)

        # Fully Connected
        residual = hidden_states
        # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
        if self.do_layer_norm_before:
            hidden_states = self.final_layer_norm(hidden_states)
        hidden_states, _ = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states, _ = self.fc2(hidden_states)
        hidden_states = residual + hidden_states
        # 350m applies layer norm AFTER attention
        if not self.do_layer_norm_before:
            hidden_states = self.final_layer_norm(hidden_states)
        return hidden_states

activation_fn instance-attribute

activation_fn = get_act_fn(activation_function)

config instance-attribute

config = config

do_layer_norm_before instance-attribute

do_layer_norm_before = do_layer_norm_before

embed_dim instance-attribute

embed_dim = hidden_size

fc1 instance-attribute

fc1 = ColumnParallelLinear(
    embed_dim,
    ffn_dim,
    bias=enable_bias,
    quant_config=quant_config,
    prefix=f"{prefix}.fc1",
)

fc2 instance-attribute

fc2 = RowParallelLinear(
    ffn_dim,
    embed_dim,
    bias=enable_bias,
    quant_config=quant_config,
    prefix=f"{prefix}.fc2",
)

final_layer_norm instance-attribute

final_layer_norm = LayerNorm(
    embed_dim,
    elementwise_affine=layer_norm_elementwise_affine,
)

self_attn instance-attribute

self_attn = OPTAttention(
    embed_dim=embed_dim,
    num_heads=num_attention_heads,
    bias=enable_bias,
    cache_config=cache_config,
    quant_config=quant_config,
    prefix=f"{prefix}.self_attn",
)

self_attn_layer_norm instance-attribute

self_attn_layer_norm = LayerNorm(
    embed_dim,
    elementwise_affine=layer_norm_elementwise_affine,
)

__init__

__init__(
    config: OPTConfig,
    cache_config: Optional[CacheConfig] = None,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
)
Source code in vllm/model_executor/models/opt.py
def __init__(
    self,
    config: OPTConfig,
    cache_config: Optional[CacheConfig] = None,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
):
    super().__init__()
    self.config = config
    self.embed_dim = config.hidden_size
    self.self_attn = OPTAttention(
        embed_dim=self.embed_dim,
        num_heads=config.num_attention_heads,
        bias=config.enable_bias,
        cache_config=cache_config,
        quant_config=quant_config,
        prefix=f"{prefix}.self_attn",
    )
    self.do_layer_norm_before = config.do_layer_norm_before

    self.self_attn_layer_norm = nn.LayerNorm(
        self.embed_dim,
        elementwise_affine=config.layer_norm_elementwise_affine)
    self.fc1 = ColumnParallelLinear(
        self.embed_dim,
        config.ffn_dim,
        bias=config.enable_bias,
        quant_config=quant_config,
        prefix=f"{prefix}.fc1",
    )
    self.activation_fn = get_act_fn(config.activation_function)
    self.fc2 = RowParallelLinear(
        config.ffn_dim,
        self.embed_dim,
        bias=config.enable_bias,
        quant_config=quant_config,
        prefix=f"{prefix}.fc2",
    )
    self.final_layer_norm = nn.LayerNorm(
        self.embed_dim,
        elementwise_affine=config.layer_norm_elementwise_affine)

forward

forward(hidden_states: Tensor) -> Tensor
Source code in vllm/model_executor/models/opt.py
def forward(
    self,
    hidden_states: torch.Tensor,
) -> torch.Tensor:
    # Self Attention
    residual = hidden_states
    # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
    if self.do_layer_norm_before:
        hidden_states = self.self_attn_layer_norm(hidden_states)
    hidden_states = self.self_attn(hidden_states=hidden_states)
    hidden_states = residual + hidden_states
    # 350m applies layer norm AFTER attention
    if not self.do_layer_norm_before:
        hidden_states = self.self_attn_layer_norm(hidden_states)

    # Fully Connected
    residual = hidden_states
    # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
    if self.do_layer_norm_before:
        hidden_states = self.final_layer_norm(hidden_states)
    hidden_states, _ = self.fc1(hidden_states)
    hidden_states = self.activation_fn(hidden_states)
    hidden_states, _ = self.fc2(hidden_states)
    hidden_states = residual + hidden_states
    # 350m applies layer norm AFTER attention
    if not self.do_layer_norm_before:
        hidden_states = self.final_layer_norm(hidden_states)
    return hidden_states

OPTForCausalLM

Bases: Module, SupportsPP

Source code in vllm/model_executor/models/opt.py
class OPTForCausalLM(nn.Module, SupportsPP):
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"]
    }

    hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={
        "decoder.": "model.decoder.",
    })

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        self.config = config
        self.quant_config = quant_config
        self.model = OPTModel(vllm_config=vllm_config,
                              prefix=maybe_prefix(prefix, "model"))
        if self.config.tie_word_embeddings:
            self.lm_head = self.model.decoder.embed_tokens
        else:
            self.lm_head = ParallelLMHead(config.vocab_size,
                                          config.word_embed_proj_dim)
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        hidden_states = self.model(input_ids, positions, intermediate_tensors,
                                   inputs_embeds)
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
        logits = self.logits_processor(self.lm_head, hidden_states,
                                       sampling_metadata)
        return logits

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=(["lm_head.weight"]
                           if self.config.tie_word_embeddings else None),
        )
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

config instance-attribute

config = config

hf_to_vllm_mapper class-attribute instance-attribute

hf_to_vllm_mapper = WeightsMapper(
    orig_to_new_prefix={"decoder.": "model.decoder."}
)

lm_head instance-attribute

lm_head = embed_tokens

logits_processor instance-attribute

logits_processor = LogitsProcessor(vocab_size)

make_empty_intermediate_tensors instance-attribute

make_empty_intermediate_tensors = (
    make_empty_intermediate_tensors
)

model instance-attribute

model = OPTModel(
    vllm_config=vllm_config,
    prefix=maybe_prefix(prefix, "model"),
)

packed_modules_mapping class-attribute instance-attribute

packed_modules_mapping = {
    "qkv_proj": ["q_proj", "k_proj", "v_proj"],
    "gate_up_proj": ["gate_proj", "up_proj"],
}

quant_config instance-attribute

quant_config = quant_config

__init__

__init__(*, vllm_config: VllmConfig, prefix: str = '')
Source code in vllm/model_executor/models/opt.py
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
    super().__init__()
    config = vllm_config.model_config.hf_config
    quant_config = vllm_config.quant_config
    self.config = config
    self.quant_config = quant_config
    self.model = OPTModel(vllm_config=vllm_config,
                          prefix=maybe_prefix(prefix, "model"))
    if self.config.tie_word_embeddings:
        self.lm_head = self.model.decoder.embed_tokens
    else:
        self.lm_head = ParallelLMHead(config.vocab_size,
                                      config.word_embed_proj_dim)
    self.logits_processor = LogitsProcessor(config.vocab_size)
    self.make_empty_intermediate_tensors = (
        self.model.make_empty_intermediate_tensors)

compute_logits

compute_logits(
    hidden_states: Tensor,
    sampling_metadata: SamplingMetadata,
) -> Optional[Tensor]
Source code in vllm/model_executor/models/opt.py
def compute_logits(
    self,
    hidden_states: torch.Tensor,
    sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
    logits = self.logits_processor(self.lm_head, hidden_states,
                                   sampling_metadata)
    return logits

forward

forward(
    input_ids: Tensor,
    positions: Tensor,
    intermediate_tensors: Optional[
        IntermediateTensors
    ] = None,
    inputs_embeds: Optional[Tensor] = None,
) -> Union[Tensor, IntermediateTensors]
Source code in vllm/model_executor/models/opt.py
def forward(
    self,
    input_ids: torch.Tensor,
    positions: torch.Tensor,
    intermediate_tensors: Optional[IntermediateTensors] = None,
    inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
    hidden_states = self.model(input_ids, positions, intermediate_tensors,
                               inputs_embeds)
    return hidden_states

get_input_embeddings

get_input_embeddings(input_ids: Tensor) -> Tensor
Source code in vllm/model_executor/models/opt.py
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
    return self.model.get_input_embeddings(input_ids)

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]
Source code in vllm/model_executor/models/opt.py
def load_weights(self, weights: Iterable[tuple[str,
                                               torch.Tensor]]) -> set[str]:
    loader = AutoWeightsLoader(
        self,
        skip_prefixes=(["lm_head.weight"]
                       if self.config.tie_word_embeddings else None),
    )
    return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

OPTLearnedPositionalEmbedding

Bases: Embedding

Source code in vllm/model_executor/models/opt.py
class OPTLearnedPositionalEmbedding(nn.Embedding):

    def __init__(self, num_embeddings: int, embedding_dim: int):
        # OPT is set up so that if padding_idx is specified then offset the
        # embedding ids by 2 and adjust num_embeddings appropriately. Other
        # models don't have this hack
        self.offset = 2
        super().__init__(num_embeddings + self.offset, embedding_dim)

    def forward(self, positions: torch.Tensor):
        return super().forward(positions + self.offset)

offset instance-attribute

offset = 2

__init__

__init__(num_embeddings: int, embedding_dim: int)
Source code in vllm/model_executor/models/opt.py
def __init__(self, num_embeddings: int, embedding_dim: int):
    # OPT is set up so that if padding_idx is specified then offset the
    # embedding ids by 2 and adjust num_embeddings appropriately. Other
    # models don't have this hack
    self.offset = 2
    super().__init__(num_embeddings + self.offset, embedding_dim)

forward

forward(positions: Tensor)
Source code in vllm/model_executor/models/opt.py
def forward(self, positions: torch.Tensor):
    return super().forward(positions + self.offset)

OPTModel

Bases: Module

Source code in vllm/model_executor/models/opt.py
@support_torch_compile
class OPTModel(nn.Module):

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

        self.decoder = OPTDecoder(config,
                                  cache_config,
                                  quant_config,
                                  prefix=f"{prefix}.decoder")
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.hidden_size))

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.decoder.get_input_embeddings(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors],
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        return self.decoder(input_ids,
                            positions,
                            intermediate_tensors,
                            inputs_embeds=inputs_embeds)

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
        loaded_params: set[str] = set()
        for name, loaded_weight in weights:
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

decoder instance-attribute

decoder = OPTDecoder(
    config,
    cache_config,
    quant_config,
    prefix=f"{prefix}.decoder",
)

make_empty_intermediate_tensors instance-attribute

make_empty_intermediate_tensors = (
    make_empty_intermediate_tensors_factory(
        ["hidden_states"], hidden_size
    )
)

__init__

__init__(*, vllm_config: VllmConfig, prefix: str = '')
Source code in vllm/model_executor/models/opt.py
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
    super().__init__()

    config = vllm_config.model_config.hf_config
    cache_config = vllm_config.cache_config
    quant_config = vllm_config.quant_config

    self.decoder = OPTDecoder(config,
                              cache_config,
                              quant_config,
                              prefix=f"{prefix}.decoder")
    self.make_empty_intermediate_tensors = (
        make_empty_intermediate_tensors_factory(["hidden_states"],
                                                config.hidden_size))

forward

forward(
    input_ids: Tensor,
    positions: Tensor,
    intermediate_tensors: Optional[IntermediateTensors],
    inputs_embeds: Optional[Tensor] = None,
) -> Union[Tensor, IntermediateTensors]
Source code in vllm/model_executor/models/opt.py
def forward(
    self,
    input_ids: torch.Tensor,
    positions: torch.Tensor,
    intermediate_tensors: Optional[IntermediateTensors],
    inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
    return self.decoder(input_ids,
                        positions,
                        intermediate_tensors,
                        inputs_embeds=inputs_embeds)

get_input_embeddings

get_input_embeddings(input_ids: Tensor) -> Tensor
Source code in vllm/model_executor/models/opt.py
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
    return self.decoder.get_input_embeddings(input_ids)

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]
Source code in vllm/model_executor/models/opt.py
def load_weights(self, weights: Iterable[tuple[str,
                                               torch.Tensor]]) -> set[str]:
    stacked_params_mapping = [
        # (param_name, shard_name, shard_id)
        ("qkv_proj", "q_proj", "q"),
        ("qkv_proj", "k_proj", "k"),
        ("qkv_proj", "v_proj", "v"),
    ]
    params_dict = dict(self.named_parameters(remove_duplicate=False))
    loaded_params: set[str] = set()
    for name, loaded_weight in weights:
        for (param_name, weight_name, shard_id) in stacked_params_mapping:
            if weight_name not in name:
                continue
            name = name.replace(weight_name, param_name)
            # Skip loading extra bias for GPTQ models.
            if name.endswith(".bias") and name not in params_dict:
                continue
            if is_pp_missing_parameter(name, self):
                continue
            param = params_dict[name]
            weight_loader = param.weight_loader
            weight_loader(param, loaded_weight, shard_id)
            break
        else:
            # Skip loading extra bias for GPTQ models.
            if name.endswith(".bias") and name not in params_dict:
                continue
            if is_pp_missing_parameter(name, self):
                continue
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)
        loaded_params.add(name)
    return loaded_params