Skip to content

vllm.model_executor.models.granite4_vision

vLLM implementation of Granite 4 Vision.

Uses GraniteForCausalLM as the language backbone with SigLIP vision encoder and deepstack feature injection via WindowQFormer projectors.

LoRA support: use --enable-lora --default-mm-loras for LM-only LoRA adapters.

Granite4VisionForConditionalGeneration

Bases: Module, SupportsLoRA, SupportsMultiModal, SupportsPP

vLLM implementation of Granite 4 Vision.

Architecture: - SigLIP vision tower -> WindowQFormerDownsampler projectors - Deepstack: 4 vision layers projected and injected at 4 LLM layers - Spatial: 4 offset groups from last vision layer injected at 4 more LLM layers - Granite language backbone with embedding_multiplier - logits_scaling via LogitsProcessor

The outer model runs the LLM layer loop directly (like HF does) to inject deepstack features. This avoids wrapping the inner model and keeps weight loading simple.

LoRA support: - Full merge: --hf-overrides '{"adapter_path": "path/to/lora"}' merges LM-only LoRA deltas at load time (W += scaling * B @ A). - Native LoRA: --enable-lora --default-mm-loras '{"image": "path/to/lora"}' lets vLLM runtime serve LM LoRA per-request. Both modes expect a LM-only adapter (no modules_to_save).

Source code in vllm/model_executor/models/granite4_vision.py
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
@MULTIMODAL_REGISTRY.register_processor(
    Granite4VisionMultiModalProcessor,
    info=Granite4VisionProcessingInfo,
    dummy_inputs=LlavaDummyInputsBuilder,
)
class Granite4VisionForConditionalGeneration(
    nn.Module, SupportsLoRA, SupportsMultiModal, SupportsPP
):
    """vLLM implementation of Granite 4 Vision.

    Architecture:
    - SigLIP vision tower -> WindowQFormerDownsampler projectors
    - Deepstack: 4 vision layers projected and injected at 4 LLM layers
    - Spatial: 4 offset groups from last vision layer injected at 4 more LLM layers
    - Granite language backbone with embedding_multiplier
    - logits_scaling via LogitsProcessor

    The outer model runs the LLM layer loop directly (like HF does) to inject
    deepstack features. This avoids wrapping the inner model and keeps weight
    loading simple.

    LoRA support:
    - Full merge: --hf-overrides '{"adapter_path": "path/to/lora"}' merges
      LM-only LoRA deltas at load time (W += scaling * B @ A).
    - Native LoRA: --enable-lora --default-mm-loras '{"image": "path/to/lora"}'
      lets vLLM runtime serve LM LoRA per-request.
    Both modes expect a LM-only adapter (no modules_to_save).
    """

    # LoRA class attributes (matches GraniteForCausalLM)
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"],
    }
    embedding_modules = {}

    # Weight mapping: HF checkpoint -> vLLM parameter names
    # HF: model.language_model.layers.0...
    # vLLM: language_model.model.layers.0...
    # (because GraniteForCausalLM.model = GraniteModel)
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "model.language_model.": "language_model.model.",
            "model.layerwise_projectors.": "layerwise_projectors.",
            "model.spatial_projectors.": "spatial_projectors.",
            "model.image_newline": "image_newline",
            "model.vision_tower.": "vision_tower.",
            "lm_head.": "language_model.lm_head.",
        }
    )

    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
        if modality.startswith("image"):
            return "<image>"
        raise ValueError(f"Only image modality is supported, got {modality}")

    def get_mm_mapping(self) -> MultiModelKeys:
        return MultiModelKeys.from_string_field(
            language_model="language_model",
            connector=["layerwise_projectors", "spatial_projectors"],
            tower_model="vision_tower",
        )

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
        super().__init__()

        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        self.config = config
        self.vllm_config = vllm_config

        # ----- Vision tower + projectors (marked as tower) -----
        with self._mark_tower_model(vllm_config, "image"):
            # Do NOT use init_vision_tower_for_llava here — it truncates the
            # encoder to vision_feature_layer depth. Deepstack needs ALL hidden
            # states (deepstack_layer_map uses negative indices into the full
            # encoder output list).
            self.vision_tower = SiglipVisionModel(
                config.vision_config,
                quant_config=quant_config,
                require_post_norm=False,
                prefix=maybe_prefix(prefix, "vision_tower"),
            )

            # image_newline parameter
            if config.use_image_newline_parameter:
                self.image_newline = nn.Parameter(
                    torch.empty(config.text_config.hidden_size)
                )
            else:
                self.image_newline = None

            cache_config = vllm_config.cache_config

            # Deepstack projectors: one per (vision_layer, llm_layer) pair
            self.layerwise_projectors = nn.ModuleList(
                [
                    WindowQFormerDownsampler(
                        config,
                        quant_config=quant_config,
                        cache_config=cache_config,
                        prefix=maybe_prefix(prefix, f"layerwise_projectors.{i}"),
                    )
                    for i in range(len(config.deepstack_layer_map))
                ]
            )

            # Spatial projectors: 4 offset groups
            self.spatial_projectors = None
            if config.use_spatial_sampling:
                self.spatial_projectors = nn.ModuleList(
                    [
                        WindowQFormerDownsampler(
                            config,
                            quant_config=quant_config,
                            cache_config=cache_config,
                            spatial_offset=i,
                            prefix=maybe_prefix(prefix, f"spatial_projectors.{i}"),
                        )
                        for i in range(4)
                    ]
                )

        # ----- Language model (marked as LM) -----
        with self._mark_language_model(vllm_config):
            self.language_model = Granite4VisionLLMForCausalLM(
                vllm_config=vllm_config.with_hf_config(config.text_config),
                prefix=maybe_prefix(prefix, "language_model"),
            )

        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors
        )

        # Store config values we need
        self._deepstack_layer_map = config.deepstack_layer_map  # [[-19, 9], ...]
        self._use_spatial_sampling = getattr(config, "use_spatial_sampling", False)
        self._spatial_vision_layer = getattr(config, "spatial_vision_layer", -1)
        self._spatial_target_layers = getattr(config, "spatial_target_layers", [])
        self._vision_feature_select_strategy = getattr(
            config, "vision_feature_select_strategy", "full"
        )
        self._downsample_rate = Fraction(config.downsample_rate)

        # Ordered list of LLM layer indices for each deepstack level.
        # Pre-populated from config so it's available during CUDA graph capture
        # (before any embed_multimodal call).
        self._ds_layer_indices: list[int] = [
            llm_layer for _, llm_layer in config.deepstack_layer_map
        ] + list(getattr(config, "spatial_target_layers", []))

        # Share ds_layer_indices with the LLM causal model so
        # make_empty_intermediate_tensors includes the correct keys
        # (its self.config is text_config, no deepstack_layer_map).
        self.language_model._ds_layer_indices = self._ds_layer_indices

        # Pre-allocated persistent GPU buffers for deepstack features.
        # Written via .copy_() in embed_input_ids(), read by forward() via a
        # slice. Because the buffer address is fixed, CUDA graph replay sees
        # the updated values written just before each prefill.
        # Shape: (max_num_batched_tokens, lm_hidden_size) per level.
        n_layerwise = len(config.deepstack_layer_map)
        n_spatial = len(getattr(config, "spatial_target_layers", []))
        num_ds_levels = n_layerwise + n_spatial
        lm_hidden = config.text_config.hidden_size
        max_tokens = vllm_config.scheduler_config.max_num_batched_tokens
        # Allocated on CPU first; moved to GPU in embed_input_ids on first use.
        self._ds_buffers: list[torch.Tensor] = [
            torch.zeros(max_tokens, lm_hidden) for _ in range(num_ds_levels)
        ]
        self._ds_num_tokens: int = 0  # tokens written in last embed_input_ids call

    # ----- Vision feature extraction -----

    def _get_vision_hidden_states(
        self, pixel_values: torch.Tensor
    ) -> list[torch.Tensor]:
        """Run vision tower and return all hidden states (including input embeddings).

        Uses SiglipEncoder's built-in return_all_hidden_states support.
        Returns list[Tensor] where index 0 = embeddings, index i = after layer i-1.
        """
        vt = self.vision_tower
        vm = vt.vision_model if hasattr(vt, "vision_model") else vt

        hidden_states = vm.embeddings(pixel_values)
        all_hidden_states = vm.encoder(
            inputs_embeds=hidden_states,
            return_all_hidden_states=True,
        )
        return all_hidden_states

    def _pack_and_unpad_image_features(
        self,
        image_features: list[torch.Tensor] | tuple[torch.Tensor, ...],
        image_sizes: torch.Tensor,
    ) -> list[torch.Tensor]:
        """Reshape, unpad, and pack image features.

        Matches HF Granite4VisionModel.pack_and_unpad_image_features exactly.
        """
        config = self.config
        ds_rate = self._downsample_rate
        new_image_features = []

        for image_idx, image_feature in enumerate(image_features):
            if image_feature.shape[0] > 1:
                # Multi-patch: first is base, rest are high-res
                base_image_feature = image_feature[0]
                image_feature = image_feature[1:]

                height = width = (
                    config.vision_config.image_size // config.vision_config.patch_size
                )
                # After QFormer downsampling
                height = int(height * ds_rate)
                width = int(width * ds_rate)

                num_patch_height, num_patch_width = get_anyres_image_grid_shape(
                    image_sizes[image_idx],
                    config.image_grid_pinpoints,
                    config.vision_config.image_size,
                )

                image_feature = image_feature.view(
                    num_patch_height, num_patch_width, height, width, -1
                )
                image_feature = (
                    image_feature.permute(4, 0, 2, 1, 3)
                    .contiguous()
                    .flatten(1, 2)
                    .flatten(2, 3)
                )
                image_feature = unpad_image(image_feature, image_sizes[image_idx])

                if self.image_newline is not None:
                    image_feature = torch.cat(
                        (
                            image_feature,
                            self.image_newline[:, None, None]
                            .expand(*image_feature.shape[:-1], 1)
                            .to(image_feature.device, image_feature.dtype),
                        ),
                        dim=-1,
                    )

                image_feature = image_feature.flatten(1, 2).transpose(0, 1)
                image_feature = torch.cat((base_image_feature, image_feature), dim=0)
            else:
                image_feature = image_feature[0]
                if self.image_newline is not None:
                    image_feature = torch.cat(
                        (image_feature, self.image_newline[None].to(image_feature)),
                        dim=0,
                    )

            new_image_features.append(image_feature)

        return new_image_features

    def _get_all_layer_features(
        self,
        pixel_values: torch.Tensor,
        image_sizes: torch.Tensor,
    ) -> tuple[list[int], list[torch.Tensor]]:
        """Extract deepstack + spatial features for all levels.

        Returns:
          llm_layer_indices: ordered list of target LLM layer indices
          per_image_packed:  one tensor per image, shape
                             (num_tokens_i, lm_hidden_size * num_levels),
                             all levels packed on dim=-1.

        Packing on dim=-1 means the framework's token-level slicing for
        chunked prefill preserves all levels intact.
        """
        select_strategy = self._vision_feature_select_strategy

        image_num_patches = [
            image_size_to_num_patches(
                image_size=imsize,
                grid_pinpoints=self.config.image_grid_pinpoints,
                patch_size=self.config.vision_config.image_size,
            )
            for imsize in image_sizes
        ]

        if pixel_values.dim() == 5:
            pixel_values = torch.cat(
                [pv[:np_] for pv, np_ in zip(pixel_values, image_num_patches)],
                dim=0,
            )

        all_hidden_states = self._get_vision_hidden_states(pixel_values)

        # Collect per-level: (llm_layer, [per_image_tensor, ...])
        levels: list[tuple[int, list[torch.Tensor]]] = []

        for proj_idx, (vision_layer, llm_layer) in enumerate(self._deepstack_layer_map):
            selected = all_hidden_states[vision_layer]
            if select_strategy == "default":
                selected = selected[:, 1:]
            projected = self.layerwise_projectors[proj_idx](selected)
            per_image = self._pack_and_unpad_image_features(
                torch.split(projected, image_num_patches, dim=0), image_sizes
            )
            levels.append((llm_layer, per_image))

        if self._use_spatial_sampling and self.spatial_projectors is not None:
            spatial_hidden = all_hidden_states[self._spatial_vision_layer]
            if select_strategy == "default":
                spatial_hidden = spatial_hidden[:, 1:]
            for group_idx, llm_layer in enumerate(self._spatial_target_layers):
                projected = self.spatial_projectors[group_idx](spatial_hidden)
                per_image = self._pack_and_unpad_image_features(
                    torch.split(projected, image_num_patches, dim=0), image_sizes
                )
                levels.append((llm_layer, per_image))

        llm_layer_indices = [llm_layer for llm_layer, _ in levels]
        num_images = len(image_sizes)
        per_image_packed = [
            torch.cat([levels[lvl][1][img] for lvl in range(len(levels))], dim=-1)
            for img in range(num_images)
        ]

        return llm_layer_indices, per_image_packed

    # ----- Multimodal interface -----

    def _parse_and_validate_image_input(
        self, **kwargs: object
    ) -> LlavaNextImageInputs | None:
        pixel_values = kwargs.pop("pixel_values", None)
        image_sizes = kwargs.pop("image_sizes", None)
        image_embeds = kwargs.pop("image_embeds", None)

        if pixel_values is None and image_embeds is None:
            return None

        if pixel_values is not None:
            expected_h = expected_w = self.config.vision_config.image_size
            return LlavaNextImagePixelInputs(
                type="pixel_values",
                pixel_values=pixel_values,
                image_sizes=image_sizes,
                resolve_bindings={"h": expected_h, "w": expected_w},
            )

        if image_embeds is not None:
            return LlavaNextImageEmbeddingInputs(
                type="image_embeds",
                data=image_embeds,
            )

        raise AssertionError("Unreachable")

    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
        """Run vision tower and return per-image packed feature tensors.

        Each returned tensor has shape (num_tokens_i, lm_hidden_size * num_levels)
        with all deepstack levels packed on dim=-1. The framework caches these
        tensors and slices along dim=0 for chunked prefill — all levels survive
        intact because slicing is token-wise, not feature-wise.

        embed_input_ids() splits the packed tensor back into per-level buffers.
        """
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return []

        if image_input["type"] == "image_embeds":
            return [image_input["data"]]

        pixel_values = image_input["pixel_values"]
        image_sizes = image_input.get("image_sizes")

        if isinstance(pixel_values, list):
            pixel_values = torch.cat(pixel_values, dim=0)

        llm_layer_indices, per_image_packed = self._get_all_layer_features(
            pixel_values, image_sizes
        )
        self._ds_layer_indices = llm_layer_indices
        return per_image_packed

    def embed_input_ids(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: MultiModalEmbeddings | None = None,
        *,
        is_multimodal: torch.Tensor | None = None,
        handle_oov_mm_token: bool = True,
    ) -> torch.Tensor:
        """Merge text and vision embeddings, apply embedding_multiplier.

        HF flow:
        1. inputs_embeds = embed_tokens(input_ids)
        2. inputs_embeds.masked_fill(vision_mask, 0.0)
        3. hidden_states = inputs_embeds * embedding_multiplier
        4. layer loop injects deepstack features at target layers

        multimodal_embeddings contains packed tensors from embed_multimodal():
        shape (num_tokens_i, lm_hidden_size * num_levels). We split on dim=-1
        to get per-level features, build batch-sized buffers (zero at text
        positions), and store in self._ds_features for forward().
        """
        lm_inner = self.language_model.model

        has_vision = (
            multimodal_embeddings is not None
            and is_multimodal is not None
            and len(multimodal_embeddings) > 0
            and is_multimodal.any()
        )

        if not has_vision:
            self._ds_num_tokens = 0
            embeds = lm_inner.embed_input_ids(input_ids)
            return embeds * lm_inner.config.embedding_multiplier

        # 1. Text embeddings
        text_embeds = lm_inner.embed_input_ids(input_ids)

        # 2. Zero image positions (matches HF masked_fill(vision_mask, 0.0))
        text_embeds[is_multimodal] = 0.0

        # 3. Apply embedding_multiplier
        inputs_embeds = text_embeds * lm_inner.config.embedding_multiplier

        # 4. Split packed tensors into per-level features and build buffers.
        #    multimodal_embeddings is a list of per-image packed tensors
        #    (possibly a chunk slice from the framework's encoder cache).
        #    Concatenate along token dim → (total_mm_tokens, lm_h * num_levels).
        N, lm_h = inputs_embeds.shape
        all_packed = torch.cat(
            [t.to(dtype=inputs_embeds.dtype) for t in multimodal_embeddings],
            dim=0,
        )
        level_features = all_packed.split(lm_h, dim=-1)  # num_levels tensors

        # Ensure persistent buffers are on the right device/dtype (first call).
        buf0 = self._ds_buffers[0]
        if buf0.device != inputs_embeds.device or buf0.dtype != inputs_embeds.dtype:
            self._ds_buffers = [
                b.to(device=inputs_embeds.device, dtype=inputs_embeds.dtype)
                for b in self._ds_buffers
            ]

        for level_idx in range(len(self._ds_layer_indices)):
            target = self._ds_buffers[level_idx][:N]
            target.zero_()
            target[is_multimodal] = level_features[level_idx]

        self._ds_num_tokens = N
        return inputs_embeds

    # ----- Forward -----

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
        **kwargs: object,
    ) -> torch.Tensor | IntermediateTensors:
        if intermediate_tensors is not None:
            inputs_embeds = None

        # Build IntermediateTensors from pre-allocated persistent buffers.
        # Always pass deepstack when inputs_embeds is non-None (prefill path),
        # including during CUDA graph capture (buffers are zero → no-op injection).
        # This ensures the graph captures the injection code path.
        if (
            inputs_embeds is not None
            and get_pp_group().is_first_rank
            and self._ds_layer_indices
        ):
            ds: IntermediateTensors | None = IntermediateTensors(
                {
                    f"ds_{llm_layer}": self._ds_buffers[lvl]
                    for lvl, llm_layer in enumerate(self._ds_layer_indices)
                }
            )
        else:
            ds = None

        hidden_states = self.language_model.model(
            input_ids=input_ids,
            positions=positions,
            intermediate_tensors=intermediate_tensors,
            inputs_embeds=inputs_embeds,
            deepstack_input_embeds=ds,
        )

        # Clear buffers after use so stale features don't leak into the next request.
        if (
            inputs_embeds is not None
            and get_pp_group().is_first_rank
            and self._ds_num_tokens > 0
        ):
            n = self._ds_num_tokens
            for buf in self._ds_buffers:
                buf[:n].zero_()
            self._ds_num_tokens = 0

        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor | None:
        # GraniteForCausalLM.compute_logits uses
        # LogitsProcessor(scale=1/logits_scaling)
        return self.language_model.compute_logits(hidden_states)

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

_get_all_layer_features

_get_all_layer_features(
    pixel_values: Tensor, image_sizes: Tensor
) -> tuple[list[int], list[Tensor]]

Extract deepstack + spatial features for all levels.

Returns:

Name Type Description
llm_layer_indices list[int]

ordered list of target LLM layer indices

per_image_packed list[Tensor]

one tensor per image, shape (num_tokens_i, lm_hidden_size * num_levels), all levels packed on dim=-1.

Packing on dim=-1 means the framework's token-level slicing for chunked prefill preserves all levels intact.

Source code in vllm/model_executor/models/granite4_vision.py
def _get_all_layer_features(
    self,
    pixel_values: torch.Tensor,
    image_sizes: torch.Tensor,
) -> tuple[list[int], list[torch.Tensor]]:
    """Extract deepstack + spatial features for all levels.

    Returns:
      llm_layer_indices: ordered list of target LLM layer indices
      per_image_packed:  one tensor per image, shape
                         (num_tokens_i, lm_hidden_size * num_levels),
                         all levels packed on dim=-1.

    Packing on dim=-1 means the framework's token-level slicing for
    chunked prefill preserves all levels intact.
    """
    select_strategy = self._vision_feature_select_strategy

    image_num_patches = [
        image_size_to_num_patches(
            image_size=imsize,
            grid_pinpoints=self.config.image_grid_pinpoints,
            patch_size=self.config.vision_config.image_size,
        )
        for imsize in image_sizes
    ]

    if pixel_values.dim() == 5:
        pixel_values = torch.cat(
            [pv[:np_] for pv, np_ in zip(pixel_values, image_num_patches)],
            dim=0,
        )

    all_hidden_states = self._get_vision_hidden_states(pixel_values)

    # Collect per-level: (llm_layer, [per_image_tensor, ...])
    levels: list[tuple[int, list[torch.Tensor]]] = []

    for proj_idx, (vision_layer, llm_layer) in enumerate(self._deepstack_layer_map):
        selected = all_hidden_states[vision_layer]
        if select_strategy == "default":
            selected = selected[:, 1:]
        projected = self.layerwise_projectors[proj_idx](selected)
        per_image = self._pack_and_unpad_image_features(
            torch.split(projected, image_num_patches, dim=0), image_sizes
        )
        levels.append((llm_layer, per_image))

    if self._use_spatial_sampling and self.spatial_projectors is not None:
        spatial_hidden = all_hidden_states[self._spatial_vision_layer]
        if select_strategy == "default":
            spatial_hidden = spatial_hidden[:, 1:]
        for group_idx, llm_layer in enumerate(self._spatial_target_layers):
            projected = self.spatial_projectors[group_idx](spatial_hidden)
            per_image = self._pack_and_unpad_image_features(
                torch.split(projected, image_num_patches, dim=0), image_sizes
            )
            levels.append((llm_layer, per_image))

    llm_layer_indices = [llm_layer for llm_layer, _ in levels]
    num_images = len(image_sizes)
    per_image_packed = [
        torch.cat([levels[lvl][1][img] for lvl in range(len(levels))], dim=-1)
        for img in range(num_images)
    ]

    return llm_layer_indices, per_image_packed

_get_vision_hidden_states

_get_vision_hidden_states(
    pixel_values: Tensor,
) -> list[Tensor]

Run vision tower and return all hidden states (including input embeddings).

Uses SiglipEncoder's built-in return_all_hidden_states support. Returns list[Tensor] where index 0 = embeddings, index i = after layer i-1.

Source code in vllm/model_executor/models/granite4_vision.py
def _get_vision_hidden_states(
    self, pixel_values: torch.Tensor
) -> list[torch.Tensor]:
    """Run vision tower and return all hidden states (including input embeddings).

    Uses SiglipEncoder's built-in return_all_hidden_states support.
    Returns list[Tensor] where index 0 = embeddings, index i = after layer i-1.
    """
    vt = self.vision_tower
    vm = vt.vision_model if hasattr(vt, "vision_model") else vt

    hidden_states = vm.embeddings(pixel_values)
    all_hidden_states = vm.encoder(
        inputs_embeds=hidden_states,
        return_all_hidden_states=True,
    )
    return all_hidden_states

_pack_and_unpad_image_features

_pack_and_unpad_image_features(
    image_features: list[Tensor] | tuple[Tensor, ...],
    image_sizes: Tensor,
) -> list[Tensor]

Reshape, unpad, and pack image features.

Matches HF Granite4VisionModel.pack_and_unpad_image_features exactly.

Source code in vllm/model_executor/models/granite4_vision.py
def _pack_and_unpad_image_features(
    self,
    image_features: list[torch.Tensor] | tuple[torch.Tensor, ...],
    image_sizes: torch.Tensor,
) -> list[torch.Tensor]:
    """Reshape, unpad, and pack image features.

    Matches HF Granite4VisionModel.pack_and_unpad_image_features exactly.
    """
    config = self.config
    ds_rate = self._downsample_rate
    new_image_features = []

    for image_idx, image_feature in enumerate(image_features):
        if image_feature.shape[0] > 1:
            # Multi-patch: first is base, rest are high-res
            base_image_feature = image_feature[0]
            image_feature = image_feature[1:]

            height = width = (
                config.vision_config.image_size // config.vision_config.patch_size
            )
            # After QFormer downsampling
            height = int(height * ds_rate)
            width = int(width * ds_rate)

            num_patch_height, num_patch_width = get_anyres_image_grid_shape(
                image_sizes[image_idx],
                config.image_grid_pinpoints,
                config.vision_config.image_size,
            )

            image_feature = image_feature.view(
                num_patch_height, num_patch_width, height, width, -1
            )
            image_feature = (
                image_feature.permute(4, 0, 2, 1, 3)
                .contiguous()
                .flatten(1, 2)
                .flatten(2, 3)
            )
            image_feature = unpad_image(image_feature, image_sizes[image_idx])

            if self.image_newline is not None:
                image_feature = torch.cat(
                    (
                        image_feature,
                        self.image_newline[:, None, None]
                        .expand(*image_feature.shape[:-1], 1)
                        .to(image_feature.device, image_feature.dtype),
                    ),
                    dim=-1,
                )

            image_feature = image_feature.flatten(1, 2).transpose(0, 1)
            image_feature = torch.cat((base_image_feature, image_feature), dim=0)
        else:
            image_feature = image_feature[0]
            if self.image_newline is not None:
                image_feature = torch.cat(
                    (image_feature, self.image_newline[None].to(image_feature)),
                    dim=0,
                )

        new_image_features.append(image_feature)

    return new_image_features

embed_input_ids

embed_input_ids(
    input_ids: Tensor,
    multimodal_embeddings: MultiModalEmbeddings
    | None = None,
    *,
    is_multimodal: Tensor | None = None,
    handle_oov_mm_token: bool = True,
) -> Tensor

Merge text and vision embeddings, apply embedding_multiplier.

HF flow: 1. inputs_embeds = embed_tokens(input_ids) 2. inputs_embeds.masked_fill(vision_mask, 0.0) 3. hidden_states = inputs_embeds * embedding_multiplier 4. layer loop injects deepstack features at target layers

multimodal_embeddings contains packed tensors from embed_multimodal(): shape (num_tokens_i, lm_hidden_size * num_levels). We split on dim=-1 to get per-level features, build batch-sized buffers (zero at text positions), and store in self._ds_features for forward().

Source code in vllm/model_executor/models/granite4_vision.py
def embed_input_ids(
    self,
    input_ids: torch.Tensor,
    multimodal_embeddings: MultiModalEmbeddings | None = None,
    *,
    is_multimodal: torch.Tensor | None = None,
    handle_oov_mm_token: bool = True,
) -> torch.Tensor:
    """Merge text and vision embeddings, apply embedding_multiplier.

    HF flow:
    1. inputs_embeds = embed_tokens(input_ids)
    2. inputs_embeds.masked_fill(vision_mask, 0.0)
    3. hidden_states = inputs_embeds * embedding_multiplier
    4. layer loop injects deepstack features at target layers

    multimodal_embeddings contains packed tensors from embed_multimodal():
    shape (num_tokens_i, lm_hidden_size * num_levels). We split on dim=-1
    to get per-level features, build batch-sized buffers (zero at text
    positions), and store in self._ds_features for forward().
    """
    lm_inner = self.language_model.model

    has_vision = (
        multimodal_embeddings is not None
        and is_multimodal is not None
        and len(multimodal_embeddings) > 0
        and is_multimodal.any()
    )

    if not has_vision:
        self._ds_num_tokens = 0
        embeds = lm_inner.embed_input_ids(input_ids)
        return embeds * lm_inner.config.embedding_multiplier

    # 1. Text embeddings
    text_embeds = lm_inner.embed_input_ids(input_ids)

    # 2. Zero image positions (matches HF masked_fill(vision_mask, 0.0))
    text_embeds[is_multimodal] = 0.0

    # 3. Apply embedding_multiplier
    inputs_embeds = text_embeds * lm_inner.config.embedding_multiplier

    # 4. Split packed tensors into per-level features and build buffers.
    #    multimodal_embeddings is a list of per-image packed tensors
    #    (possibly a chunk slice from the framework's encoder cache).
    #    Concatenate along token dim → (total_mm_tokens, lm_h * num_levels).
    N, lm_h = inputs_embeds.shape
    all_packed = torch.cat(
        [t.to(dtype=inputs_embeds.dtype) for t in multimodal_embeddings],
        dim=0,
    )
    level_features = all_packed.split(lm_h, dim=-1)  # num_levels tensors

    # Ensure persistent buffers are on the right device/dtype (first call).
    buf0 = self._ds_buffers[0]
    if buf0.device != inputs_embeds.device or buf0.dtype != inputs_embeds.dtype:
        self._ds_buffers = [
            b.to(device=inputs_embeds.device, dtype=inputs_embeds.dtype)
            for b in self._ds_buffers
        ]

    for level_idx in range(len(self._ds_layer_indices)):
        target = self._ds_buffers[level_idx][:N]
        target.zero_()
        target[is_multimodal] = level_features[level_idx]

    self._ds_num_tokens = N
    return inputs_embeds

embed_multimodal

embed_multimodal(**kwargs: object) -> MultiModalEmbeddings

Run vision tower and return per-image packed feature tensors.

Each returned tensor has shape (num_tokens_i, lm_hidden_size * num_levels) with all deepstack levels packed on dim=-1. The framework caches these tensors and slices along dim=0 for chunked prefill — all levels survive intact because slicing is token-wise, not feature-wise.

embed_input_ids() splits the packed tensor back into per-level buffers.

Source code in vllm/model_executor/models/granite4_vision.py
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
    """Run vision tower and return per-image packed feature tensors.

    Each returned tensor has shape (num_tokens_i, lm_hidden_size * num_levels)
    with all deepstack levels packed on dim=-1. The framework caches these
    tensors and slices along dim=0 for chunked prefill — all levels survive
    intact because slicing is token-wise, not feature-wise.

    embed_input_ids() splits the packed tensor back into per-level buffers.
    """
    image_input = self._parse_and_validate_image_input(**kwargs)
    if image_input is None:
        return []

    if image_input["type"] == "image_embeds":
        return [image_input["data"]]

    pixel_values = image_input["pixel_values"]
    image_sizes = image_input.get("image_sizes")

    if isinstance(pixel_values, list):
        pixel_values = torch.cat(pixel_values, dim=0)

    llm_layer_indices, per_image_packed = self._get_all_layer_features(
        pixel_values, image_sizes
    )
    self._ds_layer_indices = llm_layer_indices
    return per_image_packed

Granite4VisionLLMForCausalLM

Bases: GraniteForCausalLM

GraniteForCausalLM backed by Granite4VisionLLMModel.

Source code in vllm/model_executor/models/granite4_vision.py
class Granite4VisionLLMForCausalLM(GraniteForCausalLM):
    """GraniteForCausalLM backed by Granite4VisionLLMModel."""

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
        nn.Module.__init__(self)
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        self.config = config
        self.quant_config = quant_config
        self.model = Granite4VisionLLMModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
        if get_pp_group().is_last_rank:
            self.lm_head = ParallelLMHead(
                config.vocab_size,
                config.hidden_size,
                quant_config=quant_config,
                prefix=maybe_prefix(prefix, "lm_head"),
            )
            if config.tie_word_embeddings:
                self.lm_head.weight = self.model.embed_tokens.weight
            logit_scale = getattr(config, "logit_scale", 1.0)
            if hasattr(config, "logits_scaling"):
                logit_scale /= config.logits_scaling
            self.logits_processor = LogitsProcessor(
                config.vocab_size, scale=logit_scale
            )
        else:
            self.lm_head = PPMissingLayer()

    def make_empty_intermediate_tensors(
        self, batch_size: int, dtype: torch.dtype, device: torch.device
    ) -> IntermediateTensors:
        tensors = super().make_empty_intermediate_tensors(batch_size, dtype, device)
        # Include deepstack buffers so non-first PP ranks receive them.
        # _ds_layer_indices is set directly on this instance by the outer model.
        for llm_layer in getattr(self, "_ds_layer_indices", []):
            tensors.tensors[f"ds_{llm_layer}"] = torch.zeros(
                (batch_size, self.config.hidden_size), dtype=dtype, device=device
            )
        return tensors

Granite4VisionLLMModel

Bases: GraniteModel

GraniteModel with deepstack feature injection in the layer loop.

Source code in vllm/model_executor/models/granite4_vision.py
@support_torch_compile(
    dynamic_arg_dims={
        "input_ids": 0,
        "positions": 0,
        "intermediate_tensors": 0,
        "inputs_embeds": 0,
        "deepstack_input_embeds": 0,
    }
)
class Granite4VisionLLMModel(GraniteModel):
    """GraniteModel with deepstack feature injection in the layer loop."""

    def forward(
        self,
        input_ids: torch.Tensor | None,
        positions: torch.Tensor,
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
        deepstack_input_embeds: IntermediateTensors | None = None,
    ) -> torch.Tensor | IntermediateTensors:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.embed_input_ids(input_ids)
                hidden_states = hidden_states * self.config.embedding_multiplier
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            # Recover deepstack features forwarded from the previous PP rank.
            if deepstack_input_embeds is None:
                ds_keys = [
                    k for k in intermediate_tensors.tensors if k.startswith("ds_")
                ]
                if ds_keys:
                    deepstack_input_embeds = IntermediateTensors(
                        {k: intermediate_tensors[k] for k in ds_keys}
                    )

        for layer_idx, layer in islice(
            enumerate(self.layers), self.start_layer, self.end_layer
        ):
            if deepstack_input_embeds is not None:
                key = f"ds_{layer_idx}"
                if key in deepstack_input_embeds.tensors:
                    feat = deepstack_input_embeds[key]
                    # Resize to match hidden_states in case of CUDA graph padding
                    num_tokens = hidden_states.size(0)
                    buf_len = feat.shape[0]
                    if buf_len != num_tokens:
                        feat = torch.nn.functional.pad(
                            feat[:num_tokens],
                            (0, 0, 0, max(0, num_tokens - buf_len)),
                        )
                    hidden_states = hidden_states + feat
            hidden_states = layer(positions, hidden_states)

        if not get_pp_group().is_last_rank:
            # Forward hidden_states and any deepstack features for later ranks.
            it = {"hidden_states": hidden_states}
            if deepstack_input_embeds is not None:
                remaining = {
                    k: v
                    for k, v in deepstack_input_embeds.tensors.items()
                    if int(k.split("_")[1]) >= self.end_layer
                }
                it.update(remaining)
            return IntermediateTensors(it)

        hidden_states = self.norm(hidden_states)
        return hidden_states

InterpolateDownsampler

Spatial downsampling via area interpolation.

Source code in vllm/model_executor/models/granite4_vision.py
class InterpolateDownsampler:
    """Spatial downsampling via area interpolation."""

    def __init__(self, config, mode="area"):
        self.orig_image_side = (
            config.vision_config.image_size // config.vision_config.patch_size
        )
        self.new_image_side = int(
            self.orig_image_side * Fraction(config.downsample_rate)
        )
        self.mode = mode

    def __call__(self, image_features: torch.Tensor) -> torch.Tensor:
        batch_size, _, dim = image_features.size()
        up_shape = [batch_size, self.orig_image_side, self.orig_image_side, dim]
        large = image_features.view(up_shape).permute(0, 3, 1, 2)
        small = torch.nn.functional.interpolate(
            large,
            size=(self.new_image_side, self.new_image_side),
            mode=self.mode,
        )
        return small.permute(0, 2, 3, 1).flatten(1, 2)

SpatialOffsetDownsampler

Sample one position from each 2x2 block (offset 0-3 = TL/TR/BL/BR).

Source code in vllm/model_executor/models/granite4_vision.py
class SpatialOffsetDownsampler:
    """Sample one position from each 2x2 block (offset 0-3 = TL/TR/BL/BR)."""

    def __init__(self, config, offset: int = 0):
        self.orig_image_side = (
            config.vision_config.image_size // config.vision_config.patch_size
        )
        self.new_image_side = self.orig_image_side // 2
        offsets = [(0, 0), (0, 1), (1, 0), (1, 1)]
        self.offset_h, self.offset_w = offsets[offset]

    def __call__(self, image_features: torch.Tensor) -> torch.Tensor:
        B, _, C = image_features.shape
        features_2d = image_features.reshape(
            B, self.orig_image_side, self.orig_image_side, C
        )
        n = self.new_image_side
        blocks = features_2d.reshape(B, n, 2, n, 2, C)
        sampled = blocks[:, :, self.offset_h, :, self.offset_w, :]
        return sampled.reshape(B, -1, C)

WindowQFormerDownsampler

Bases: Module

Window-based QFormer downsampler (matches HF downsampling.py exactly).

Source code in vllm/model_executor/models/granite4_vision.py
class WindowQFormerDownsampler(nn.Module):
    """Window-based QFormer downsampler (matches HF downsampling.py exactly)."""

    def __init__(
        self,
        config,
        quant_config: QuantizationConfig | None = None,
        cache_config: CacheConfig | None = None,
        spatial_offset: int | None = None,
        prefix: str = "",
    ):
        super().__init__()
        llm_hidden_size = config.text_config.hidden_size
        vision_hidden_size = config.vision_config.hidden_size

        self.dropout = nn.Dropout(config.projector_dropout)

        if spatial_offset is not None:
            self.downsampler = SpatialOffsetDownsampler(config, offset=spatial_offset)
        else:
            self.downsampler = InterpolateDownsampler(config)

        qformer_config = Blip2QFormerConfig(
            hidden_size=vision_hidden_size,
            num_attention_heads=vision_hidden_size // 64,
            intermediate_size=3072,
            num_hidden_layers=1,
            encoder_hidden_size=vision_hidden_size,
            cross_attention_frequency=1,
            max_position_embeddings=2048,
            use_qformer_text_input=False,
        )
        self.qformer = Blip2QFormerModel(
            qformer_config,
            quant_config=quant_config,
            cache_config=cache_config,
            prefix=f"{prefix}.qformer",
        )

        self.image_side = (
            config.vision_config.image_size // config.vision_config.patch_size
        )
        q, w = config.downsample_rate.split("/")
        self.query_side, self.window_side = int(q), int(w)
        self.query_length = self.query_side**2

        embed_std = 1 / math.sqrt(vision_hidden_size)
        self.norm = nn.LayerNorm(vision_hidden_size, eps=1e-6)
        self.query = nn.Parameter(
            torch.randn(1, self.query_length, vision_hidden_size) * embed_std
        )
        self.image_positions = nn.Parameter(
            torch.randn(1, self.window_side**2, vision_hidden_size) * embed_std
        )
        self.out_linear = nn.Linear(vision_hidden_size, llm_hidden_size, bias=True)

    def _win(self, x: torch.Tensor, side: int, win: int) -> torch.Tensor:
        """(B, side*side, C) → (B*n*n, win*win, C) where n=side//win."""
        B, _, C = x.shape
        n = side // win
        return (
            x.view(B, side, side, C)
            .view(B, n, win, n, win, C)
            .transpose(2, 3)
            .flatten(0, 2)
            .flatten(1, 2)
        )

    def _unwin(self, xw: torch.Tensor, n: int, win: int) -> torch.Tensor:
        """(B*n*n, win*win, C) → (B, (n*win)^2, C)."""
        Bnn, _, C = xw.shape
        B = Bnn // (n * n)
        side = n * win
        return (
            xw.view(B, n, n, win, win, C)
            .transpose(2, 3)
            .contiguous()
            .view(B, side, side, C)
            .flatten(1, 2)
        )

    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
        B, HW, C = image_features.shape
        assert self.image_side * self.image_side == HW
        n = self.image_side // self.window_side

        image_features = self.norm(image_features)
        enc = self._win(image_features, self.image_side, self.window_side)

        downsampled = self.downsampler(image_features)
        new_side = n * self.query_side
        downsampled_w = self._win(downsampled, new_side, self.query_side)

        query_embeds = self.query + downsampled_w
        encoder_embeds = self.dropout(enc + self.image_positions)
        out_w = self.qformer(
            query_embeds=query_embeds,
            encoder_hidden_states=encoder_embeds,
        )

        out = self._unwin(out_w, n=n, win=self.query_side)
        out = self.dropout(out)
        return self.out_linear(out)

_unwin

_unwin(xw: Tensor, n: int, win: int) -> Tensor

(Bnn, winwin, C) → (B, (nwin)^2, C).

Source code in vllm/model_executor/models/granite4_vision.py
def _unwin(self, xw: torch.Tensor, n: int, win: int) -> torch.Tensor:
    """(B*n*n, win*win, C) → (B, (n*win)^2, C)."""
    Bnn, _, C = xw.shape
    B = Bnn // (n * n)
    side = n * win
    return (
        xw.view(B, n, n, win, win, C)
        .transpose(2, 3)
        .contiguous()
        .view(B, side, side, C)
        .flatten(1, 2)
    )

_win

_win(x: Tensor, side: int, win: int) -> Tensor

(B, sideside, C) → (Bnn, winwin, C) where n=side//win.

Source code in vllm/model_executor/models/granite4_vision.py
def _win(self, x: torch.Tensor, side: int, win: int) -> torch.Tensor:
    """(B, side*side, C) → (B*n*n, win*win, C) where n=side//win."""
    B, _, C = x.shape
    n = side // win
    return (
        x.view(B, side, side, C)
        .view(B, n, win, n, win, C)
        .transpose(2, 3)
        .flatten(0, 2)
        .flatten(1, 2)
    )