Skip to content

vllm.model_executor.models.pixtral

PATCH_MERGE module-attribute

PATCH_MERGE = 'patch_merge'

USE_XFORMERS_OPS module-attribute

USE_XFORMERS_OPS = True

Attention

Bases: Module

Source code in vllm/model_executor/models/pixtral.py
class Attention(nn.Module):

    def __init__(self, args: VisionEncoderArgs):
        super().__init__()
        self.args = args
        assert not args.hidden_size % args.num_attention_heads
        self.n_heads = args.num_attention_heads
        self.head_dim = args.hidden_size // args.num_attention_heads

        self.wq = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
        self.wk = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
        self.wv = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
        self.wo = nn.Linear(args.hidden_size, args.hidden_size, bias=False)

    def forward(
        self,
        x: torch.Tensor,
        mask: torch.Tensor,
        freqs_cis: torch.Tensor,
    ) -> torch.Tensor:
        batch, patches, _ = x.shape

        q, k, v = self.wq(x), self.wk(x), self.wv(x)
        q = q.reshape(batch, patches, self.n_heads, self.head_dim)
        k = k.reshape(batch, patches, self.n_heads, self.head_dim)
        v = v.reshape(batch, patches, self.n_heads, self.head_dim)

        q, k = apply_rotary_emb_vit(q, k, freqs_cis=freqs_cis)
        out = xops.memory_efficient_attention(q, k, v, attn_bias=mask)
        out = out.reshape(batch, patches, self.n_heads * self.head_dim)
        return self.wo(out)

args instance-attribute

args = args

head_dim instance-attribute

head_dim = hidden_size // num_attention_heads

n_heads instance-attribute

n_heads = num_attention_heads

wk instance-attribute

wk = Linear(hidden_size, hidden_size, bias=False)

wo instance-attribute

wo = Linear(hidden_size, hidden_size, bias=False)

wq instance-attribute

wq = Linear(hidden_size, hidden_size, bias=False)

wv instance-attribute

wv = Linear(hidden_size, hidden_size, bias=False)

__init__

__init__(args: VisionEncoderArgs)
Source code in vllm/model_executor/models/pixtral.py
def __init__(self, args: VisionEncoderArgs):
    super().__init__()
    self.args = args
    assert not args.hidden_size % args.num_attention_heads
    self.n_heads = args.num_attention_heads
    self.head_dim = args.hidden_size // args.num_attention_heads

    self.wq = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
    self.wk = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
    self.wv = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
    self.wo = nn.Linear(args.hidden_size, args.hidden_size, bias=False)

forward

forward(
    x: Tensor, mask: Tensor, freqs_cis: Tensor
) -> Tensor
Source code in vllm/model_executor/models/pixtral.py
def forward(
    self,
    x: torch.Tensor,
    mask: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> torch.Tensor:
    batch, patches, _ = x.shape

    q, k, v = self.wq(x), self.wk(x), self.wv(x)
    q = q.reshape(batch, patches, self.n_heads, self.head_dim)
    k = k.reshape(batch, patches, self.n_heads, self.head_dim)
    v = v.reshape(batch, patches, self.n_heads, self.head_dim)

    q, k = apply_rotary_emb_vit(q, k, freqs_cis=freqs_cis)
    out = xops.memory_efficient_attention(q, k, v, attn_bias=mask)
    out = out.reshape(batch, patches, self.n_heads * self.head_dim)
    return self.wo(out)

FeedForward

Bases: Module

Source code in vllm/model_executor/models/pixtral.py
class FeedForward(nn.Module):

    def __init__(self, args: VisionEncoderArgs):
        super().__init__()
        assert args.intermediate_size is not None
        self.w1 = nn.Linear(args.hidden_size,
                            args.intermediate_size,
                            bias=False)
        self.w2 = nn.Linear(args.intermediate_size,
                            args.hidden_size,
                            bias=False)
        self.w3 = nn.Linear(args.hidden_size,
                            args.intermediate_size,
                            bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

w1 instance-attribute

w1 = Linear(hidden_size, intermediate_size, bias=False)

w2 instance-attribute

w2 = Linear(intermediate_size, hidden_size, bias=False)

w3 instance-attribute

w3 = Linear(hidden_size, intermediate_size, bias=False)

__init__

__init__(args: VisionEncoderArgs)
Source code in vllm/model_executor/models/pixtral.py
def __init__(self, args: VisionEncoderArgs):
    super().__init__()
    assert args.intermediate_size is not None
    self.w1 = nn.Linear(args.hidden_size,
                        args.intermediate_size,
                        bias=False)
    self.w2 = nn.Linear(args.intermediate_size,
                        args.hidden_size,
                        bias=False)
    self.w3 = nn.Linear(args.hidden_size,
                        args.intermediate_size,
                        bias=False)

forward

forward(x: Tensor) -> Tensor
Source code in vllm/model_executor/models/pixtral.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    return self.w2(F.silu(self.w1(x)) * self.w3(x))

PatchMerger

Bases: Module

Learned merging of spatial_merge_size ** 2 patches

Source code in vllm/model_executor/models/pixtral.py
class PatchMerger(nn.Module):
    """
    Learned merging of spatial_merge_size ** 2 patches
    """

    def __init__(
        self,
        vision_encoder_dim: int,
        spatial_merge_size: int,
        use_mlp_bias: bool = False,
    ) -> None:
        super().__init__()

        mlp_input_dim = vision_encoder_dim * (spatial_merge_size**2)

        self.spatial_merge_size = spatial_merge_size
        self.mlp_input_dim = mlp_input_dim

        self.merging_layer = nn.Linear(
            mlp_input_dim,
            vision_encoder_dim,
            bias=use_mlp_bias,
        )

    def forward(self, x: torch.Tensor,
                image_sizes: list[tuple[int, int]]) -> torch.Tensor:
        # image_sizes specified in tokens
        assert sum([h * w for h, w in image_sizes]) == len(x)

        # x is (N, vision_encoder_dim)
        x = self.permute(x, image_sizes)

        # x is (N / spatial_merge_size ** 2,
        #       vision_encoder_dim * spatial_merge_size ** 2)
        x = self.merging_layer(x)

        # x is (N / spatial_merge_size ** 2, vision_encoder_dim)
        return x

    def permute(
        self,
        x: torch.Tensor,
        image_sizes: list[tuple[int, int]],
    ) -> torch.Tensor:
        """
        Args:
            x: (N, D) where N is flattened and concatenated patch tokens
                for all images
            image_sizes: list of tuple of (height, width) in tokens for
                each image
        Returns:
            image_features: reorders patch tokens so each grid of
                (spatial_merge_size, spatial_merge_size) is contiguous.
                now (N / spatial_merge_size ** 2, D * spatial_merge_size ** 2)
        """

        sub_grids = get_sub_grids(
            x=x,
            image_sizes=image_sizes,
            spatial_merge_size=self.spatial_merge_size
        )  # list of [d x sub_grid_size x sub_grid_size x n_patches]
        permuted_tensor: list[torch.Tensor] = []
        for grid in sub_grids:
            n_patches = grid.shape[-1]
            permuted_tensor.append(grid.view(-1, n_patches).t(
            ))  # n_patches x d * sub_grid_size * sub_grid_size
        return torch.cat(
            permuted_tensor, dim=0
        )  # (N / spatial_merge_size ** 2, d * spatial_merge_size ** 2)

merging_layer instance-attribute

merging_layer = Linear(
    mlp_input_dim, vision_encoder_dim, bias=use_mlp_bias
)

mlp_input_dim instance-attribute

mlp_input_dim = mlp_input_dim

spatial_merge_size instance-attribute

spatial_merge_size = spatial_merge_size

__init__

__init__(
    vision_encoder_dim: int,
    spatial_merge_size: int,
    use_mlp_bias: bool = False,
) -> None
Source code in vllm/model_executor/models/pixtral.py
def __init__(
    self,
    vision_encoder_dim: int,
    spatial_merge_size: int,
    use_mlp_bias: bool = False,
) -> None:
    super().__init__()

    mlp_input_dim = vision_encoder_dim * (spatial_merge_size**2)

    self.spatial_merge_size = spatial_merge_size
    self.mlp_input_dim = mlp_input_dim

    self.merging_layer = nn.Linear(
        mlp_input_dim,
        vision_encoder_dim,
        bias=use_mlp_bias,
    )

forward

forward(
    x: Tensor, image_sizes: list[tuple[int, int]]
) -> Tensor
Source code in vllm/model_executor/models/pixtral.py
def forward(self, x: torch.Tensor,
            image_sizes: list[tuple[int, int]]) -> torch.Tensor:
    # image_sizes specified in tokens
    assert sum([h * w for h, w in image_sizes]) == len(x)

    # x is (N, vision_encoder_dim)
    x = self.permute(x, image_sizes)

    # x is (N / spatial_merge_size ** 2,
    #       vision_encoder_dim * spatial_merge_size ** 2)
    x = self.merging_layer(x)

    # x is (N / spatial_merge_size ** 2, vision_encoder_dim)
    return x

permute

permute(
    x: Tensor, image_sizes: list[tuple[int, int]]
) -> Tensor

Parameters:

Name Type Description Default
x Tensor

(N, D) where N is flattened and concatenated patch tokens for all images

required
image_sizes list[tuple[int, int]]

list of tuple of (height, width) in tokens for each image

required

Returns: image_features: reorders patch tokens so each grid of (spatial_merge_size, spatial_merge_size) is contiguous. now (N / spatial_merge_size ** 2, D * spatial_merge_size ** 2)

Source code in vllm/model_executor/models/pixtral.py
def permute(
    self,
    x: torch.Tensor,
    image_sizes: list[tuple[int, int]],
) -> torch.Tensor:
    """
    Args:
        x: (N, D) where N is flattened and concatenated patch tokens
            for all images
        image_sizes: list of tuple of (height, width) in tokens for
            each image
    Returns:
        image_features: reorders patch tokens so each grid of
            (spatial_merge_size, spatial_merge_size) is contiguous.
            now (N / spatial_merge_size ** 2, D * spatial_merge_size ** 2)
    """

    sub_grids = get_sub_grids(
        x=x,
        image_sizes=image_sizes,
        spatial_merge_size=self.spatial_merge_size
    )  # list of [d x sub_grid_size x sub_grid_size x n_patches]
    permuted_tensor: list[torch.Tensor] = []
    for grid in sub_grids:
        n_patches = grid.shape[-1]
        permuted_tensor.append(grid.view(-1, n_patches).t(
        ))  # n_patches x d * sub_grid_size * sub_grid_size
    return torch.cat(
        permuted_tensor, dim=0
    )  # (N / spatial_merge_size ** 2, d * spatial_merge_size ** 2)

PixtralDummyInputsBuilder

Bases: BaseDummyInputsBuilder[PixtralProcessingInfo]

Source code in vllm/model_executor/models/pixtral.py
class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]):

    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        return ""

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)

        target_width, target_height = \
            self.info.get_image_size_with_most_features()

        return {
            "image":
            self._get_dummy_images(width=target_width,
                                   height=target_height,
                                   num_images=num_images)
        }

    def get_dummy_processor_inputs(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> ProcessorInputs:
        tokenizer = self.info.get_tokenizer()

        dummy_text = self.get_dummy_text(mm_counts)
        dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts)
        dummy_images = dummy_mm_data.get("image", [])
        tokenization_kwargs = {"truncation": False}

        request = ChatCompletionRequest(messages=[
            UserMessage(content=[
                TextChunk(text=dummy_text),
                *(ImageChunk(image=image) for image in dummy_images),
            ]),
        ])
        res = tokenizer.mistral.encode_chat_completion(request)
        dummy_tokens = res.tokens

        return ProcessorInputs(prompt=dummy_tokens,
                               mm_data=dummy_mm_data,
                               tokenization_kwargs=tokenization_kwargs)

get_dummy_mm_data

get_dummy_mm_data(
    seq_len: int, mm_counts: Mapping[str, int]
) -> MultiModalDataDict
Source code in vllm/model_executor/models/pixtral.py
def get_dummy_mm_data(
    self,
    seq_len: int,
    mm_counts: Mapping[str, int],
) -> MultiModalDataDict:
    num_images = mm_counts.get("image", 0)

    target_width, target_height = \
        self.info.get_image_size_with_most_features()

    return {
        "image":
        self._get_dummy_images(width=target_width,
                               height=target_height,
                               num_images=num_images)
    }

get_dummy_processor_inputs

get_dummy_processor_inputs(
    seq_len: int, mm_counts: Mapping[str, int]
) -> ProcessorInputs
Source code in vllm/model_executor/models/pixtral.py
def get_dummy_processor_inputs(
    self,
    seq_len: int,
    mm_counts: Mapping[str, int],
) -> ProcessorInputs:
    tokenizer = self.info.get_tokenizer()

    dummy_text = self.get_dummy_text(mm_counts)
    dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts)
    dummy_images = dummy_mm_data.get("image", [])
    tokenization_kwargs = {"truncation": False}

    request = ChatCompletionRequest(messages=[
        UserMessage(content=[
            TextChunk(text=dummy_text),
            *(ImageChunk(image=image) for image in dummy_images),
        ]),
    ])
    res = tokenizer.mistral.encode_chat_completion(request)
    dummy_tokens = res.tokens

    return ProcessorInputs(prompt=dummy_tokens,
                           mm_data=dummy_mm_data,
                           tokenization_kwargs=tokenization_kwargs)

get_dummy_text

get_dummy_text(mm_counts: Mapping[str, int]) -> str
Source code in vllm/model_executor/models/pixtral.py
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
    return ""

PixtralForConditionalGeneration

Bases: Module, SupportsMultiModal, SupportsPP

Source code in vllm/model_executor/models/pixtral.py
@MULTIMODAL_REGISTRY.register_processor(PixtralMultiModalProcessor,
                                        info=PixtralProcessingInfo,
                                        dummy_inputs=PixtralDummyInputsBuilder)
class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
                                      SupportsPP):

    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality.startswith("image"):
            return None

        raise ValueError("Only image modality is supported")

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        multimodal_config = vllm_config.model_config.multimodal_config
        self.config = config
        self.multimodal_config = multimodal_config

        dataclass_fields = {field.name for field in fields(VisionEncoderArgs)}
        vision_args = {
            key: value
            for key, value in self.config.vision_config.to_dict().items()
            if key in dataclass_fields
        }

        self.vision_args = VisionEncoderArgs(**vision_args)

        # init MistralForCausalLM
        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )

        self.vision_encoder = VisionTransformer(self.vision_args)

        if self.vision_args.add_pre_mm_projector_layer_norm:
            self.pre_mm_projector_norm = RMSNorm(self.vision_args.hidden_size,
                                                 eps=1e-5)

        if self.vision_args.mm_projector_id == PATCH_MERGE:
            self.patch_merger = PatchMerger(
                vision_encoder_dim=self.vision_args.hidden_size,
                spatial_merge_size=self.vision_args.spatial_merge_size,
                use_mlp_bias=False,
            )

        self.vision_language_adapter = VisionLanguageAdapter(
            self.vision_args, dim=config.text_config.hidden_size)

        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

    def _parse_and_validate_image_input(
            self, **kwargs: object) -> Optional[PixtralImagePixelInputs]:
        images = kwargs.pop("images", None)
        if images is None:
            return None

        if not isinstance(images, (torch.Tensor, list)):
            raise ValueError("Incorrect type of images. "
                             f"Got type: {type(images)}")

        return PixtralImagePixelInputs(
            type="pixel_values",
            images=flatten_bn(images),
        )

    def _process_image_input(
        self,
        image_input: PixtralImagePixelInputs,
    ) -> tuple[torch.Tensor, ...]:
        images = image_input["images"]
        image_features = self.vision_encoder(images)
        feature_sizes = [
            image_feature.shape[0] for image_feature in image_features
        ]
        image_features = torch.cat(image_features)
        if self.vision_args.add_pre_mm_projector_layer_norm:
            image_features = self.pre_mm_projector_norm(image_features)
        if self.vision_args.mm_projector_id == PATCH_MERGE:
            patch_size = self.vision_args.patch_size
            spatial_merge_size_square = self.vision_args.spatial_merge_size**2
            img_patch_dims = [(img.shape[1] // patch_size,
                               img.shape[2] // patch_size) for img in images]
            feature_sizes = [
                feature_size // spatial_merge_size_square
                for feature_size in feature_sizes
            ]
            image_features = self.patch_merger(image_features,
                                               image_sizes=img_patch_dims)
        image_embeds = self.vision_language_adapter(image_features)
        image_embeds = torch.split(image_embeds, feature_sizes)
        return image_embeds

    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

    def get_multimodal_embeddings(self,
                                  **kwargs: object) -> MultiModalEmbeddings:
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return []

        return self._process_image_input(image_input)

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
    ) -> torch.Tensor:
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
        if multimodal_embeddings is not None \
            and len(multimodal_embeddings) != 0:
            inputs_embeds = merge_multimodal_embeddings(
                input_ids,
                inputs_embeds,
                multimodal_embeddings,
                self.vision_args.image_token_id,
            )
        return inputs_embeds

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        **kwargs: object,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        """Run forward pass for pixtral."""
        if intermediate_tensors is not None:
            inputs_embeds = None

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
        elif inputs_embeds is None:
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None

        hidden_states = self.language_model.model(input_ids,
                                                  positions,
                                                  intermediate_tensors,
                                                  inputs_embeds=inputs_embeds)

        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):

        def is_vision_encoder_weights(weight: tuple[str, torch.Tensor]):
            return weight[0].startswith("vision_encoder")

        def is_vision_lang_adapter_weights(weight: tuple[str, torch.Tensor]):
            return weight[0].startswith("vision_language_adapter")

        def is_patch_merger(weight: tuple[str, torch.Tensor]):
            return weight[0].startswith("patch_merger")

        def is_pre_mm_projector_norm(weight: tuple[str, torch.Tensor]):
            return weight[0].startswith("pre_mm_projector_norm")

        # Get references to parameters for direct loading
        vision_encoder_dict = dict(self.vision_encoder.named_parameters())
        patch_merger_dict = dict(self.patch_merger.named_parameters(
        )) if self.vision_args.mm_projector_id == PATCH_MERGE else dict()
        pre_mm_projector_norm_dict = dict(
            self.pre_mm_projector_norm.named_parameters(
            )) if self.vision_args.add_pre_mm_projector_layer_norm else dict()
        vision_lang_adapter_dict = dict(
            self.vision_language_adapter.named_parameters())

        def llm_weights_generator():
            # Single pass over weights
            for name, w in weights:
                if is_vision_encoder_weights((name, w)):
                    # Load vision encoder weights directly
                    trimmed_name = '.'.join(name.split(".")[1:])
                    param = vision_encoder_dict[trimmed_name]
                    with torch.no_grad():
                        default_weight_loader(param, w)
                elif is_patch_merger((name, w)):
                    # Load vision patch merger weights directly
                    trimmed_name = '.'.join(name.split(".")[1:])
                    param = patch_merger_dict[trimmed_name]
                    with torch.no_grad():
                        default_weight_loader(param, w)
                elif is_pre_mm_projector_norm((name, w)):
                    # Load vision pre_mm_projector_norm weights directly
                    trimmed_name = '.'.join(name.split(".")[1:])
                    param = pre_mm_projector_norm_dict[trimmed_name]
                    with torch.no_grad():
                        default_weight_loader(param, w)
                elif is_vision_lang_adapter_weights((name, w)):
                    # Load vision-language adapter weights directly
                    trimmed_name = '.'.join(name.split(".")[1:])
                    param = vision_lang_adapter_dict[trimmed_name]
                    with torch.no_grad():
                        default_weight_loader(param, w)
                else:
                    # LLM weights: yield them to be loaded
                    # by language_model.load_weights
                    yield (name, w)

        # Now we call the language model load with the generator
        self.language_model.load_weights(llm_weights_generator())

config instance-attribute

config = config

language_model instance-attribute

language_model = init_vllm_registered_model(
    vllm_config=vllm_config,
    hf_config=text_config,
    prefix=maybe_prefix(prefix, "language_model"),
)

make_empty_intermediate_tensors instance-attribute

make_empty_intermediate_tensors = (
    make_empty_intermediate_tensors
)

multimodal_config instance-attribute

multimodal_config = multimodal_config

patch_merger instance-attribute

patch_merger = PatchMerger(
    vision_encoder_dim=hidden_size,
    spatial_merge_size=spatial_merge_size,
    use_mlp_bias=False,
)

pre_mm_projector_norm instance-attribute

pre_mm_projector_norm = RMSNorm(hidden_size, eps=1e-05)

vision_args instance-attribute

vision_args = VisionEncoderArgs(**vision_args)

vision_encoder instance-attribute

vision_encoder = VisionTransformer(vision_args)

vision_language_adapter instance-attribute

vision_language_adapter = VisionLanguageAdapter(
    vision_args, dim=hidden_size
)

__init__

__init__(*, vllm_config: VllmConfig, prefix: str = '')
Source code in vllm/model_executor/models/pixtral.py
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
    super().__init__()
    config = vllm_config.model_config.hf_config
    multimodal_config = vllm_config.model_config.multimodal_config
    self.config = config
    self.multimodal_config = multimodal_config

    dataclass_fields = {field.name for field in fields(VisionEncoderArgs)}
    vision_args = {
        key: value
        for key, value in self.config.vision_config.to_dict().items()
        if key in dataclass_fields
    }

    self.vision_args = VisionEncoderArgs(**vision_args)

    # init MistralForCausalLM
    self.language_model = init_vllm_registered_model(
        vllm_config=vllm_config,
        hf_config=config.text_config,
        prefix=maybe_prefix(prefix, "language_model"),
    )

    self.vision_encoder = VisionTransformer(self.vision_args)

    if self.vision_args.add_pre_mm_projector_layer_norm:
        self.pre_mm_projector_norm = RMSNorm(self.vision_args.hidden_size,
                                             eps=1e-5)

    if self.vision_args.mm_projector_id == PATCH_MERGE:
        self.patch_merger = PatchMerger(
            vision_encoder_dim=self.vision_args.hidden_size,
            spatial_merge_size=self.vision_args.spatial_merge_size,
            use_mlp_bias=False,
        )

    self.vision_language_adapter = VisionLanguageAdapter(
        self.vision_args, dim=config.text_config.hidden_size)

    self.make_empty_intermediate_tensors = (
        self.language_model.make_empty_intermediate_tensors)

_parse_and_validate_image_input

_parse_and_validate_image_input(
    **kwargs: object,
) -> Optional[PixtralImagePixelInputs]
Source code in vllm/model_executor/models/pixtral.py
def _parse_and_validate_image_input(
        self, **kwargs: object) -> Optional[PixtralImagePixelInputs]:
    images = kwargs.pop("images", None)
    if images is None:
        return None

    if not isinstance(images, (torch.Tensor, list)):
        raise ValueError("Incorrect type of images. "
                         f"Got type: {type(images)}")

    return PixtralImagePixelInputs(
        type="pixel_values",
        images=flatten_bn(images),
    )

_process_image_input

_process_image_input(
    image_input: PixtralImagePixelInputs,
) -> tuple[Tensor, ...]
Source code in vllm/model_executor/models/pixtral.py
def _process_image_input(
    self,
    image_input: PixtralImagePixelInputs,
) -> tuple[torch.Tensor, ...]:
    images = image_input["images"]
    image_features = self.vision_encoder(images)
    feature_sizes = [
        image_feature.shape[0] for image_feature in image_features
    ]
    image_features = torch.cat(image_features)
    if self.vision_args.add_pre_mm_projector_layer_norm:
        image_features = self.pre_mm_projector_norm(image_features)
    if self.vision_args.mm_projector_id == PATCH_MERGE:
        patch_size = self.vision_args.patch_size
        spatial_merge_size_square = self.vision_args.spatial_merge_size**2
        img_patch_dims = [(img.shape[1] // patch_size,
                           img.shape[2] // patch_size) for img in images]
        feature_sizes = [
            feature_size // spatial_merge_size_square
            for feature_size in feature_sizes
        ]
        image_features = self.patch_merger(image_features,
                                           image_sizes=img_patch_dims)
    image_embeds = self.vision_language_adapter(image_features)
    image_embeds = torch.split(image_embeds, feature_sizes)
    return image_embeds

compute_logits

compute_logits(
    hidden_states: Tensor,
    sampling_metadata: SamplingMetadata,
) -> Optional[Tensor]
Source code in vllm/model_executor/models/pixtral.py
def compute_logits(
    self,
    hidden_states: torch.Tensor,
    sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
    return self.language_model.compute_logits(hidden_states,
                                              sampling_metadata)

forward

forward(
    input_ids: Tensor,
    positions: Tensor,
    intermediate_tensors: Optional[
        IntermediateTensors
    ] = None,
    inputs_embeds: Optional[Tensor] = None,
    **kwargs: object,
) -> Union[Tensor, IntermediateTensors]

Run forward pass for pixtral.

Source code in vllm/model_executor/models/pixtral.py
def forward(
    self,
    input_ids: torch.Tensor,
    positions: torch.Tensor,
    intermediate_tensors: Optional[IntermediateTensors] = None,
    inputs_embeds: Optional[torch.Tensor] = None,
    **kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]:
    """Run forward pass for pixtral."""
    if intermediate_tensors is not None:
        inputs_embeds = None

    # NOTE: In v1, inputs_embeds is always generated at model runner, this
    # condition is for v0 compatibility.
    elif inputs_embeds is None:
        vision_embeddings = self.get_multimodal_embeddings(**kwargs)
        inputs_embeds = self.get_input_embeddings(input_ids,
                                                  vision_embeddings)
        input_ids = None

    hidden_states = self.language_model.model(input_ids,
                                              positions,
                                              intermediate_tensors,
                                              inputs_embeds=inputs_embeds)

    return hidden_states

get_input_embeddings

get_input_embeddings(
    input_ids: Tensor,
    multimodal_embeddings: Optional[
        MultiModalEmbeddings
    ] = None,
) -> Tensor
Source code in vllm/model_executor/models/pixtral.py
def get_input_embeddings(
    self,
    input_ids: torch.Tensor,
    multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
    inputs_embeds = self.language_model.get_input_embeddings(input_ids)
    if multimodal_embeddings is not None \
        and len(multimodal_embeddings) != 0:
        inputs_embeds = merge_multimodal_embeddings(
            input_ids,
            inputs_embeds,
            multimodal_embeddings,
            self.vision_args.image_token_id,
        )
    return inputs_embeds

get_language_model

get_language_model() -> Module
Source code in vllm/model_executor/models/pixtral.py
def get_language_model(self) -> torch.nn.Module:
    return self.language_model

get_multimodal_embeddings

get_multimodal_embeddings(
    **kwargs: object,
) -> MultiModalEmbeddings
Source code in vllm/model_executor/models/pixtral.py
def get_multimodal_embeddings(self,
                              **kwargs: object) -> MultiModalEmbeddings:
    image_input = self._parse_and_validate_image_input(**kwargs)
    if image_input is None:
        return []

    return self._process_image_input(image_input)

get_placeholder_str classmethod

get_placeholder_str(modality: str, i: int) -> Optional[str]
Source code in vllm/model_executor/models/pixtral.py
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
    if modality.startswith("image"):
        return None

    raise ValueError("Only image modality is supported")

load_weights

load_weights(weights: Iterable[tuple[str, Tensor]])
Source code in vllm/model_executor/models/pixtral.py
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):

    def is_vision_encoder_weights(weight: tuple[str, torch.Tensor]):
        return weight[0].startswith("vision_encoder")

    def is_vision_lang_adapter_weights(weight: tuple[str, torch.Tensor]):
        return weight[0].startswith("vision_language_adapter")

    def is_patch_merger(weight: tuple[str, torch.Tensor]):
        return weight[0].startswith("patch_merger")

    def is_pre_mm_projector_norm(weight: tuple[str, torch.Tensor]):
        return weight[0].startswith("pre_mm_projector_norm")

    # Get references to parameters for direct loading
    vision_encoder_dict = dict(self.vision_encoder.named_parameters())
    patch_merger_dict = dict(self.patch_merger.named_parameters(
    )) if self.vision_args.mm_projector_id == PATCH_MERGE else dict()
    pre_mm_projector_norm_dict = dict(
        self.pre_mm_projector_norm.named_parameters(
        )) if self.vision_args.add_pre_mm_projector_layer_norm else dict()
    vision_lang_adapter_dict = dict(
        self.vision_language_adapter.named_parameters())

    def llm_weights_generator():
        # Single pass over weights
        for name, w in weights:
            if is_vision_encoder_weights((name, w)):
                # Load vision encoder weights directly
                trimmed_name = '.'.join(name.split(".")[1:])
                param = vision_encoder_dict[trimmed_name]
                with torch.no_grad():
                    default_weight_loader(param, w)
            elif is_patch_merger((name, w)):
                # Load vision patch merger weights directly
                trimmed_name = '.'.join(name.split(".")[1:])
                param = patch_merger_dict[trimmed_name]
                with torch.no_grad():
                    default_weight_loader(param, w)
            elif is_pre_mm_projector_norm((name, w)):
                # Load vision pre_mm_projector_norm weights directly
                trimmed_name = '.'.join(name.split(".")[1:])
                param = pre_mm_projector_norm_dict[trimmed_name]
                with torch.no_grad():
                    default_weight_loader(param, w)
            elif is_vision_lang_adapter_weights((name, w)):
                # Load vision-language adapter weights directly
                trimmed_name = '.'.join(name.split(".")[1:])
                param = vision_lang_adapter_dict[trimmed_name]
                with torch.no_grad():
                    default_weight_loader(param, w)
            else:
                # LLM weights: yield them to be loaded
                # by language_model.load_weights
                yield (name, w)

    # Now we call the language model load with the generator
    self.language_model.load_weights(llm_weights_generator())

PixtralHFAttention

Bases: Module

Source code in vllm/model_executor/models/pixtral.py
class PixtralHFAttention(nn.Module):

    def __init__(
        self,
        config: PixtralVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        *,
        prefix: str = "",
    ) -> None:
        super().__init__()

        self.config = config
        assert not config.hidden_size % config.num_attention_heads
        self.total_num_heads = config.num_attention_heads
        tp_size = get_tensor_model_parallel_world_size()
        self.n_heads = divide(config.num_attention_heads, tp_size)
        self.head_dim = config.hidden_size // config.num_attention_heads

        self.qkv_proj = QKVParallelLinear(
            hidden_size=config.hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )
        assert self.total_num_heads * self.head_dim == config.hidden_size
        self.o_proj = RowParallelLinear(
            input_size=config.hidden_size,
            output_size=config.hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor,
        position_embeddings: torch.Tensor,
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
        batch, patches, _ = hidden_states.size()

        qkv_states, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv_states.chunk(3, dim=-1)

        # Transpose q and k to apply HF's Rotary Position Embedding
        q = q.view(batch, patches, self.n_heads, self.head_dim).transpose(1, 2)
        k = k.view(batch, patches, self.n_heads, self.head_dim).transpose(1, 2)
        v = v.view(batch, patches, self.n_heads, self.head_dim)
        cos, sin = position_embeddings
        q, k = apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=0)

        if USE_XFORMERS_OPS:
            # Transpose q and k back for attention
            q = q.transpose(1, 2).contiguous()
            k = k.transpose(1, 2).contiguous()

            out = xops.memory_efficient_attention(q,
                                                  k,
                                                  v,
                                                  attn_bias=attention_mask)
        else:
            v = v.transpose(1, 2)
            out = nn.functional.scaled_dot_product_attention(
                q, k, v, attn_mask=attention_mask)
            out = out.transpose(1, 2)

        out = out.view(batch, patches, self.n_heads * self.head_dim)
        attn_output, _ = self.o_proj(out)

        return attn_output, None

config instance-attribute

config = config

head_dim instance-attribute

head_dim = hidden_size // num_attention_heads

n_heads instance-attribute

n_heads = divide(num_attention_heads, tp_size)

o_proj instance-attribute

o_proj = RowParallelLinear(
    input_size=hidden_size,
    output_size=hidden_size,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.o_proj",
)

qkv_proj instance-attribute

qkv_proj = QKVParallelLinear(
    hidden_size=hidden_size,
    head_size=head_dim,
    total_num_heads=total_num_heads,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.qkv_proj",
)

total_num_heads instance-attribute

total_num_heads = num_attention_heads

__init__

__init__(
    config: PixtralVisionConfig,
    quant_config: Optional[QuantizationConfig] = None,
    *,
    prefix: str = "",
) -> None
Source code in vllm/model_executor/models/pixtral.py
def __init__(
    self,
    config: PixtralVisionConfig,
    quant_config: Optional[QuantizationConfig] = None,
    *,
    prefix: str = "",
) -> None:
    super().__init__()

    self.config = config
    assert not config.hidden_size % config.num_attention_heads
    self.total_num_heads = config.num_attention_heads
    tp_size = get_tensor_model_parallel_world_size()
    self.n_heads = divide(config.num_attention_heads, tp_size)
    self.head_dim = config.hidden_size // config.num_attention_heads

    self.qkv_proj = QKVParallelLinear(
        hidden_size=config.hidden_size,
        head_size=self.head_dim,
        total_num_heads=self.total_num_heads,
        bias=False,
        quant_config=quant_config,
        prefix=f"{prefix}.qkv_proj",
    )
    assert self.total_num_heads * self.head_dim == config.hidden_size
    self.o_proj = RowParallelLinear(
        input_size=config.hidden_size,
        output_size=config.hidden_size,
        bias=False,
        quant_config=quant_config,
        prefix=f"{prefix}.o_proj",
    )

forward

forward(
    hidden_states: Tensor,
    attention_mask: Tensor,
    position_embeddings: Tensor,
) -> tuple[Tensor, Optional[Tensor]]
Source code in vllm/model_executor/models/pixtral.py
def forward(
    self,
    hidden_states: torch.Tensor,
    attention_mask: torch.Tensor,
    position_embeddings: torch.Tensor,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
    batch, patches, _ = hidden_states.size()

    qkv_states, _ = self.qkv_proj(hidden_states)
    q, k, v = qkv_states.chunk(3, dim=-1)

    # Transpose q and k to apply HF's Rotary Position Embedding
    q = q.view(batch, patches, self.n_heads, self.head_dim).transpose(1, 2)
    k = k.view(batch, patches, self.n_heads, self.head_dim).transpose(1, 2)
    v = v.view(batch, patches, self.n_heads, self.head_dim)
    cos, sin = position_embeddings
    q, k = apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=0)

    if USE_XFORMERS_OPS:
        # Transpose q and k back for attention
        q = q.transpose(1, 2).contiguous()
        k = k.transpose(1, 2).contiguous()

        out = xops.memory_efficient_attention(q,
                                              k,
                                              v,
                                              attn_bias=attention_mask)
    else:
        v = v.transpose(1, 2)
        out = nn.functional.scaled_dot_product_attention(
            q, k, v, attn_mask=attention_mask)
        out = out.transpose(1, 2)

    out = out.view(batch, patches, self.n_heads * self.head_dim)
    attn_output, _ = self.o_proj(out)

    return attn_output, None

PixtralHFEncoderInfo

Bases: VisionEncoderInfo[PixtralVisionConfig]

Source code in vllm/model_executor/models/pixtral.py
class PixtralHFEncoderInfo(VisionEncoderInfo[PixtralVisionConfig]):

    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        ncols, nrows = self.get_patch_grid_size(
            image_width=image_width,
            image_height=image_height,
        )
        return ncols * nrows

    def get_image_size(self) -> int:
        return self.vision_config.image_size

    def get_patch_size(self) -> int:
        # spatial_merge_size is needed for Mistral3
        spatial_merge_size = getattr(self.hf_config, "spatial_merge_size", 1)
        return self.vision_config.patch_size * spatial_merge_size

    def get_patch_grid_length(self) -> int:
        image_size, patch_size = self.get_image_size(), self.get_patch_size()

        # Since interpolation is applied, the image size need not be divisible
        # assert image_size % patch_size == 0
        return image_size // patch_size

    # Adapted from: https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/pixtral/image_processing_pixtral.py#L99
    def get_patch_grid_size(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> tuple[int, int]:
        max_width = max_height = self.get_image_size()
        patch_width = patch_height = self.get_patch_size()

        ratio = max(image_width / max_width, image_height / max_height)

        if ratio > 1:
            image_width = int(math.floor(image_width / ratio))
            image_height = int(math.floor(image_height / ratio))

        nrows, ncols = _get_pixtral_hf_num_image_tokens(
            (image_height, image_width),
            (patch_height, patch_width),
        )  # type: ignore

        return ncols, nrows

get_image_size

get_image_size() -> int
Source code in vllm/model_executor/models/pixtral.py
def get_image_size(self) -> int:
    return self.vision_config.image_size

get_num_image_tokens

get_num_image_tokens(
    *, image_width: int, image_height: int
) -> int
Source code in vllm/model_executor/models/pixtral.py
def get_num_image_tokens(
    self,
    *,
    image_width: int,
    image_height: int,
) -> int:
    ncols, nrows = self.get_patch_grid_size(
        image_width=image_width,
        image_height=image_height,
    )
    return ncols * nrows

get_patch_grid_length

get_patch_grid_length() -> int
Source code in vllm/model_executor/models/pixtral.py
def get_patch_grid_length(self) -> int:
    image_size, patch_size = self.get_image_size(), self.get_patch_size()

    # Since interpolation is applied, the image size need not be divisible
    # assert image_size % patch_size == 0
    return image_size // patch_size

get_patch_grid_size

get_patch_grid_size(
    *, image_width: int, image_height: int
) -> tuple[int, int]
Source code in vllm/model_executor/models/pixtral.py
def get_patch_grid_size(
    self,
    *,
    image_width: int,
    image_height: int,
) -> tuple[int, int]:
    max_width = max_height = self.get_image_size()
    patch_width = patch_height = self.get_patch_size()

    ratio = max(image_width / max_width, image_height / max_height)

    if ratio > 1:
        image_width = int(math.floor(image_width / ratio))
        image_height = int(math.floor(image_height / ratio))

    nrows, ncols = _get_pixtral_hf_num_image_tokens(
        (image_height, image_width),
        (patch_height, patch_width),
    )  # type: ignore

    return ncols, nrows

get_patch_size

get_patch_size() -> int
Source code in vllm/model_executor/models/pixtral.py
def get_patch_size(self) -> int:
    # spatial_merge_size is needed for Mistral3
    spatial_merge_size = getattr(self.hf_config, "spatial_merge_size", 1)
    return self.vision_config.patch_size * spatial_merge_size

PixtralHFMLP

Bases: Module

Source code in vllm/model_executor/models/pixtral.py
class PixtralHFMLP(nn.Module):

    def __init__(
        self,
        config: PixtralVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        *,
        prefix: str = "",
    ) -> None:
        super().__init__()

        assert config.intermediate_size is not None
        self.gate_up_proj = MergedColumnParallelLinear(
            input_size=config.hidden_size,
            output_sizes=[config.intermediate_size] * 2,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.gate_up_proj")
        self.down_proj = RowParallelLinear(input_size=config.intermediate_size,
                                           output_size=config.hidden_size,
                                           bias=False,
                                           quant_config=quant_config,
                                           prefix=f"{prefix}.down_proj")
        self.act_and_mul = get_act_and_mul_fn(config.hidden_act)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_and_mul(gate_up)
        x, _ = self.down_proj(x)
        return x

act_and_mul instance-attribute

act_and_mul = get_act_and_mul_fn(hidden_act)

down_proj instance-attribute

down_proj = RowParallelLinear(
    input_size=intermediate_size,
    output_size=hidden_size,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.down_proj",
)

gate_up_proj instance-attribute

gate_up_proj = MergedColumnParallelLinear(
    input_size=hidden_size,
    output_sizes=[intermediate_size] * 2,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.gate_up_proj",
)

__init__

__init__(
    config: PixtralVisionConfig,
    quant_config: Optional[QuantizationConfig] = None,
    *,
    prefix: str = "",
) -> None
Source code in vllm/model_executor/models/pixtral.py
def __init__(
    self,
    config: PixtralVisionConfig,
    quant_config: Optional[QuantizationConfig] = None,
    *,
    prefix: str = "",
) -> None:
    super().__init__()

    assert config.intermediate_size is not None
    self.gate_up_proj = MergedColumnParallelLinear(
        input_size=config.hidden_size,
        output_sizes=[config.intermediate_size] * 2,
        bias=False,
        quant_config=quant_config,
        prefix=f"{prefix}.gate_up_proj")
    self.down_proj = RowParallelLinear(input_size=config.intermediate_size,
                                       output_size=config.hidden_size,
                                       bias=False,
                                       quant_config=quant_config,
                                       prefix=f"{prefix}.down_proj")
    self.act_and_mul = get_act_and_mul_fn(config.hidden_act)

forward

forward(x: Tensor) -> Tensor
Source code in vllm/model_executor/models/pixtral.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    gate_up, _ = self.gate_up_proj(x)
    x = self.act_and_mul(gate_up)
    x, _ = self.down_proj(x)
    return x

PixtralHFTransformer

Bases: Module

Source code in vllm/model_executor/models/pixtral.py
class PixtralHFTransformer(nn.Module):

    def __init__(
        self,
        config: PixtralVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        *,
        num_hidden_layers_override: Optional[int] = None,
        prefix: str = "",
    ) -> None:
        super().__init__()

        if num_hidden_layers_override is None:
            num_hidden_layers = config.num_hidden_layers
        else:
            num_hidden_layers = num_hidden_layers_override

        self.layers = nn.ModuleList([
            PixtralHFTransformerBlock(config=config,
                                      quant_config=quant_config,
                                      prefix=f"{prefix}.layers.{layer_idx}")
            for layer_idx in range(num_hidden_layers)
        ])

    def forward(
        self,
        x: torch.Tensor,
        attention_mask: torch.Tensor,
        position_embeddings: torch.Tensor,
        return_all_hidden_states: bool,
    ) -> torch.Tensor:
        hidden_states_pool = [x]

        for layer in self.layers:
            x = layer(x, attention_mask, position_embeddings)
            if return_all_hidden_states:
                hidden_states_pool.append(x)
        # If we have multiple feature sample layers, we return all hidden
        # states in order and grab the ones we need by index.
        if return_all_hidden_states:
            return hidden_states_pool
        return x

layers instance-attribute

layers = ModuleList(
    [
        PixtralHFTransformerBlock(
            config=config,
            quant_config=quant_config,
            prefix=f"{prefix}.layers.{layer_idx}",
        )
        for layer_idx in range(num_hidden_layers)
    ]
)

__init__

__init__(
    config: PixtralVisionConfig,
    quant_config: Optional[QuantizationConfig] = None,
    *,
    num_hidden_layers_override: Optional[int] = None,
    prefix: str = "",
) -> None
Source code in vllm/model_executor/models/pixtral.py
def __init__(
    self,
    config: PixtralVisionConfig,
    quant_config: Optional[QuantizationConfig] = None,
    *,
    num_hidden_layers_override: Optional[int] = None,
    prefix: str = "",
) -> None:
    super().__init__()

    if num_hidden_layers_override is None:
        num_hidden_layers = config.num_hidden_layers
    else:
        num_hidden_layers = num_hidden_layers_override

    self.layers = nn.ModuleList([
        PixtralHFTransformerBlock(config=config,
                                  quant_config=quant_config,
                                  prefix=f"{prefix}.layers.{layer_idx}")
        for layer_idx in range(num_hidden_layers)
    ])

forward

forward(
    x: Tensor,
    attention_mask: Tensor,
    position_embeddings: Tensor,
    return_all_hidden_states: bool,
) -> Tensor
Source code in vllm/model_executor/models/pixtral.py
def forward(
    self,
    x: torch.Tensor,
    attention_mask: torch.Tensor,
    position_embeddings: torch.Tensor,
    return_all_hidden_states: bool,
) -> torch.Tensor:
    hidden_states_pool = [x]

    for layer in self.layers:
        x = layer(x, attention_mask, position_embeddings)
        if return_all_hidden_states:
            hidden_states_pool.append(x)
    # If we have multiple feature sample layers, we return all hidden
    # states in order and grab the ones we need by index.
    if return_all_hidden_states:
        return hidden_states_pool
    return x

PixtralHFTransformerBlock

Bases: Module

Source code in vllm/model_executor/models/pixtral.py
class PixtralHFTransformerBlock(nn.Module):

    def __init__(
        self,
        config: PixtralVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        *,
        prefix: str = "",
    ) -> None:
        super().__init__()

        self.attention_norm = RMSNorm(config.hidden_size, eps=1e-5)
        self.attention = PixtralHFAttention(config,
                                            quant_config=quant_config,
                                            prefix=f"{prefix}.attention")
        self.feed_forward = PixtralHFMLP(config,
                                         quant_config=quant_config,
                                         prefix=f"{prefix}.feed_forward")
        self.ffn_norm = RMSNorm(config.hidden_size, eps=1e-5)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor,
        position_embeddings: torch.Tensor,
    ) -> torch.Tensor:
        r, _ = self.attention.forward(self.attention_norm(hidden_states),
                                      attention_mask=attention_mask,
                                      position_embeddings=position_embeddings)
        h = hidden_states + r
        r = self.feed_forward.forward(self.ffn_norm(h))
        out = h + r
        return out

attention instance-attribute

attention = PixtralHFAttention(
    config,
    quant_config=quant_config,
    prefix=f"{prefix}.attention",
)

attention_norm instance-attribute

attention_norm = RMSNorm(hidden_size, eps=1e-05)

feed_forward instance-attribute

feed_forward = PixtralHFMLP(
    config,
    quant_config=quant_config,
    prefix=f"{prefix}.feed_forward",
)

ffn_norm instance-attribute

ffn_norm = RMSNorm(hidden_size, eps=1e-05)

__init__

__init__(
    config: PixtralVisionConfig,
    quant_config: Optional[QuantizationConfig] = None,
    *,
    prefix: str = "",
) -> None
Source code in vllm/model_executor/models/pixtral.py
def __init__(
    self,
    config: PixtralVisionConfig,
    quant_config: Optional[QuantizationConfig] = None,
    *,
    prefix: str = "",
) -> None:
    super().__init__()

    self.attention_norm = RMSNorm(config.hidden_size, eps=1e-5)
    self.attention = PixtralHFAttention(config,
                                        quant_config=quant_config,
                                        prefix=f"{prefix}.attention")
    self.feed_forward = PixtralHFMLP(config,
                                     quant_config=quant_config,
                                     prefix=f"{prefix}.feed_forward")
    self.ffn_norm = RMSNorm(config.hidden_size, eps=1e-5)

forward

forward(
    hidden_states: Tensor,
    attention_mask: Tensor,
    position_embeddings: Tensor,
) -> Tensor
Source code in vllm/model_executor/models/pixtral.py
def forward(
    self,
    hidden_states: torch.Tensor,
    attention_mask: torch.Tensor,
    position_embeddings: torch.Tensor,
) -> torch.Tensor:
    r, _ = self.attention.forward(self.attention_norm(hidden_states),
                                  attention_mask=attention_mask,
                                  position_embeddings=position_embeddings)
    h = hidden_states + r
    r = self.feed_forward.forward(self.ffn_norm(h))
    out = h + r
    return out

PixtralHFVisionModel

Bases: Module

Source code in vllm/model_executor/models/pixtral.py
class PixtralHFVisionModel(nn.Module):

    def __init__(
        self,
        config: PixtralVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        *,
        num_hidden_layers_override: Optional[int] = None,
        require_post_norm: Optional[bool] = None,
        prefix: str = "",
    ) -> None:
        super().__init__()

        self.config = config

        self.patch_conv = nn.Conv2d(
            in_channels=config.num_channels,
            out_channels=config.hidden_size,
            kernel_size=config.patch_size,
            stride=config.patch_size,
            bias=False,
        )
        self.ln_pre = RMSNorm(config.hidden_size, eps=1e-5)
        self.transformer = PixtralHFTransformer(
            config,
            quant_config,
            num_hidden_layers_override=num_hidden_layers_override,
            prefix=f"{prefix}.transformer",
        )

        num_hidden_layers = config.num_hidden_layers
        if len(self.transformer.layers) > config.num_hidden_layers:
            raise ValueError(
                f"The original encoder only has {num_hidden_layers} "
                f"layers, but you requested {len(self.transformer.layers)} "
                "layers.")

        if require_post_norm is True:
            msg = "PixtralHFVisionModel does not have post-layernorm"
            raise ValueError(msg)

        self.dtype = next(self.parameters()).dtype
        self.device = next(self.parameters()).device
        self.patch_positional_embedding = PixtralRotaryEmbedding(
            config, self.device)

    def forward(
        self,
        pixel_values: list[torch.Tensor],
        feature_sample_layers: Optional[list[int]] = None,
    ) -> tuple[torch.Tensor, ...]:
        """
        Args:
            pixel_values: Each image to be processed will be a separate tensor
                in pixel_values. This means it will be a list of tensors
                because multiple requests batched can have multiple images,
                each with their own shape potentially
            feature_sample_layers: Layer indices whose features should be
                concatenated and used as the visual encoder output. If none
                are provided, the last layer is used.

        Returns:
            image_features: tensor of token features for
                all tokens of all images of shape (N_toks, D)
        """
        # pass images through initial convolution independently
        patch_embeds_list = [
            self.patch_conv(img.unsqueeze(0).to(self.dtype))
            for img in pixel_values
        ]

        patch_embeds = [
            p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list
        ]
        embed_sizes = [p.shape[1] for p in patch_embeds]

        # flatten to a single sequence
        patch_embeds = torch.cat(patch_embeds, dim=1)
        patch_embeds = self.ln_pre(patch_embeds)

        # positional embeddings
        position_ids = position_ids_in_meshgrid(
            patch_embeds_list,
            max_width=self.config.image_size // self.config.patch_size).to(
                self.device)
        position_embedding = self.patch_positional_embedding(
            patch_embeds, position_ids)

        if USE_XFORMERS_OPS:
            attention_mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens(
                [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], )
        else:
            from transformers.models.pixtral.modeling_pixtral import (
                generate_block_attention_mask)
            attention_mask = generate_block_attention_mask(
                [p.shape[-2] * p.shape[-1] for p in patch_embeds_list],
                patch_embeds)

        return_all_hidden_states = feature_sample_layers is not None
        out = self.transformer(
            patch_embeds,
            attention_mask,
            position_embedding,
            return_all_hidden_states=return_all_hidden_states)

        out = resolve_visual_encoder_outputs(out, feature_sample_layers, None,
                                             self.config.num_hidden_layers)

        # squeeze dim 0 and split into separate tensors for each image
        return torch.split(out.squeeze(0), embed_sizes)

    # (TODO) Add prefix argument for filtering out weights to be loaded
    #        ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
        layer_count = len(self.transformer.layers)

        for name, loaded_weight in weights:
            # omit layers when num_hidden_layers_override is set
            if name.startswith("transformer.layers"):
                layer_idx = int(name.split(".")[2])
                if layer_idx >= layer_count:
                    continue

            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

config instance-attribute

config = config

device instance-attribute

device = device

dtype instance-attribute

dtype = dtype

ln_pre instance-attribute

ln_pre = RMSNorm(hidden_size, eps=1e-05)

patch_conv instance-attribute

patch_conv = Conv2d(
    in_channels=num_channels,
    out_channels=hidden_size,
    kernel_size=patch_size,
    stride=patch_size,
    bias=False,
)

patch_positional_embedding instance-attribute

patch_positional_embedding = PixtralRotaryEmbedding(
    config, device
)

transformer instance-attribute

transformer = PixtralHFTransformer(
    config,
    quant_config,
    num_hidden_layers_override=num_hidden_layers_override,
    prefix=f"{prefix}.transformer",
)

__init__

__init__(
    config: PixtralVisionConfig,
    quant_config: Optional[QuantizationConfig] = None,
    *,
    num_hidden_layers_override: Optional[int] = None,
    require_post_norm: Optional[bool] = None,
    prefix: str = "",
) -> None
Source code in vllm/model_executor/models/pixtral.py
def __init__(
    self,
    config: PixtralVisionConfig,
    quant_config: Optional[QuantizationConfig] = None,
    *,
    num_hidden_layers_override: Optional[int] = None,
    require_post_norm: Optional[bool] = None,
    prefix: str = "",
) -> None:
    super().__init__()

    self.config = config

    self.patch_conv = nn.Conv2d(
        in_channels=config.num_channels,
        out_channels=config.hidden_size,
        kernel_size=config.patch_size,
        stride=config.patch_size,
        bias=False,
    )
    self.ln_pre = RMSNorm(config.hidden_size, eps=1e-5)
    self.transformer = PixtralHFTransformer(
        config,
        quant_config,
        num_hidden_layers_override=num_hidden_layers_override,
        prefix=f"{prefix}.transformer",
    )

    num_hidden_layers = config.num_hidden_layers
    if len(self.transformer.layers) > config.num_hidden_layers:
        raise ValueError(
            f"The original encoder only has {num_hidden_layers} "
            f"layers, but you requested {len(self.transformer.layers)} "
            "layers.")

    if require_post_norm is True:
        msg = "PixtralHFVisionModel does not have post-layernorm"
        raise ValueError(msg)

    self.dtype = next(self.parameters()).dtype
    self.device = next(self.parameters()).device
    self.patch_positional_embedding = PixtralRotaryEmbedding(
        config, self.device)

forward

forward(
    pixel_values: list[Tensor],
    feature_sample_layers: Optional[list[int]] = None,
) -> tuple[Tensor, ...]

Parameters:

Name Type Description Default
pixel_values list[Tensor]

Each image to be processed will be a separate tensor in pixel_values. This means it will be a list of tensors because multiple requests batched can have multiple images, each with their own shape potentially

required
feature_sample_layers Optional[list[int]]

Layer indices whose features should be concatenated and used as the visual encoder output. If none are provided, the last layer is used.

None

Returns:

Name Type Description
image_features tuple[Tensor, ...]

tensor of token features for all tokens of all images of shape (N_toks, D)

Source code in vllm/model_executor/models/pixtral.py
def forward(
    self,
    pixel_values: list[torch.Tensor],
    feature_sample_layers: Optional[list[int]] = None,
) -> tuple[torch.Tensor, ...]:
    """
    Args:
        pixel_values: Each image to be processed will be a separate tensor
            in pixel_values. This means it will be a list of tensors
            because multiple requests batched can have multiple images,
            each with their own shape potentially
        feature_sample_layers: Layer indices whose features should be
            concatenated and used as the visual encoder output. If none
            are provided, the last layer is used.

    Returns:
        image_features: tensor of token features for
            all tokens of all images of shape (N_toks, D)
    """
    # pass images through initial convolution independently
    patch_embeds_list = [
        self.patch_conv(img.unsqueeze(0).to(self.dtype))
        for img in pixel_values
    ]

    patch_embeds = [
        p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list
    ]
    embed_sizes = [p.shape[1] for p in patch_embeds]

    # flatten to a single sequence
    patch_embeds = torch.cat(patch_embeds, dim=1)
    patch_embeds = self.ln_pre(patch_embeds)

    # positional embeddings
    position_ids = position_ids_in_meshgrid(
        patch_embeds_list,
        max_width=self.config.image_size // self.config.patch_size).to(
            self.device)
    position_embedding = self.patch_positional_embedding(
        patch_embeds, position_ids)

    if USE_XFORMERS_OPS:
        attention_mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens(
            [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], )
    else:
        from transformers.models.pixtral.modeling_pixtral import (
            generate_block_attention_mask)
        attention_mask = generate_block_attention_mask(
            [p.shape[-2] * p.shape[-1] for p in patch_embeds_list],
            patch_embeds)

    return_all_hidden_states = feature_sample_layers is not None
    out = self.transformer(
        patch_embeds,
        attention_mask,
        position_embedding,
        return_all_hidden_states=return_all_hidden_states)

    out = resolve_visual_encoder_outputs(out, feature_sample_layers, None,
                                         self.config.num_hidden_layers)

    # squeeze dim 0 and split into separate tensors for each image
    return torch.split(out.squeeze(0), embed_sizes)

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]
Source code in vllm/model_executor/models/pixtral.py
def load_weights(self, weights: Iterable[tuple[str,
                                               torch.Tensor]]) -> set[str]:
    stacked_params_mapping = [
        # (param_name, shard_name, shard_id)
        (".qkv_proj", ".q_proj", "q"),
        (".qkv_proj", ".k_proj", "k"),
        (".qkv_proj", ".v_proj", "v"),
        (".gate_up_proj", ".gate_proj", 0),
        (".gate_up_proj", ".up_proj", 1),
    ]
    params_dict = dict(self.named_parameters())
    loaded_params: set[str] = set()
    layer_count = len(self.transformer.layers)

    for name, loaded_weight in weights:
        # omit layers when num_hidden_layers_override is set
        if name.startswith("transformer.layers"):
            layer_idx = int(name.split(".")[2])
            if layer_idx >= layer_count:
                continue

        for (param_name, weight_name, shard_id) in stacked_params_mapping:
            if weight_name not in name:
                continue
            name = name.replace(weight_name, param_name)
            param = params_dict[name]
            weight_loader = param.weight_loader
            weight_loader(param, loaded_weight, shard_id)
            break
        else:
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)
        loaded_params.add(name)
    return loaded_params

PixtralImagePixelInputs

Bases: TypedDict

Source code in vllm/model_executor/models/pixtral.py
class PixtralImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]

    images: Union[torch.Tensor, list[torch.Tensor]]
    """
    Shape: `(batch_size * num_images, num_channels, image_width, image_height)`

    The result of stacking `ImageEncoding.tokens` from each prompt.
    """

images instance-attribute

images: Union[Tensor, list[Tensor]]

Shape: (batch_size * num_images, num_channels, image_width, image_height)

The result of stacking ImageEncoding.tokens from each prompt.

type instance-attribute

type: Literal['pixel_values']

PixtralMultiModalProcessor

Bases: BaseMultiModalProcessor[PixtralProcessingInfo]

Source code in vllm/model_executor/models/pixtral.py
class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
                                 ):

    def _get_mm_fields_config(
        self,
        hf_inputs: Mapping[str, NestedTensors],
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(images=MultiModalFieldConfig.batched("image"))

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
    ) -> Sequence[PromptUpdate]:
        processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)

        image_break_id = processor.image_break_id
        image_token_id = processor.image_token_id
        image_end_id = processor.image_end_id

        def get_replacement(item_idx: int):
            images = mm_items.get_items("image", ImageProcessorItems)
            image_size = images.get_image_size(item_idx)

            ncols, nrows = processor.image_processor._image_to_num_tokens(
                Image.new("RGB", (image_size.width, image_size.height)))

            tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
            tokens[-1] = image_end_id

            return PromptUpdateDetails.select_token_id(tokens, image_token_id)

        return [
            PromptReplacement(
                modality="image",
                target="",  # Never match the prompt (see below note)
                replacement=get_replacement,
            ),
        ]

    def _cached_apply_hf_processor(
        self,
        prompt: Union[str, list[int]],
        mm_data_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        tokenization_kwargs: Mapping[str, object],
        *,
        return_mm_hashes: bool,
    ) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]:
        (
            prompt_ids,
            mm_kwargs,
            mm_hashes,
            _,
        ) = super()._cached_apply_hf_processor(
            prompt=prompt,
            mm_data_items=mm_data_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
            tokenization_kwargs=tokenization_kwargs,
            return_mm_hashes=return_mm_hashes,
        )

        # NOTE: The tokens are already inserted by the chat template
        return prompt_ids, mm_kwargs, mm_hashes, True

_cached_apply_hf_processor

_cached_apply_hf_processor(
    prompt: Union[str, list[int]],
    mm_data_items: MultiModalDataItems,
    hf_processor_mm_kwargs: Mapping[str, object],
    tokenization_kwargs: Mapping[str, object],
    *,
    return_mm_hashes: bool,
) -> tuple[
    list[int],
    MultiModalKwargs,
    Optional[MultiModalHashes],
    bool,
]
Source code in vllm/model_executor/models/pixtral.py
def _cached_apply_hf_processor(
    self,
    prompt: Union[str, list[int]],
    mm_data_items: MultiModalDataItems,
    hf_processor_mm_kwargs: Mapping[str, object],
    tokenization_kwargs: Mapping[str, object],
    *,
    return_mm_hashes: bool,
) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]:
    (
        prompt_ids,
        mm_kwargs,
        mm_hashes,
        _,
    ) = super()._cached_apply_hf_processor(
        prompt=prompt,
        mm_data_items=mm_data_items,
        hf_processor_mm_kwargs=hf_processor_mm_kwargs,
        tokenization_kwargs=tokenization_kwargs,
        return_mm_hashes=return_mm_hashes,
    )

    # NOTE: The tokens are already inserted by the chat template
    return prompt_ids, mm_kwargs, mm_hashes, True

_get_mm_fields_config

_get_mm_fields_config(
    hf_inputs: Mapping[str, NestedTensors],
    hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]
Source code in vllm/model_executor/models/pixtral.py
def _get_mm_fields_config(
    self,
    hf_inputs: Mapping[str, NestedTensors],
    hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
    return dict(images=MultiModalFieldConfig.batched("image"))

_get_prompt_updates

_get_prompt_updates(
    mm_items: MultiModalDataItems,
    hf_processor_mm_kwargs: Mapping[str, object],
    out_mm_kwargs: MultiModalKwargs,
) -> Sequence[PromptUpdate]
Source code in vllm/model_executor/models/pixtral.py
def _get_prompt_updates(
    self,
    mm_items: MultiModalDataItems,
    hf_processor_mm_kwargs: Mapping[str, object],
    out_mm_kwargs: MultiModalKwargs,
) -> Sequence[PromptUpdate]:
    processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)

    image_break_id = processor.image_break_id
    image_token_id = processor.image_token_id
    image_end_id = processor.image_end_id

    def get_replacement(item_idx: int):
        images = mm_items.get_items("image", ImageProcessorItems)
        image_size = images.get_image_size(item_idx)

        ncols, nrows = processor.image_processor._image_to_num_tokens(
            Image.new("RGB", (image_size.width, image_size.height)))

        tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
        tokens[-1] = image_end_id

        return PromptUpdateDetails.select_token_id(tokens, image_token_id)

    return [
        PromptReplacement(
            modality="image",
            target="",  # Never match the prompt (see below note)
            replacement=get_replacement,
        ),
    ]

PixtralProcessingInfo

Bases: BaseProcessingInfo

Source code in vllm/model_executor/models/pixtral.py
class PixtralProcessingInfo(BaseProcessingInfo):

    def get_tokenizer(self) -> MistralTokenizer:
        tokenizer = cached_tokenizer_from_config(self.ctx.model_config)
        if not isinstance(tokenizer, MistralTokenizer):
            raise ValueError("This model requires `--tokenizer-mode mistral`")

        return tokenizer

    def get_hf_processor(self) -> PixtralProcessorAdapter:
        return PixtralProcessorAdapter(self.get_tokenizer())

    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None}

    def get_vision_config(
        self,
        processor: Optional[PixtralProcessorAdapter] = None,
    ):
        if processor is None:
            processor = self.get_hf_processor()

        return PixtralVisionConfig(
            image_size=processor.image_size,
            patch_size=processor.patch_size,
        )

    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
        processor: Optional[PixtralProcessorAdapter] = None,
    ) -> int:
        if processor is None:
            processor = self.get_hf_processor()

        ncols, nrows = processor.image_processor._image_to_num_tokens(
            Image.new("RGB", (image_width, image_height)))

        return ncols * nrows

    def get_image_size_with_most_features(self) -> ImageSize:
        image_processor = self.get_hf_processor().image_processor
        max_image_size = image_processor.mm_config.max_image_size

        return ImageSize(width=max_image_size, height=max_image_size)

get_hf_processor

get_hf_processor() -> PixtralProcessorAdapter
Source code in vllm/model_executor/models/pixtral.py
def get_hf_processor(self) -> PixtralProcessorAdapter:
    return PixtralProcessorAdapter(self.get_tokenizer())

get_image_size_with_most_features

get_image_size_with_most_features() -> ImageSize
Source code in vllm/model_executor/models/pixtral.py
def get_image_size_with_most_features(self) -> ImageSize:
    image_processor = self.get_hf_processor().image_processor
    max_image_size = image_processor.mm_config.max_image_size

    return ImageSize(width=max_image_size, height=max_image_size)

get_num_image_tokens

get_num_image_tokens(
    *,
    image_width: int,
    image_height: int,
    processor: Optional[PixtralProcessorAdapter] = None,
) -> int
Source code in vllm/model_executor/models/pixtral.py
def get_num_image_tokens(
    self,
    *,
    image_width: int,
    image_height: int,
    processor: Optional[PixtralProcessorAdapter] = None,
) -> int:
    if processor is None:
        processor = self.get_hf_processor()

    ncols, nrows = processor.image_processor._image_to_num_tokens(
        Image.new("RGB", (image_width, image_height)))

    return ncols * nrows

get_supported_mm_limits

get_supported_mm_limits() -> Mapping[str, Optional[int]]
Source code in vllm/model_executor/models/pixtral.py
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
    return {"image": None}

get_tokenizer

get_tokenizer() -> MistralTokenizer
Source code in vllm/model_executor/models/pixtral.py
def get_tokenizer(self) -> MistralTokenizer:
    tokenizer = cached_tokenizer_from_config(self.ctx.model_config)
    if not isinstance(tokenizer, MistralTokenizer):
        raise ValueError("This model requires `--tokenizer-mode mistral`")

    return tokenizer

get_vision_config

get_vision_config(
    processor: Optional[PixtralProcessorAdapter] = None,
)
Source code in vllm/model_executor/models/pixtral.py
def get_vision_config(
    self,
    processor: Optional[PixtralProcessorAdapter] = None,
):
    if processor is None:
        processor = self.get_hf_processor()

    return PixtralVisionConfig(
        image_size=processor.image_size,
        patch_size=processor.patch_size,
    )

PixtralProcessorAdapter

Provide a HF-compatible interface for mistral_common.tokens.tokenizers.multimodal.ImageEncoder.

Source code in vllm/model_executor/models/pixtral.py
class PixtralProcessorAdapter:
    """
    Provide a HF-compatible interface for
    `mistral_common.tokens.tokenizers.multimodal.ImageEncoder`.
    """

    def __init__(self, tokenizer: MistralTokenizer) -> None:
        super().__init__()

        self.tokenizer = tokenizer

    @property
    def image_processor(self) -> ImageEncoder:
        image_encoder = self.tokenizer.instruct.mm_encoder
        assert isinstance(image_encoder, ImageEncoder)
        return image_encoder

    @cached_property
    def image_break_id(self) -> int:
        return self.image_processor.special_ids.img_break

    @cached_property
    def image_token_id(self) -> int:
        return self.image_processor.special_ids.img

    @cached_property
    def image_end_id(self) -> int:
        return self.image_processor.special_ids.img_end

    @cached_property
    def image_size(self) -> int:
        return self.image_processor.mm_config.max_image_size

    @cached_property
    def patch_size(self) -> int:
        return self.image_processor.mm_config.image_patch_size

    def __call__(
        self,
        text: Optional[Union[TextInput, list[TextInput]]] = None,
        images: Optional[Union[ImageInput, list[ImageInput]]] = None,
        return_tensors: Optional[Union[str, TensorType]] = None,
        **kwargs,
    ) -> Mapping[str, NestedTensors]:
        if text is None:
            text = []
        if not isinstance(text, list):
            text = [text]
        if images is None:
            images = []
        if not isinstance(images, list):
            images = [images]

        if not images:
            input_ids = self.tokenizer(text).input_ids

            return {"input_ids": torch.tensor(input_ids)}

        # Allow dummy text, which is used for profiling as well as token inputs
        if any(len(t) > 0 for t in text):
            raise ValueError(
                "You've passed text inputs instead of token inputs. "
                "Make sure to process your input via `mistral_common`'s "
                "tokenizer or pass a chat completion request. "
                "For more info, see: "
                "https://github.com/vllm-project/vllm/issues/8411.")

        images_processed = list[torch.Tensor]()
        images_tokens = list[torch.Tensor]()

        for image in images:
            image_inputs = self.image_processor(ImageChunk(image=image))
            image_processed = torch.tensor(image_inputs.image)
            image_tokens = torch.tensor(image_inputs.tokens)

            images_processed.append(image_processed)
            images_tokens.append(image_tokens)

        return {
            "input_ids": torch.cat(images_tokens)[None].expand(len(text), -1),
            "images": images_processed,
        }

image_break_id cached property

image_break_id: int

image_end_id cached property

image_end_id: int

image_processor property

image_processor: ImageEncoder

image_size cached property

image_size: int

image_token_id cached property

image_token_id: int

patch_size cached property

patch_size: int

tokenizer instance-attribute

tokenizer = tokenizer

__call__

__call__(
    text: Optional[
        Union[TextInput, list[TextInput]]
    ] = None,
    images: Optional[
        Union[ImageInput, list[ImageInput]]
    ] = None,
    return_tensors: Optional[Union[str, TensorType]] = None,
    **kwargs,
) -> Mapping[str, NestedTensors]
Source code in vllm/model_executor/models/pixtral.py
def __call__(
    self,
    text: Optional[Union[TextInput, list[TextInput]]] = None,
    images: Optional[Union[ImageInput, list[ImageInput]]] = None,
    return_tensors: Optional[Union[str, TensorType]] = None,
    **kwargs,
) -> Mapping[str, NestedTensors]:
    if text is None:
        text = []
    if not isinstance(text, list):
        text = [text]
    if images is None:
        images = []
    if not isinstance(images, list):
        images = [images]

    if not images:
        input_ids = self.tokenizer(text).input_ids

        return {"input_ids": torch.tensor(input_ids)}

    # Allow dummy text, which is used for profiling as well as token inputs
    if any(len(t) > 0 for t in text):
        raise ValueError(
            "You've passed text inputs instead of token inputs. "
            "Make sure to process your input via `mistral_common`'s "
            "tokenizer or pass a chat completion request. "
            "For more info, see: "
            "https://github.com/vllm-project/vllm/issues/8411.")

    images_processed = list[torch.Tensor]()
    images_tokens = list[torch.Tensor]()

    for image in images:
        image_inputs = self.image_processor(ImageChunk(image=image))
        image_processed = torch.tensor(image_inputs.image)
        image_tokens = torch.tensor(image_inputs.tokens)

        images_processed.append(image_processed)
        images_tokens.append(image_tokens)

    return {
        "input_ids": torch.cat(images_tokens)[None].expand(len(text), -1),
        "images": images_processed,
    }

__init__

__init__(tokenizer: MistralTokenizer) -> None
Source code in vllm/model_executor/models/pixtral.py
def __init__(self, tokenizer: MistralTokenizer) -> None:
    super().__init__()

    self.tokenizer = tokenizer

Transformer

Bases: Module

Source code in vllm/model_executor/models/pixtral.py
class Transformer(nn.Module):

    def __init__(self, args: VisionEncoderArgs):
        super().__init__()
        self.layers = torch.nn.ModuleList()
        for _ in range(args.num_hidden_layers):
            self.layers.append(TransformerBlock(args))

    def forward(
        self,
        x: torch.Tensor,
        mask: torch.Tensor,
        freqs_cis: Optional[torch.Tensor],
    ) -> torch.Tensor:
        for layer in self.layers:
            x = layer(x, mask=mask, freqs_cis=freqs_cis)
        return x

layers instance-attribute

layers = ModuleList()

__init__

__init__(args: VisionEncoderArgs)
Source code in vllm/model_executor/models/pixtral.py
def __init__(self, args: VisionEncoderArgs):
    super().__init__()
    self.layers = torch.nn.ModuleList()
    for _ in range(args.num_hidden_layers):
        self.layers.append(TransformerBlock(args))

forward

forward(
    x: Tensor, mask: Tensor, freqs_cis: Optional[Tensor]
) -> Tensor
Source code in vllm/model_executor/models/pixtral.py
def forward(
    self,
    x: torch.Tensor,
    mask: torch.Tensor,
    freqs_cis: Optional[torch.Tensor],
) -> torch.Tensor:
    for layer in self.layers:
        x = layer(x, mask=mask, freqs_cis=freqs_cis)
    return x

TransformerBlock

Bases: Module

Source code in vllm/model_executor/models/pixtral.py
class TransformerBlock(nn.Module):

    def __init__(self, args: VisionEncoderArgs):
        super().__init__()
        self.attention = Attention(args)
        self.feed_forward = FeedForward(args)
        self.attention_norm = RMSNorm(args.hidden_size, eps=1e-5)
        self.ffn_norm = RMSNorm(args.hidden_size, eps=1e-5)

    def forward(
        self,
        x: torch.Tensor,
        mask: torch.Tensor,
        freqs_cis: torch.Tensor,
    ) -> torch.Tensor:
        r = self.attention.forward(self.attention_norm(x),
                                   mask=mask,
                                   freqs_cis=freqs_cis)
        h = x + r
        r = self.feed_forward.forward(self.ffn_norm(h))
        out = h + r
        return out

attention instance-attribute

attention = Attention(args)

attention_norm instance-attribute

attention_norm = RMSNorm(hidden_size, eps=1e-05)

feed_forward instance-attribute

feed_forward = FeedForward(args)

ffn_norm instance-attribute

ffn_norm = RMSNorm(hidden_size, eps=1e-05)

__init__

__init__(args: VisionEncoderArgs)
Source code in vllm/model_executor/models/pixtral.py
def __init__(self, args: VisionEncoderArgs):
    super().__init__()
    self.attention = Attention(args)
    self.feed_forward = FeedForward(args)
    self.attention_norm = RMSNorm(args.hidden_size, eps=1e-5)
    self.ffn_norm = RMSNorm(args.hidden_size, eps=1e-5)

forward

forward(
    x: Tensor, mask: Tensor, freqs_cis: Tensor
) -> Tensor
Source code in vllm/model_executor/models/pixtral.py
def forward(
    self,
    x: torch.Tensor,
    mask: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> torch.Tensor:
    r = self.attention.forward(self.attention_norm(x),
                               mask=mask,
                               freqs_cis=freqs_cis)
    h = x + r
    r = self.feed_forward.forward(self.ffn_norm(h))
    out = h + r
    return out

VisionEncoderArgs dataclass

Source code in vllm/model_executor/models/pixtral.py
@dataclass
class VisionEncoderArgs:
    hidden_size: int
    num_channels: int
    image_size: int
    patch_size: int
    intermediate_size: int
    num_hidden_layers: int
    num_attention_heads: int
    rope_theta: float  # for rope-2D
    image_token_id: int
    adapter_bias: bool = True
    spatial_merge_size: int = 1
    add_pre_mm_projector_layer_norm: bool = False
    mm_projector_id: str = ""

adapter_bias class-attribute instance-attribute

adapter_bias: bool = True

add_pre_mm_projector_layer_norm class-attribute instance-attribute

add_pre_mm_projector_layer_norm: bool = False

hidden_size instance-attribute

hidden_size: int

image_size instance-attribute

image_size: int

image_token_id instance-attribute

image_token_id: int

intermediate_size instance-attribute

intermediate_size: int

mm_projector_id class-attribute instance-attribute

mm_projector_id: str = ''

num_attention_heads instance-attribute

num_attention_heads: int

num_channels instance-attribute

num_channels: int

num_hidden_layers instance-attribute

num_hidden_layers: int

patch_size instance-attribute

patch_size: int

rope_theta instance-attribute

rope_theta: float

spatial_merge_size class-attribute instance-attribute

spatial_merge_size: int = 1

__init__

__init__(
    hidden_size: int,
    num_channels: int,
    image_size: int,
    patch_size: int,
    intermediate_size: int,
    num_hidden_layers: int,
    num_attention_heads: int,
    rope_theta: float,
    image_token_id: int,
    adapter_bias: bool = True,
    spatial_merge_size: int = 1,
    add_pre_mm_projector_layer_norm: bool = False,
    mm_projector_id: str = "",
) -> None

VisionLanguageAdapter

Bases: Module

Source code in vllm/model_executor/models/pixtral.py
class VisionLanguageAdapter(nn.Module):

    def __init__(self, args: VisionEncoderArgs, dim: int):
        super().__init__()
        assert isinstance(args, VisionEncoderArgs)
        self.w_in = nn.Linear(
            args.hidden_size,
            dim,
            bias=args.adapter_bias,
        )
        self.gelu = nn.GELU()
        self.w_out = nn.Linear(dim, dim, bias=args.adapter_bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.w_out(self.gelu(self.w_in(x)))

gelu instance-attribute

gelu = GELU()

w_in instance-attribute

w_in = Linear(hidden_size, dim, bias=adapter_bias)

w_out instance-attribute

w_out = Linear(dim, dim, bias=adapter_bias)

__init__

__init__(args: VisionEncoderArgs, dim: int)
Source code in vllm/model_executor/models/pixtral.py
def __init__(self, args: VisionEncoderArgs, dim: int):
    super().__init__()
    assert isinstance(args, VisionEncoderArgs)
    self.w_in = nn.Linear(
        args.hidden_size,
        dim,
        bias=args.adapter_bias,
    )
    self.gelu = nn.GELU()
    self.w_out = nn.Linear(dim, dim, bias=args.adapter_bias)

forward

forward(x: Tensor) -> Tensor
Source code in vllm/model_executor/models/pixtral.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    return self.w_out(self.gelu(self.w_in(x)))

VisionTransformer

Bases: Module

Source code in vllm/model_executor/models/pixtral.py
class VisionTransformer(nn.Module):

    def __init__(self, args: VisionEncoderArgs):
        super().__init__()
        self.args = args
        self.patch_conv = nn.Conv2d(
            in_channels=args.num_channels,
            out_channels=args.hidden_size,
            kernel_size=args.patch_size,
            stride=args.patch_size,
            bias=False,
        )
        self.ln_pre = RMSNorm(args.hidden_size, eps=1e-5)
        self.transformer = Transformer(args)

        head_dim = self.args.hidden_size // self.args.num_attention_heads
        assert head_dim % 2 == 0, "ROPE requires even head_dim"
        self._freqs_cis: Optional[torch.Tensor] = None

    @property
    def max_patches_per_side(self) -> int:
        return self.args.image_size // self.args.patch_size

    @property
    def device(self) -> torch.types.Device:
        return next(self.parameters()).device

    @property
    def dtype(self) -> torch.dtype:
        return next(self.parameters()).dtype

    @property
    def freqs_cis(self) -> torch.Tensor:
        if self._freqs_cis is None:
            self._freqs_cis = precompute_freqs_cis_2d(
                dim=self.args.hidden_size // self.args.num_attention_heads,
                height=self.max_patches_per_side,
                width=self.max_patches_per_side,
                theta=self.args.rope_theta,
            )

        if self._freqs_cis.device != self.device:
            self._freqs_cis = self._freqs_cis.to(device=self.device)

        return self._freqs_cis

    def forward(
        self,
        images: list[torch.Tensor],
    ) -> torch.Tensor:
        """
        Args:
            images: list of N_img images of variable sizes,
                each of shape (C, H, W)
        Returns:
            image_features: tensor of token features for
                all tokens of all images of shape (N_toks, D)
        """
        # pass images through initial convolution independently
        patch_embeds_list = [
            self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in images
        ]

        patch_embeds = [
            p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list
        ]
        embed_sizes = [p.shape[1] for p in patch_embeds]

        # flatten to a single sequence
        patch_embeds = torch.cat(patch_embeds, dim=1)
        patch_embeds = self.ln_pre(patch_embeds)

        # positional embeddings
        positions = position_meshgrid(patch_embeds_list).to(self.device)
        freqs_cis = self.freqs_cis[positions[:, 0], positions[:, 1]]

        # pass through Transformer with a block diagonal mask delimiting images
        if USE_XFORMERS_OPS:
            mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens(
                [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], )
        else:
            raise ImportError("Xformers is required for Pixtral inference "
                              "with the Mistral format")
        out = self.transformer(patch_embeds, mask=mask, freqs_cis=freqs_cis)

        # squeeze dim 0 and split into separate tensors for each image
        return torch.split(out.squeeze(0), embed_sizes)

_freqs_cis instance-attribute

_freqs_cis: Optional[Tensor] = None

args instance-attribute

args = args

device property

device: Device

dtype property

dtype: dtype

freqs_cis property

freqs_cis: Tensor

ln_pre instance-attribute

ln_pre = RMSNorm(hidden_size, eps=1e-05)

max_patches_per_side property

max_patches_per_side: int

patch_conv instance-attribute

patch_conv = Conv2d(
    in_channels=num_channels,
    out_channels=hidden_size,
    kernel_size=patch_size,
    stride=patch_size,
    bias=False,
)

transformer instance-attribute

transformer = Transformer(args)

__init__

__init__(args: VisionEncoderArgs)
Source code in vllm/model_executor/models/pixtral.py
def __init__(self, args: VisionEncoderArgs):
    super().__init__()
    self.args = args
    self.patch_conv = nn.Conv2d(
        in_channels=args.num_channels,
        out_channels=args.hidden_size,
        kernel_size=args.patch_size,
        stride=args.patch_size,
        bias=False,
    )
    self.ln_pre = RMSNorm(args.hidden_size, eps=1e-5)
    self.transformer = Transformer(args)

    head_dim = self.args.hidden_size // self.args.num_attention_heads
    assert head_dim % 2 == 0, "ROPE requires even head_dim"
    self._freqs_cis: Optional[torch.Tensor] = None

forward

forward(images: list[Tensor]) -> Tensor

Parameters:

Name Type Description Default
images list[Tensor]

list of N_img images of variable sizes, each of shape (C, H, W)

required

Returns: image_features: tensor of token features for all tokens of all images of shape (N_toks, D)

Source code in vllm/model_executor/models/pixtral.py
def forward(
    self,
    images: list[torch.Tensor],
) -> torch.Tensor:
    """
    Args:
        images: list of N_img images of variable sizes,
            each of shape (C, H, W)
    Returns:
        image_features: tensor of token features for
            all tokens of all images of shape (N_toks, D)
    """
    # pass images through initial convolution independently
    patch_embeds_list = [
        self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in images
    ]

    patch_embeds = [
        p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list
    ]
    embed_sizes = [p.shape[1] for p in patch_embeds]

    # flatten to a single sequence
    patch_embeds = torch.cat(patch_embeds, dim=1)
    patch_embeds = self.ln_pre(patch_embeds)

    # positional embeddings
    positions = position_meshgrid(patch_embeds_list).to(self.device)
    freqs_cis = self.freqs_cis[positions[:, 0], positions[:, 1]]

    # pass through Transformer with a block diagonal mask delimiting images
    if USE_XFORMERS_OPS:
        mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens(
            [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], )
    else:
        raise ImportError("Xformers is required for Pixtral inference "
                          "with the Mistral format")
    out = self.transformer(patch_embeds, mask=mask, freqs_cis=freqs_cis)

    # squeeze dim 0 and split into separate tensors for each image
    return torch.split(out.squeeze(0), embed_sizes)

_reshape_for_broadcast

_reshape_for_broadcast(
    freqs_cis: Tensor, x: Tensor
) -> Tensor

freqs_cis: complex - (seq_len, head_dim / 2) x: complex - (bsz, seq_len, head_dim / 2)

Source code in vllm/model_executor/models/pixtral.py
def _reshape_for_broadcast(freqs_cis: torch.Tensor,
                           x: torch.Tensor) -> torch.Tensor:
    """
    freqs_cis: complex - (seq_len, head_dim / 2)
    x: complex - (bsz, seq_len, head_dim / 2)
    """
    ndim = x.ndim
    assert ndim > 1
    assert freqs_cis.shape == (x.shape[1], x.shape[-1]), (
        freqs_cis.shape,
        (x.shape[1], x.shape[-1]),
    )
    shape = [
        d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)
    ]
    return freqs_cis.view(*shape)

apply_rotary_emb_vit

apply_rotary_emb_vit(
    xq: Tensor, xk: Tensor, freqs_cis: Tensor
) -> tuple[Tensor, Tensor]
Source code in vllm/model_executor/models/pixtral.py
def apply_rotary_emb_vit(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    assert freqs_cis.dtype == torch.complex64
    freqs_cis = _reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

get_sub_grids

get_sub_grids(
    x: Tensor,
    image_sizes: list[tuple[int, int]],
    spatial_merge_size: int,
) -> list[Tensor]
Source code in vllm/model_executor/models/pixtral.py
def get_sub_grids(
    x: torch.Tensor,
    image_sizes: list[tuple[int, int]],
    spatial_merge_size: int,
) -> list[torch.Tensor]:
    # image_sizes specified in tokens
    tokens_per_image = [h * w for h, w in image_sizes]
    d = x.shape[-1]
    all_img_sub_grids: list[torch.Tensor] = []
    sub_grid_size = spatial_merge_size

    for image_index, image_tokens in enumerate(x.split(tokens_per_image)):
        # Reshape image_tokens into a 2D grid
        h, w = image_sizes[image_index]
        image_grid = image_tokens.view(h, w, d).permute(
            2, 0, 1)[None, :, :, :]  # 1 x d x h x w
        sub_grids = torch.nn.functional.unfold(image_grid,
                                               kernel_size=sub_grid_size,
                                               stride=sub_grid_size)
        sub_grids = sub_grids.view(
            1, d, sub_grid_size, sub_grid_size,
            -1)  # 1 x d x sub_grid_size x sub_grid_size x n_patches

        all_img_sub_grids.append(sub_grids[0])

    return all_img_sub_grids

position_meshgrid

position_meshgrid(
    patch_embeds_list: list[Tensor],
) -> Tensor
Source code in vllm/model_executor/models/pixtral.py
def position_meshgrid(patch_embeds_list: list[torch.Tensor], ) -> torch.Tensor:
    positions = torch.cat([
        torch.stack(
            torch.meshgrid(
                torch.arange(p.shape[-2]),
                torch.arange(p.shape[-1]),
                indexing="ij",
            ),
            dim=-1,
        ).reshape(-1, 2) for p in patch_embeds_list
    ])
    return positions

precompute_freqs_cis_2d

precompute_freqs_cis_2d(
    dim: int, height: int, width: int, theta: float
) -> Tensor
2D complex tensor of shape (height, width, dim // 2)

to be indexed by (height, width) position tuples

Source code in vllm/model_executor/models/pixtral.py
def precompute_freqs_cis_2d(
    dim: int,
    height: int,
    width: int,
    theta: float,
) -> torch.Tensor:
    """
    freqs_cis: 2D complex tensor of shape (height, width, dim // 2)
        to be indexed by (height, width) position tuples
    """
    # (dim / 2) frequency bases
    freqs = 1.0 / (theta**(torch.arange(0, dim, 2).float() / dim))

    h = torch.arange(height, device=freqs.device)
    w = torch.arange(width, device=freqs.device)

    freqs_h = torch.outer(h, freqs[::2]).float()
    freqs_w = torch.outer(w, freqs[1::2]).float()
    freqs_2d = torch.cat(
        [
            freqs_h[:, None, :].repeat(1, width, 1),
            freqs_w[None, :, :].repeat(height, 1, 1),
        ],
        dim=-1,
    )
    return torch.polar(torch.ones_like(freqs_2d), freqs_2d)