Skip to content

vllm.model_executor.models.paligemma

PaliGemmaImageInputs module-attribute

logger module-attribute

logger = init_logger(__name__)

PaliGemmaDummyInputsBuilder

Bases: BaseDummyInputsBuilder[PaliGemmaProcessingInfo]

Source code in vllm/model_executor/models/paligemma.py
class PaliGemmaDummyInputsBuilder(
        BaseDummyInputsBuilder[PaliGemmaProcessingInfo]):

    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        return ""

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> MultiModalDataDict:
        hf_config = self.info.get_hf_config()
        vision_config = hf_config.vision_config
        max_image_size = vision_config.image_size

        num_images = mm_counts.get("image", 0)

        return {
            "image":
            self._get_dummy_images(width=max_image_size,
                                   height=max_image_size,
                                   num_images=num_images)
        }

get_dummy_mm_data

get_dummy_mm_data(
    seq_len: int, mm_counts: Mapping[str, int]
) -> MultiModalDataDict
Source code in vllm/model_executor/models/paligemma.py
def get_dummy_mm_data(
    self,
    seq_len: int,
    mm_counts: Mapping[str, int],
) -> MultiModalDataDict:
    hf_config = self.info.get_hf_config()
    vision_config = hf_config.vision_config
    max_image_size = vision_config.image_size

    num_images = mm_counts.get("image", 0)

    return {
        "image":
        self._get_dummy_images(width=max_image_size,
                               height=max_image_size,
                               num_images=num_images)
    }

get_dummy_text

get_dummy_text(mm_counts: Mapping[str, int]) -> str
Source code in vllm/model_executor/models/paligemma.py
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
    return ""

PaliGemmaForConditionalGeneration

Bases: Module, SupportsMultiModal, SupportsPP

Source code in vllm/model_executor/models/paligemma.py
@MULTIMODAL_REGISTRY.register_processor(
    PaliGemmaMultiModalProcessor,
    info=PaliGemmaProcessingInfo,
    dummy_inputs=PaliGemmaDummyInputsBuilder)
class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
                                        SupportsPP):
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.vision_tower.": "vision_tower.",
            "model.multi_modal_projector.": "multi_modal_projector.",
            "lm_head.": "language_model.lm_head.",
        })

    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality.startswith("image"):
            return None

        raise ValueError("Only image modality is supported")

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
        self.config = config
        self.multimodal_config = multimodal_config

        self.vision_tower = SiglipVisionModel(config.vision_config,
                                              quant_config,
                                              prefix=maybe_prefix(
                                                  prefix, "vision_tower"))
        self.multi_modal_projector = PaliGemmaMultiModalProjector(
            vision_hidden_size=config.vision_config.hidden_size,
            projection_dim=config.vision_config.projection_dim)

        self.quant_config = quant_config

        if config.text_config.model_type == "gemma":
            config.text_config.architectures = ["GemmaForCausalLM"]
        else:
            config.text_config.architectures = ["Gemma2ForCausalLM"]
        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )
        logit_scale = getattr(config, "logit_scale", 1.0)
        self.language_model.logits_processor.scale *= logit_scale

        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

    def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)
        actual_dims = tuple(data.shape[1:])

        if actual_dims != expected_dims:
            expected_expr = ("batch_size", *map(str, expected_dims))
            raise ValueError(
                f"The expected shape of pixel values is {expected_expr}. "
                f"You supplied {tuple(data.shape)}.")

        return data

    def _parse_and_validate_image_input(
            self, **kwargs: object) -> Optional[PaliGemmaImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
        image_embeds = kwargs.pop("image_embeds", None)

        if pixel_values is None and image_embeds is None:
            return None

        if pixel_values is not None:
            if not isinstance(pixel_values, (torch.Tensor, list)):
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")

            pixel_values = flatten_bn(pixel_values, concat=True)

            return PaliGemmaImagePixelInputs(
                type="pixel_values",
                data=self._validate_pixel_values(pixel_values),
            )

        if image_embeds is not None:
            if not isinstance(image_embeds, torch.Tensor):
                raise ValueError("Incorrect type of image embeddings. "
                                 f"Got type: {type(image_embeds)}")

            image_embeds = flatten_bn(image_embeds, concat=True)

            return PaliGemmaImageEmbeddingInputs(
                type="image_embeds",
                data=image_embeds,
            )

        raise AssertionError("This line should be unreachable.")

    def _image_pixels_to_features(
        self,
        vision_tower: SiglipVisionModel,
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:

        target_dtype = vision_tower.get_input_embeddings().weight.dtype
        image_features = vision_tower(pixel_values.to(dtype=target_dtype))

        return image_features

    def _process_image_input(
        self,
        image_input: PaliGemmaImageInputs,
    ) -> torch.Tensor:

        if image_input["type"] == "image_embeds":
            return image_input["data"]

        assert self.vision_tower is not None
        pixel_values = image_input["data"]
        image_features = self._image_pixels_to_features(
            self.vision_tower,
            pixel_values,
        )

        return self.multi_modal_projector(image_features)

    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

    def get_multimodal_embeddings(self,
                                  **kwargs: object) -> MultiModalEmbeddings:
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return []
        vision_embeddings = self._process_image_input(image_input)
        # https://github.com/huggingface/transformers/blob/main/src/transformers/models/paligemma/modeling_paligemma.py#L294 # noqa
        vision_embeddings = vision_embeddings * (self.config.hidden_size**-0.5)
        return vision_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
    ) -> torch.Tensor:
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
        if multimodal_embeddings is not None \
            and len(multimodal_embeddings) != 0:
            inputs_embeds = merge_multimodal_embeddings(
                input_ids, inputs_embeds, multimodal_embeddings,
                self.config.image_token_index)
        return inputs_embeds

    def forward(self,
                input_ids: torch.Tensor,
                positions: torch.Tensor,
                intermediate_tensors: Optional[IntermediateTensors] = None,
                inputs_embeds: Optional[torch.Tensor] = None,
                **kwargs: object) -> IntermediateTensors:
        if intermediate_tensors is not None:
            inputs_embeds = None

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
        elif inputs_embeds is None:
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None

        hidden_states = self.language_model.model(input_ids,
                                                  positions,
                                                  intermediate_tensors,
                                                  inputs_embeds=inputs_embeds)

        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

config instance-attribute

config = config

hf_to_vllm_mapper class-attribute instance-attribute

hf_to_vllm_mapper = WeightsMapper(
    orig_to_new_prefix={
        "model.language_model.": "language_model.model.",
        "model.vision_tower.": "vision_tower.",
        "model.multi_modal_projector.": "multi_modal_projector.",
        "lm_head.": "language_model.lm_head.",
    }
)

language_model instance-attribute

language_model = init_vllm_registered_model(
    vllm_config=vllm_config,
    hf_config=text_config,
    prefix=maybe_prefix(prefix, "language_model"),
)

make_empty_intermediate_tensors instance-attribute

make_empty_intermediate_tensors = (
    make_empty_intermediate_tensors
)

multi_modal_projector instance-attribute

multi_modal_projector = PaliGemmaMultiModalProjector(
    vision_hidden_size=hidden_size,
    projection_dim=projection_dim,
)

multimodal_config instance-attribute

multimodal_config = multimodal_config

packed_modules_mapping class-attribute instance-attribute

packed_modules_mapping = {
    "qkv_proj": ["q_proj", "k_proj", "v_proj"],
    "gate_up_proj": ["gate_proj", "up_proj"],
}

quant_config instance-attribute

quant_config = quant_config

vision_tower instance-attribute

vision_tower = SiglipVisionModel(
    vision_config,
    quant_config,
    prefix=maybe_prefix(prefix, "vision_tower"),
)

__init__

__init__(*, vllm_config: VllmConfig, prefix: str = '')
Source code in vllm/model_executor/models/paligemma.py
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
    super().__init__()
    config = vllm_config.model_config.hf_config
    quant_config = vllm_config.quant_config
    multimodal_config = vllm_config.model_config.multimodal_config
    self.config = config
    self.multimodal_config = multimodal_config

    self.vision_tower = SiglipVisionModel(config.vision_config,
                                          quant_config,
                                          prefix=maybe_prefix(
                                              prefix, "vision_tower"))
    self.multi_modal_projector = PaliGemmaMultiModalProjector(
        vision_hidden_size=config.vision_config.hidden_size,
        projection_dim=config.vision_config.projection_dim)

    self.quant_config = quant_config

    if config.text_config.model_type == "gemma":
        config.text_config.architectures = ["GemmaForCausalLM"]
    else:
        config.text_config.architectures = ["Gemma2ForCausalLM"]
    self.language_model = init_vllm_registered_model(
        vllm_config=vllm_config,
        hf_config=config.text_config,
        prefix=maybe_prefix(prefix, "language_model"),
    )
    logit_scale = getattr(config, "logit_scale", 1.0)
    self.language_model.logits_processor.scale *= logit_scale

    self.make_empty_intermediate_tensors = (
        self.language_model.make_empty_intermediate_tensors)

_image_pixels_to_features

_image_pixels_to_features(
    vision_tower: SiglipVisionModel, pixel_values: Tensor
) -> Tensor
Source code in vllm/model_executor/models/paligemma.py
def _image_pixels_to_features(
    self,
    vision_tower: SiglipVisionModel,
    pixel_values: torch.Tensor,
) -> torch.Tensor:

    target_dtype = vision_tower.get_input_embeddings().weight.dtype
    image_features = vision_tower(pixel_values.to(dtype=target_dtype))

    return image_features

_parse_and_validate_image_input

_parse_and_validate_image_input(
    **kwargs: object,
) -> Optional[PaliGemmaImageInputs]
Source code in vllm/model_executor/models/paligemma.py
def _parse_and_validate_image_input(
        self, **kwargs: object) -> Optional[PaliGemmaImageInputs]:
    pixel_values = kwargs.pop("pixel_values", None)
    image_embeds = kwargs.pop("image_embeds", None)

    if pixel_values is None and image_embeds is None:
        return None

    if pixel_values is not None:
        if not isinstance(pixel_values, (torch.Tensor, list)):
            raise ValueError("Incorrect type of pixel values. "
                             f"Got type: {type(pixel_values)}")

        pixel_values = flatten_bn(pixel_values, concat=True)

        return PaliGemmaImagePixelInputs(
            type="pixel_values",
            data=self._validate_pixel_values(pixel_values),
        )

    if image_embeds is not None:
        if not isinstance(image_embeds, torch.Tensor):
            raise ValueError("Incorrect type of image embeddings. "
                             f"Got type: {type(image_embeds)}")

        image_embeds = flatten_bn(image_embeds, concat=True)

        return PaliGemmaImageEmbeddingInputs(
            type="image_embeds",
            data=image_embeds,
        )

    raise AssertionError("This line should be unreachable.")

_process_image_input

_process_image_input(
    image_input: PaliGemmaImageInputs,
) -> Tensor
Source code in vllm/model_executor/models/paligemma.py
def _process_image_input(
    self,
    image_input: PaliGemmaImageInputs,
) -> torch.Tensor:

    if image_input["type"] == "image_embeds":
        return image_input["data"]

    assert self.vision_tower is not None
    pixel_values = image_input["data"]
    image_features = self._image_pixels_to_features(
        self.vision_tower,
        pixel_values,
    )

    return self.multi_modal_projector(image_features)

_validate_pixel_values

_validate_pixel_values(data: Tensor) -> Tensor
Source code in vllm/model_executor/models/paligemma.py
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
    h = w = self.config.vision_config.image_size
    expected_dims = (3, h, w)
    actual_dims = tuple(data.shape[1:])

    if actual_dims != expected_dims:
        expected_expr = ("batch_size", *map(str, expected_dims))
        raise ValueError(
            f"The expected shape of pixel values is {expected_expr}. "
            f"You supplied {tuple(data.shape)}.")

    return data

compute_logits

compute_logits(
    hidden_states: Tensor,
    sampling_metadata: SamplingMetadata,
) -> Optional[Tensor]
Source code in vllm/model_executor/models/paligemma.py
def compute_logits(
    self,
    hidden_states: torch.Tensor,
    sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
    return self.language_model.compute_logits(hidden_states,
                                              sampling_metadata)

forward

forward(
    input_ids: Tensor,
    positions: Tensor,
    intermediate_tensors: Optional[
        IntermediateTensors
    ] = None,
    inputs_embeds: Optional[Tensor] = None,
    **kwargs: object,
) -> IntermediateTensors
Source code in vllm/model_executor/models/paligemma.py
def forward(self,
            input_ids: torch.Tensor,
            positions: torch.Tensor,
            intermediate_tensors: Optional[IntermediateTensors] = None,
            inputs_embeds: Optional[torch.Tensor] = None,
            **kwargs: object) -> IntermediateTensors:
    if intermediate_tensors is not None:
        inputs_embeds = None

    # NOTE: In v1, inputs_embeds is always generated at model runner, this
    # condition is for v0 compatibility.
    elif inputs_embeds is None:
        vision_embeddings = self.get_multimodal_embeddings(**kwargs)
        inputs_embeds = self.get_input_embeddings(input_ids,
                                                  vision_embeddings)
        input_ids = None

    hidden_states = self.language_model.model(input_ids,
                                              positions,
                                              intermediate_tensors,
                                              inputs_embeds=inputs_embeds)

    return hidden_states

get_input_embeddings

get_input_embeddings(
    input_ids: Tensor,
    multimodal_embeddings: Optional[
        MultiModalEmbeddings
    ] = None,
) -> Tensor
Source code in vllm/model_executor/models/paligemma.py
def get_input_embeddings(
    self,
    input_ids: torch.Tensor,
    multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
    inputs_embeds = self.language_model.get_input_embeddings(input_ids)
    if multimodal_embeddings is not None \
        and len(multimodal_embeddings) != 0:
        inputs_embeds = merge_multimodal_embeddings(
            input_ids, inputs_embeds, multimodal_embeddings,
            self.config.image_token_index)
    return inputs_embeds

get_language_model

get_language_model() -> Module
Source code in vllm/model_executor/models/paligemma.py
def get_language_model(self) -> torch.nn.Module:
    return self.language_model

get_multimodal_embeddings

get_multimodal_embeddings(
    **kwargs: object,
) -> MultiModalEmbeddings
Source code in vllm/model_executor/models/paligemma.py
def get_multimodal_embeddings(self,
                              **kwargs: object) -> MultiModalEmbeddings:
    image_input = self._parse_and_validate_image_input(**kwargs)
    if image_input is None:
        return []
    vision_embeddings = self._process_image_input(image_input)
    # https://github.com/huggingface/transformers/blob/main/src/transformers/models/paligemma/modeling_paligemma.py#L294 # noqa
    vision_embeddings = vision_embeddings * (self.config.hidden_size**-0.5)
    return vision_embeddings

get_placeholder_str classmethod

get_placeholder_str(modality: str, i: int) -> Optional[str]
Source code in vllm/model_executor/models/paligemma.py
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
    if modality.startswith("image"):
        return None

    raise ValueError("Only image modality is supported")

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]
Source code in vllm/model_executor/models/paligemma.py
def load_weights(self, weights: Iterable[tuple[str,
                                               torch.Tensor]]) -> set[str]:
    loader = AutoWeightsLoader(self)
    return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

PaliGemmaImageEmbeddingInputs

Bases: TypedDict

Source code in vllm/model_executor/models/paligemma.py
class PaliGemmaImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
    data: torch.Tensor
    """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`

    `hidden_size` must match the hidden size of language model backbone.
    """

data instance-attribute

data: Tensor

Shape: (batch_size * num_images, image_feature_size, hidden_size)

hidden_size must match the hidden size of language model backbone.

type instance-attribute

type: Literal['image_embeds']

PaliGemmaImagePixelInputs

Bases: TypedDict

Source code in vllm/model_executor/models/paligemma.py
class PaliGemmaImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
    data: torch.Tensor
    """Shape: `(batch_size * num_images, num_channels, height, width)`"""

data instance-attribute

data: Tensor

Shape: (batch_size * num_images, num_channels, height, width)

type instance-attribute

type: Literal['pixel_values']

PaliGemmaMultiModalProcessor

Bases: BaseMultiModalProcessor[PaliGemmaProcessingInfo]

Source code in vllm/model_executor/models/paligemma.py
class PaliGemmaMultiModalProcessor(
        BaseMultiModalProcessor[PaliGemmaProcessingInfo]):

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
        tok_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        tokenizer = self.info.get_tokenizer()
        if not mm_data:
            prompt_ids = tokenizer.encode(prompt)
            return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

        return super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
            tok_kwargs=tok_kwargs,
        )

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(pixel_values=MultiModalFieldConfig.batched("image"))

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
    ) -> Sequence[PromptUpdate]:
        hf_config = self.info.get_hf_config()
        image_token_id = hf_config.image_token_index

        tokenizer = self.info.get_tokenizer()

        bos_token_id = tokenizer.bos_token_id
        assert isinstance(bos_token_id, int)

        def get_insertion(item_idx: int):
            images = mm_items.get_items(
                "image", (ImageEmbeddingItems, ImageProcessorItems))

            if isinstance(images, ImageEmbeddingItems):
                num_image_tokens = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
                num_image_tokens = self.info.get_num_image_tokens(
                    image_width=image_size.width,
                    image_height=image_size.height,
                )

            image_tokens = [image_token_id] * num_image_tokens

            return PromptUpdateDetails.select_token_id(
                image_tokens + [bos_token_id],
                embed_token_id=image_token_id,
            )

        # Paligemma 1 and 2 have different tokenizer.add_bos_token
        # Insert <image>*n + <bos> after <bos> for Paligemma 1
        # Insert <image>*n + <bos> for Paligemma 2
        return [
            PromptInsertion(
                modality="image",
                target=PromptIndexTargets.prefix(
                    [bos_token_id] if tokenizer.add_bos_token else []),
                insertion=get_insertion,
            )
        ]

    def apply(
        self,
        prompt: Union[str, list[int]],
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
        tokenization_kwargs: Optional[Mapping[str, object]] = None,
        return_mm_hashes: bool = False,
    ) -> MultiModalInputs:
        mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs,
                                  tokenization_kwargs, return_mm_hashes)
        prompt_token_ids = mm_inputs["prompt_token_ids"]

        tokenizer = self.info.get_tokenizer()
        newline_prompt = "\n"
        newline_token_id = tokenizer.encode(newline_prompt)[-1]  # 108
        # Force to add newline at the end of prompt for paligemma's format
        # This step can NOT be replacemented by current PromptUpdate methods
        if len(prompt_token_ids) and prompt_token_ids[-1] != newline_token_id:
            prompt_token_ids.append(newline_token_id)
            mm_inputs["prompt_token_ids"] = prompt_token_ids
            mm_inputs["prompt"] += newline_prompt

        return mm_inputs

_call_hf_processor

_call_hf_processor(
    prompt: str,
    mm_data: Mapping[str, object],
    mm_kwargs: Mapping[str, object],
    tok_kwargs: Mapping[str, object],
) -> BatchFeature
Source code in vllm/model_executor/models/paligemma.py
def _call_hf_processor(
    self,
    prompt: str,
    mm_data: Mapping[str, object],
    mm_kwargs: Mapping[str, object],
    tok_kwargs: Mapping[str, object],
) -> BatchFeature:
    tokenizer = self.info.get_tokenizer()
    if not mm_data:
        prompt_ids = tokenizer.encode(prompt)
        return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

    return super()._call_hf_processor(
        prompt=prompt,
        mm_data=mm_data,
        mm_kwargs=mm_kwargs,
        tok_kwargs=tok_kwargs,
    )

_get_mm_fields_config

_get_mm_fields_config(
    hf_inputs: BatchFeature,
    hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]
Source code in vllm/model_executor/models/paligemma.py
def _get_mm_fields_config(
    self,
    hf_inputs: BatchFeature,
    hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
    return dict(pixel_values=MultiModalFieldConfig.batched("image"))

_get_prompt_updates

_get_prompt_updates(
    mm_items: MultiModalDataItems,
    hf_processor_mm_kwargs: Mapping[str, object],
    out_mm_kwargs: MultiModalKwargs,
) -> Sequence[PromptUpdate]
Source code in vllm/model_executor/models/paligemma.py
def _get_prompt_updates(
    self,
    mm_items: MultiModalDataItems,
    hf_processor_mm_kwargs: Mapping[str, object],
    out_mm_kwargs: MultiModalKwargs,
) -> Sequence[PromptUpdate]:
    hf_config = self.info.get_hf_config()
    image_token_id = hf_config.image_token_index

    tokenizer = self.info.get_tokenizer()

    bos_token_id = tokenizer.bos_token_id
    assert isinstance(bos_token_id, int)

    def get_insertion(item_idx: int):
        images = mm_items.get_items(
            "image", (ImageEmbeddingItems, ImageProcessorItems))

        if isinstance(images, ImageEmbeddingItems):
            num_image_tokens = images.get_feature_size(item_idx)
        else:
            image_size = images.get_image_size(item_idx)
            num_image_tokens = self.info.get_num_image_tokens(
                image_width=image_size.width,
                image_height=image_size.height,
            )

        image_tokens = [image_token_id] * num_image_tokens

        return PromptUpdateDetails.select_token_id(
            image_tokens + [bos_token_id],
            embed_token_id=image_token_id,
        )

    # Paligemma 1 and 2 have different tokenizer.add_bos_token
    # Insert <image>*n + <bos> after <bos> for Paligemma 1
    # Insert <image>*n + <bos> for Paligemma 2
    return [
        PromptInsertion(
            modality="image",
            target=PromptIndexTargets.prefix(
                [bos_token_id] if tokenizer.add_bos_token else []),
            insertion=get_insertion,
        )
    ]

apply

apply(
    prompt: Union[str, list[int]],
    mm_data: MultiModalDataDict,
    hf_processor_mm_kwargs: Mapping[str, object],
    tokenization_kwargs: Optional[
        Mapping[str, object]
    ] = None,
    return_mm_hashes: bool = False,
) -> MultiModalInputs
Source code in vllm/model_executor/models/paligemma.py
def apply(
    self,
    prompt: Union[str, list[int]],
    mm_data: MultiModalDataDict,
    hf_processor_mm_kwargs: Mapping[str, object],
    tokenization_kwargs: Optional[Mapping[str, object]] = None,
    return_mm_hashes: bool = False,
) -> MultiModalInputs:
    mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs,
                              tokenization_kwargs, return_mm_hashes)
    prompt_token_ids = mm_inputs["prompt_token_ids"]

    tokenizer = self.info.get_tokenizer()
    newline_prompt = "\n"
    newline_token_id = tokenizer.encode(newline_prompt)[-1]  # 108
    # Force to add newline at the end of prompt for paligemma's format
    # This step can NOT be replacemented by current PromptUpdate methods
    if len(prompt_token_ids) and prompt_token_ids[-1] != newline_token_id:
        prompt_token_ids.append(newline_token_id)
        mm_inputs["prompt_token_ids"] = prompt_token_ids
        mm_inputs["prompt"] += newline_prompt

    return mm_inputs

PaliGemmaMultiModalProjector

Bases: Module

Source code in vllm/model_executor/models/paligemma.py
class PaliGemmaMultiModalProjector(nn.Module):

    def __init__(self, vision_hidden_size: int, projection_dim: int):
        super().__init__()

        self.linear = nn.Linear(vision_hidden_size, projection_dim, bias=True)

    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
        hidden_states = self.linear(image_features)
        return hidden_states

linear instance-attribute

linear = Linear(
    vision_hidden_size, projection_dim, bias=True
)

__init__

__init__(vision_hidden_size: int, projection_dim: int)
Source code in vllm/model_executor/models/paligemma.py
def __init__(self, vision_hidden_size: int, projection_dim: int):
    super().__init__()

    self.linear = nn.Linear(vision_hidden_size, projection_dim, bias=True)

forward

forward(image_features: Tensor) -> Tensor
Source code in vllm/model_executor/models/paligemma.py
def forward(self, image_features: torch.Tensor) -> torch.Tensor:
    hidden_states = self.linear(image_features)
    return hidden_states

PaliGemmaProcessingInfo

Bases: BaseProcessingInfo

Source code in vllm/model_executor/models/paligemma.py
class PaliGemmaProcessingInfo(BaseProcessingInfo):

    def get_hf_config(self):
        return self.ctx.get_hf_config(PaliGemmaConfig)

    def get_vision_encoder_info(self):
        return get_vision_encoder_info(self.get_hf_config())

    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": 1}

    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        vision_encoder_info = self.get_vision_encoder_info()

        return vision_encoder_info.get_num_image_tokens(
            image_width=image_width,
            image_height=image_height,
        )

get_hf_config

get_hf_config()
Source code in vllm/model_executor/models/paligemma.py
def get_hf_config(self):
    return self.ctx.get_hf_config(PaliGemmaConfig)

get_num_image_tokens

get_num_image_tokens(
    *, image_width: int, image_height: int
) -> int
Source code in vllm/model_executor/models/paligemma.py
def get_num_image_tokens(
    self,
    *,
    image_width: int,
    image_height: int,
) -> int:
    vision_encoder_info = self.get_vision_encoder_info()

    return vision_encoder_info.get_num_image_tokens(
        image_width=image_width,
        image_height=image_height,
    )

get_supported_mm_limits

get_supported_mm_limits() -> Mapping[str, Optional[int]]
Source code in vllm/model_executor/models/paligemma.py
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
    return {"image": 1}

get_vision_encoder_info

get_vision_encoder_info()
Source code in vllm/model_executor/models/paligemma.py
def get_vision_encoder_info(self):
    return get_vision_encoder_info(self.get_hf_config())