Skip to content

vllm.model_executor.models.minicpmv4_6

Inference-only MiniCPM-V 4.6 model (MiniCPMV4_6ForConditionalGeneration).

MiniCPMV4_6DownsampleMLP

Bases: Module

Match HF (transformers v5.7+) parameter naming: pre_norm/linear_1/ act/linear_2 (instead of pre_norm + Sequential(mlp.0/mlp.2)).

Source code in vllm/model_executor/models/minicpmv4_6.py
class MiniCPMV4_6DownsampleMLP(nn.Module):
    """Match HF (transformers v5.7+) parameter naming: pre_norm/linear_1/
    act/linear_2 (instead of pre_norm + Sequential(mlp.0/mlp.2))."""

    def __init__(
        self,
        hidden_size: int,
        llm_embed_dim: int,
        merge_kernel_size: tuple[int, int] = (2, 2),
    ):
        super().__init__()
        self.merge_kernel_size = merge_kernel_size
        self.hidden_size = hidden_size * merge_kernel_size[0] * merge_kernel_size[1]
        self.pre_norm = nn.LayerNorm(self.hidden_size, eps=1e-6)
        self.linear_1 = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
        self.act = get_act_fn("gelu")
        self.linear_2 = nn.Linear(self.hidden_size, llm_embed_dim, bias=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pre_norm(x)
        x = self.linear_1(x)
        x = self.act(x)
        x = self.linear_2(x)
        return x

MiniCPMV4_6ForConditionalGeneration

Bases: Module, SupportsMultiModal, SupportsPP, HasInnerState, IsHybrid, SupportsMRoPE

Source code in vllm/model_executor/models/minicpmv4_6.py
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
@MULTIMODAL_REGISTRY.register_processor(
    MiniCPMV4_6MultiModalProcessor,
    info=MiniCPMV4_6ProcessingInfo,
    dummy_inputs=MiniCPMVDummyInputsBuilder,
)
class MiniCPMV4_6ForConditionalGeneration(
    nn.Module,
    SupportsMultiModal,
    SupportsPP,
    HasInnerState,
    IsHybrid,
    SupportsMRoPE,
):
    supports_encoder_tp_data = True

    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # transformers v5.7+ uses `vision_tower` and nests `vit_merger`
            # inside it. Order matters: more specific prefix must come first.
            "model.vision_tower.vit_merger.": "vit_merger.",
            "model.vision_tower.": "vpm.",
            "model.vpm.": "vpm.",
            "model.vit_merger.": "vit_merger.",
            "model.merger.": "merger.",
            "model.language_model.": "language_model.model.",
            "lm_head.": "language_model.lm_head.",
        }
    )

    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"],
        "in_proj_qkvz": ["in_proj_qkv", "in_proj_z"],
        "in_proj_ba": ["in_proj_b", "in_proj_a"],
    }

    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
        # transformers v5.7+ chat_template uses these tokens.
        if modality.startswith("image"):
            return "<|image_pad|>"
        if modality.startswith("video"):
            return "<|video_pad|>"
        raise ValueError("Only image or video modality is supported")

    def get_mrope_input_positions(
        self,
        input_tokens: list[int],
        mm_features: list["MultiModalFeatureSpec"],
    ) -> tuple[torch.Tensor, int]:
        """MiniCPM-V uses embedding injection for vision, not spatial M-RoPE.

        All tokens (text and vision placeholders) get identical sequential
        positions duplicated across the 3 M-RoPE channels expected by the
        Qwen3.5 backbone.
        """
        seq_len = len(input_tokens)
        positions = torch.arange(seq_len).unsqueeze(0).expand(3, -1)
        return positions, 0

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config: MiniCPMV4_6Config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

        self.config = config
        self.multimodal_config = multimodal_config
        self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"

        # --- Vision tower ---
        with self._mark_tower_model(vllm_config, {"image"}):
            self.vpm = Idefics2VisionTransformer(
                config.vision_config,
                quant_config=quant_config,
                apply_encoder_attention_mask=True,
                prefix=maybe_prefix(prefix, "vpm"),
            )
            if config.drop_vision_last_layer:
                self.vpm.encoder.layers = self.vpm.encoder.layers[:-1]

            self.vit_merger = MiniCPMV4_6ViTWindowAttentionMerger(
                config.vision_config,
                quant_config=quant_config,
                prefix=maybe_prefix(prefix, "vit_merger"),
            )
            self.merger = MiniCPMV4_6Merger(
                hidden_size=config.vision_config.hidden_size,
                llm_embed_dim=config.text_config.hidden_size,
            )

        # --- Language model ---
        # Temporarily swap top-level model_type so that Qwen3_5ForCausalLM
        # picks up the expected text config when introspecting the hf config.
        with self._mark_language_model(vllm_config):
            saved_model_type = config.model_type
            config.model_type = "qwen3_5_text"
            try:
                self.language_model = Qwen3_5ForCausalLM(
                    vllm_config=vllm_config,
                    prefix=maybe_prefix(prefix, "language_model"),
                )
            finally:
                config.model_type = saved_model_type

        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors
        )

    # ----- Multimodal parsing -----

    def _parse_and_validate_vision_input(
        self,
        **kwargs: object,
    ) -> MiniCPMVImagePixelInputs | MiniCPMVImageEmbeddingInputs | None:
        pixel_values = kwargs.pop("pixel_values", None)
        image_embeds = kwargs.pop("image_embeds", None)

        if pixel_values is None and image_embeds is None:
            return None

        if image_embeds is not None:
            return MiniCPMVImageEmbeddingInputs(
                type="image_embeds",
                image_embeds=image_embeds,
            )

        tgt_sizes = kwargs.pop("tgt_sizes")
        num_slices_flat = torch.tensor([len(ps) for ps in pixel_values])
        pixel_values_flat = flatten_bn(pixel_values)
        tgt_sizes_flat = flatten_bn(tgt_sizes, concat=True)

        return MiniCPMVImagePixelInputs(
            type="pixel_values",
            pixel_values=pixel_values_flat,
            tgt_sizes=tgt_sizes_flat,
            num_slices=num_slices_flat,
        )

    # ----- Vision forward -----

    def get_vision_hidden_states(
        self,
        data: MiniCPMVImagePixelInputs,
        downsample_mode: str | None = None,
    ) -> list[torch.Tensor]:
        pixel_values = data["pixel_values"]
        tgt_sizes = data["tgt_sizes"]

        B = len(pixel_values)
        P = pixel_values[0].shape[-2]
        L = max(item.shape[-1] for item in pixel_values)
        device = pixel_values[0].device
        target_dtype = self.vpm.embeddings.patch_embedding.weight.dtype

        all_pixel_values = torch.zeros(
            B,
            3,
            P,
            L,
            dtype=target_dtype,
            device=device,
        )
        for i, pv in enumerate(pixel_values):
            all_pixel_values[i, ..., : pv.shape[-1]] = pv.to(target_dtype)

        num_patches = tgt_sizes.prod(-1)
        max_patches = int(num_patches.max().item())
        patch_attn_mask = torch.zeros(
            B,
            max_patches,
            dtype=torch.bool,
            device=device,
        )
        for i in range(B):
            patch_attn_mask[i, : num_patches[i]] = True

        hidden_states = self.vpm.embeddings(
            all_pixel_values,
            patch_attention_mask=patch_attn_mask.unsqueeze(1),
            tgt_sizes=tgt_sizes,
        )

        if torch.any(~patch_attn_mask):
            mask_dtype = hidden_states.dtype
            min_val = torch.finfo(mask_dtype).min
            attention_mask = (~patch_attn_mask).to(dtype=mask_dtype) * min_val
            attention_mask = attention_mask[:, None, None, :]
        else:
            attention_mask = None

        # Encoder layers with mid-encoder merger injection
        insert_layer_id = getattr(self.config, "insert_layer_id", -1)
        if downsample_mode is None:
            downsample_mode = getattr(self.config, "downsample_mode", "16x")
        use_vit_merger = downsample_mode != "4x" and insert_layer_id >= 0

        for layer in self.vpm.encoder.layers[: insert_layer_id + 1]:
            hidden_states = layer(hidden_states, attention_mask=attention_mask)

        if use_vit_merger:
            hidden_states, tgt_sizes, attention_mask = self.vit_merger(
                hidden_states,
                tgt_sizes,
                attention_mask,
            )

        for layer in self.vpm.encoder.layers[insert_layer_id + 1 :]:
            hidden_states = layer(hidden_states, attention_mask=attention_mask)

        # 4. Post layernorm
        hidden_states = self.vpm.post_layernorm(hidden_states)

        # 5. MLP merger → list of per-slice tensors
        return self.merger(hidden_states, tgt_sizes)

    def _process_vision_input(self, image_input, use_vit_merger=None):
        if image_input["type"] == "image_embeds":
            return image_input["image_embeds"]

        downsample_mode = None
        if use_vit_merger is not None:
            downsample_mode = "16x" if use_vit_merger else "4x"
        image_features = self.get_vision_hidden_states(
            image_input,
            downsample_mode=downsample_mode,
        )
        num_slices = image_input["num_slices"]
        results = []
        idx = 0
        for n in num_slices.tolist():
            group = image_features[idx : idx + n]
            results.append(torch.cat(group, dim=0))
            idx += n
        return results

    # ----- Multimodal embedding interface -----

    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
        use_vit_merger_tensors = kwargs.pop("use_vit_merger", None)
        use_vit_merger = None
        if use_vit_merger_tensors is not None:
            if isinstance(use_vit_merger_tensors, torch.Tensor):
                use_vit_merger = bool(use_vit_merger_tensors.any().item())
            elif isinstance(use_vit_merger_tensors, list | tuple):
                use_vit_merger = any(
                    bool(t.any().item()) if isinstance(t, torch.Tensor) else bool(t)
                    for t in use_vit_merger_tensors
                )

        # Split kwargs into image / video buckets (videos are processed via
        # the same vision pipeline; their fields just carry a ``video_`` prefix).
        image_kwargs = {
            k: v
            for k, v in kwargs.items()
            if k in ("pixel_values", "image_embeds", "tgt_sizes")
        }
        video_kwargs = {
            k.removeprefix("video_"): v
            for k, v in kwargs.items()
            if k.startswith("video_")
        }

        multimodal_embeddings: tuple[torch.Tensor, ...] = ()

        if (
            image_kwargs.get("pixel_values") is not None
            or image_kwargs.get("image_embeds") is not None
        ):
            image_input = self._parse_and_validate_vision_input(**image_kwargs)
            if image_input is not None:
                multimodal_embeddings += tuple(
                    self._process_vision_input(
                        image_input,
                        use_vit_merger=use_vit_merger,
                    )
                )

        if (
            video_kwargs.get("pixel_values") is not None
            or video_kwargs.get("image_embeds") is not None
        ):
            video_input = self._parse_and_validate_vision_input(**video_kwargs)
            if video_input is not None:
                multimodal_embeddings += tuple(
                    self._process_vision_input(
                        video_input,
                        use_vit_merger=use_vit_merger,
                    )
                )

        if not multimodal_embeddings:
            return []
        return multimodal_embeddings

    def embed_input_ids(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: MultiModalEmbeddings | None = None,
        *,
        is_multimodal: torch.Tensor | None = None,
    ) -> torch.Tensor:
        inputs_embeds = self._embed_text_input_ids(
            input_ids,
            self.language_model.embed_input_ids,
            is_multimodal=is_multimodal,
        )
        if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
            return inputs_embeds

        is_multimodal = _require_is_multimodal(is_multimodal)
        return _merge_multimodal_embeddings(
            inputs_embeds=inputs_embeds,
            multimodal_embeddings=multimodal_embeddings,
            is_multimodal=is_multimodal,
        )

    # ----- Forward / Logits -----

    def forward(
        self,
        input_ids: torch.Tensor | None,
        positions: torch.Tensor,
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
        **kwargs: Any,
    ) -> torch.Tensor:
        if intermediate_tensors is not None:
            inputs_embeds = None

        return self.language_model.model(
            input_ids=input_ids,
            positions=positions,
            intermediate_tensors=intermediate_tensors,
            inputs_embeds=inputs_embeds,
        )

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor | None:
        return self.language_model.compute_logits(hidden_states)

    # ----- Weight loading -----

    def load_weights(
        self,
        weights: Iterable[tuple[str, torch.Tensor]],
    ) -> set[str]:
        loader = AutoWeightsLoader(self, skip_prefixes=["mtp."])
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

    def get_mm_mapping(self) -> MultiModelKeys:
        return MultiModelKeys.from_string_field(
            language_model="language_model",
            connector=["vit_merger", "merger"],
            tower_model="vpm",
        )

    # ----- Mamba / Hybrid state helpers (same as Qwen3.5 VLM) -----

    @classmethod
    def get_mamba_state_dtype_from_config(cls, vllm_config):
        return MambaStateDtypeCalculator.gated_delta_net_state_dtype(
            vllm_config.model_config.dtype,
            vllm_config.cache_config.mamba_cache_dtype,
            vllm_config.cache_config.mamba_ssm_cache_dtype,
        )

    @classmethod
    def get_mamba_state_shape_from_config(cls, vllm_config):
        parallel_config = vllm_config.parallel_config
        hf_config = vllm_config.model_config.hf_text_config
        tp_size = parallel_config.tensor_parallel_size
        num_spec = (
            vllm_config.speculative_config.num_speculative_tokens
            if vllm_config.speculative_config
            else 0
        )
        return MambaStateShapeCalculator.gated_delta_net_state_shape(
            tp_size,
            hf_config.linear_num_key_heads,
            hf_config.linear_num_value_heads,
            hf_config.linear_key_head_dim,
            hf_config.linear_value_head_dim,
            hf_config.linear_conv_kernel_dim,
            num_spec,
        )

    @classmethod
    def get_mamba_state_copy_func(cls):
        return MambaStateCopyFuncCalculator.gated_delta_net_state_copy_func()

get_mrope_input_positions

get_mrope_input_positions(
    input_tokens: list[int],
    mm_features: list[MultiModalFeatureSpec],
) -> tuple[Tensor, int]

MiniCPM-V uses embedding injection for vision, not spatial M-RoPE.

All tokens (text and vision placeholders) get identical sequential positions duplicated across the 3 M-RoPE channels expected by the Qwen3.5 backbone.

Source code in vllm/model_executor/models/minicpmv4_6.py
def get_mrope_input_positions(
    self,
    input_tokens: list[int],
    mm_features: list["MultiModalFeatureSpec"],
) -> tuple[torch.Tensor, int]:
    """MiniCPM-V uses embedding injection for vision, not spatial M-RoPE.

    All tokens (text and vision placeholders) get identical sequential
    positions duplicated across the 3 M-RoPE channels expected by the
    Qwen3.5 backbone.
    """
    seq_len = len(input_tokens)
    positions = torch.arange(seq_len).unsqueeze(0).expand(3, -1)
    return positions, 0

MiniCPMV4_6Merger

Bases: Module

Source code in vllm/model_executor/models/minicpmv4_6.py
class MiniCPMV4_6Merger(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        llm_embed_dim: int,
        merge_kernel_size: tuple[int, int] = (2, 2),
        times: int = 1,
    ):
        super().__init__()
        self.merge_kernel_size = merge_kernel_size
        self.times = times
        self.mlp = nn.ModuleList(
            [
                MiniCPMV4_6DownsampleMLP(
                    hidden_size,
                    llm_embed_dim if i == times - 1 else hidden_size,
                    merge_kernel_size,
                )
                for i in range(times)
            ]
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        tgt_sizes: torch.Tensor,
    ) -> list[torch.Tensor]:
        """
        Args:
            hidden_states: (B, max_patches, D) padded batch.
            tgt_sizes: (B, 2) actual (H, W) per sample.
        """
        m1, m2 = self.merge_kernel_size
        results = []

        for b in range(len(tgt_sizes)):
            h, w = tgt_sizes[b].tolist()
            n_patches = h * w
            hs = hidden_states[b, :n_patches, :]

            hs = hs.reshape(h // m1, m1, w // m2, m2, -1)
            hs = hs.permute(0, 2, 1, 3, 4).reshape(
                (h // m1) * (w // m2),
                m1 * m2 * hs.shape[-1],
            )
            hs = self.mlp[0](hs)

            if self.times > 1:
                cur_h, cur_w = h // m1, w // m2
                for t in range(1, self.times):
                    cur_h, cur_w = cur_h // m1, cur_w // m2
                    hs = hs.reshape(cur_h, m1, cur_w, m2, -1)
                    hs = hs.permute(0, 2, 1, 3, 4).reshape(
                        cur_h * cur_w,
                        m1 * m2 * hs.shape[-1],
                    )
                    hs = self.mlp[t](hs)

            results.append(hs)

        return results

forward

forward(
    hidden_states: Tensor, tgt_sizes: Tensor
) -> list[Tensor]

Parameters:

Name Type Description Default
hidden_states Tensor

(B, max_patches, D) padded batch.

required
tgt_sizes Tensor

(B, 2) actual (H, W) per sample.

required
Source code in vllm/model_executor/models/minicpmv4_6.py
def forward(
    self,
    hidden_states: torch.Tensor,
    tgt_sizes: torch.Tensor,
) -> list[torch.Tensor]:
    """
    Args:
        hidden_states: (B, max_patches, D) padded batch.
        tgt_sizes: (B, 2) actual (H, W) per sample.
    """
    m1, m2 = self.merge_kernel_size
    results = []

    for b in range(len(tgt_sizes)):
        h, w = tgt_sizes[b].tolist()
        n_patches = h * w
        hs = hidden_states[b, :n_patches, :]

        hs = hs.reshape(h // m1, m1, w // m2, m2, -1)
        hs = hs.permute(0, 2, 1, 3, 4).reshape(
            (h // m1) * (w // m2),
            m1 * m2 * hs.shape[-1],
        )
        hs = self.mlp[0](hs)

        if self.times > 1:
            cur_h, cur_w = h // m1, w // m2
            for t in range(1, self.times):
                cur_h, cur_w = cur_h // m1, cur_w // m2
                hs = hs.reshape(cur_h, m1, cur_w, m2, -1)
                hs = hs.permute(0, 2, 1, 3, 4).reshape(
                    cur_h * cur_w,
                    m1 * m2 * hs.shape[-1],
                )
                hs = self.mlp[t](hs)

        results.append(hs)

    return results

MiniCPMV4_6ProcessingInfo

Bases: MiniCPMVProcessingInfo

Source code in vllm/model_executor/models/minicpmv4_6.py
class MiniCPMV4_6ProcessingInfo(MiniCPMVProcessingInfo):
    # transformers v5.7+ chat_template emits these as image/video placeholders.
    image_pattern = "<|image_pad|>"
    video_pattern = "<|video_pad|>"

    def get_hf_config(self):
        return self.ctx.get_hf_config()

    def _get_expected_hidden_size(self) -> int:
        config = self.get_hf_config()
        if hasattr(config, "text_config") and config.text_config is not None:
            return config.text_config.hidden_size
        return config.hidden_size

    def get_model_version(self):
        return (4, 6)

    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
        return {"image": None, "video": None}

    def get_image_max_slice_num(self) -> int:
        config = self.get_hf_config()
        if hasattr(config, "slice_config") and config.slice_config is not None:
            return getattr(config.slice_config, "max_slice_nums", 9)
        return getattr(config, "max_slice_nums", 9)

    def get_video_max_slice_num(self) -> int:
        # Override the base class default of 1: transformers v5.7+
        # `MiniCPMV4_6VideoProcessor` keeps the same max_slice_nums (default 9)
        # as the image processor so that high-res frames get sliced.
        try:
            hf_processor = self.get_hf_processor()
            video_processor = getattr(hf_processor, "video_processor", None)
            if video_processor is not None:
                return int(getattr(video_processor, "max_slice_nums", 9))
        except Exception:
            pass
        return self.get_image_max_slice_num()

    def _get_downsample_mode(
        self,
        downsample_mode: str | None = None,
    ) -> str:
        if downsample_mode is not None:
            return downsample_mode
        image_processor = self.get_image_processor()
        return getattr(image_processor, "downsample_mode", "16x")

    def _compute_visual_tokens(
        self,
        image_size,
        max_slice_nums: int | None = None,
        downsample_mode: str | None = None,
    ) -> tuple[list[int], int, int]:
        """Compute grid, source_image_visual_tokens and patch_visual_tokens.

        Args:
            downsample_mode: ``"16x"`` (default, full merge) or ``"4x"``
                (skip vit_merger, 4x more visual tokens).

        Returns:
            (grids, source_image_visual_tokens, patch_visual_tokens)
            grids is [0, 0] when no slicing occurs.
        """
        image_processor = self.get_image_processor()
        if max_slice_nums is None:
            max_slice_nums = image_processor.max_slice_nums

        patch_size = image_processor.patch_size
        scale_res = image_processor.scale_resolution
        downsample_mode = self._get_downsample_mode(downsample_mode)
        token_divisor = 4 if downsample_mode == "4x" else 16

        # transformers v5.7+ requires `scale_resolution` arg
        try:
            grids = image_processor.get_sliced_grid(
                image_size,
                max_slice_nums,
                scale_res,
            )
        except TypeError:
            grids = image_processor.get_sliced_grid(
                image_size,
                max_slice_nums,
            )

        if grids is None:
            best_size = image_processor.find_best_resize(
                image_size,
                scale_res,
                patch_size,
                allow_upscale=True,
            )
            source_tokens = (
                best_size[0] * best_size[1] // (patch_size * patch_size * token_divisor)
            )
            return [0, 0], source_tokens, 0

        best_resize = image_processor.find_best_resize(
            image_size,
            scale_res,
            patch_size,
        )
        source_tokens = (
            best_resize[0] * best_resize[1] // (patch_size * patch_size * token_divisor)
        )
        refine_size = image_processor.get_refine_size(
            image_size,
            grids,
            scale_res,
            patch_size,
            allow_upscale=True,
        )
        patch_w = refine_size[0] // grids[0]
        patch_h = refine_size[1] // grids[1]
        patch_tokens = patch_w * patch_h // (patch_size * patch_size * token_divisor)
        return grids, source_tokens, patch_tokens

    def get_slice_image_placeholder(
        self,
        image_size,
        image_idx: int = 0,
        max_slice_nums: int | None = None,
        use_image_id: bool = True,
        downsample_mode: str | None = None,
    ) -> str:
        grids, source_tokens, patch_tokens = self._compute_visual_tokens(
            image_size,
            max_slice_nums,
            downsample_mode=downsample_mode,
        )
        image_processor = self.get_image_processor()
        # transformers v5.7+ removed `get_slice_image_placeholder` from the
        # image_processor and moved the logic into MiniCPMV4_6Processor.
        # Replicate it here using tokenizer special tokens.
        if hasattr(image_processor, "get_slice_image_placeholder"):
            return image_processor.get_slice_image_placeholder(
                grids,
                image_idx=image_idx,
                max_slice_nums=max_slice_nums,
                use_image_id=use_image_id,
                source_image_visual_tokens=source_tokens,
                patch_visual_tokens=patch_tokens,
            )
        tokenizer = self.get_tokenizer()
        image_token = getattr(tokenizer, "image_token", "<|image_pad|>")
        image_start = getattr(tokenizer, "image_start_token", "<image>")
        image_end = getattr(tokenizer, "image_end_token", "</image>")
        slice_start = getattr(tokenizer, "slice_start_token", "<slice>")
        slice_end = getattr(tokenizer, "slice_end_token", "</slice>")
        id_start = getattr(tokenizer, "image_id_start_token", "<image_id>")
        id_end = getattr(tokenizer, "image_id_end_token", "</image_id>")

        placeholder = image_start + image_token * source_tokens + image_end
        if use_image_id:
            placeholder = f"{id_start}{image_idx}{id_end}" + placeholder

        num_cols, num_rows = grids[0], grids[1]
        if num_cols > 0 and num_rows > 0 and patch_tokens > 0:
            slice_ph = slice_start + image_token * patch_tokens + slice_end
            slices = [slice_ph * num_cols for _ in range(num_rows)]
            placeholder += "\n".join(slices)
        return placeholder

    def get_num_image_tokens(
        self,
        image_size,
        max_slice_nums: int | None = None,
        downsample_mode: str | None = None,
    ) -> int:
        grids, source_tokens, patch_tokens = self._compute_visual_tokens(
            image_size,
            max_slice_nums,
            downsample_mode=downsample_mode,
        )
        return source_tokens + grids[0] * grids[1] * patch_tokens

_compute_visual_tokens

_compute_visual_tokens(
    image_size,
    max_slice_nums: int | None = None,
    downsample_mode: str | None = None,
) -> tuple[list[int], int, int]

Compute grid, source_image_visual_tokens and patch_visual_tokens.

Parameters:

Name Type Description Default
downsample_mode str | None

"16x" (default, full merge) or "4x" (skip vit_merger, 4x more visual tokens).

None

Returns:

Type Description
list[int]

(grids, source_image_visual_tokens, patch_visual_tokens)

int

grids is [0, 0] when no slicing occurs.

Source code in vllm/model_executor/models/minicpmv4_6.py
def _compute_visual_tokens(
    self,
    image_size,
    max_slice_nums: int | None = None,
    downsample_mode: str | None = None,
) -> tuple[list[int], int, int]:
    """Compute grid, source_image_visual_tokens and patch_visual_tokens.

    Args:
        downsample_mode: ``"16x"`` (default, full merge) or ``"4x"``
            (skip vit_merger, 4x more visual tokens).

    Returns:
        (grids, source_image_visual_tokens, patch_visual_tokens)
        grids is [0, 0] when no slicing occurs.
    """
    image_processor = self.get_image_processor()
    if max_slice_nums is None:
        max_slice_nums = image_processor.max_slice_nums

    patch_size = image_processor.patch_size
    scale_res = image_processor.scale_resolution
    downsample_mode = self._get_downsample_mode(downsample_mode)
    token_divisor = 4 if downsample_mode == "4x" else 16

    # transformers v5.7+ requires `scale_resolution` arg
    try:
        grids = image_processor.get_sliced_grid(
            image_size,
            max_slice_nums,
            scale_res,
        )
    except TypeError:
        grids = image_processor.get_sliced_grid(
            image_size,
            max_slice_nums,
        )

    if grids is None:
        best_size = image_processor.find_best_resize(
            image_size,
            scale_res,
            patch_size,
            allow_upscale=True,
        )
        source_tokens = (
            best_size[0] * best_size[1] // (patch_size * patch_size * token_divisor)
        )
        return [0, 0], source_tokens, 0

    best_resize = image_processor.find_best_resize(
        image_size,
        scale_res,
        patch_size,
    )
    source_tokens = (
        best_resize[0] * best_resize[1] // (patch_size * patch_size * token_divisor)
    )
    refine_size = image_processor.get_refine_size(
        image_size,
        grids,
        scale_res,
        patch_size,
        allow_upscale=True,
    )
    patch_w = refine_size[0] // grids[0]
    patch_h = refine_size[1] // grids[1]
    patch_tokens = patch_w * patch_h // (patch_size * patch_size * token_divisor)
    return grids, source_tokens, patch_tokens