Skip to content

vllm.transformers_utils.processors.moondream3

Custom processor for Moondream3 model.

Moondream3Processor

Bases: ProcessorMixin

Constructs a Moondream3 processor which handles image preprocessing and tokenization for the Moondream3 multimodal model.

Parameters:

Name Type Description Default
tokenizer PreTrainedTokenizerBase | None

The tokenizer to use for text processing.

None
chat_template str | None

Optional chat template string.

None
crop_size int

Size of each image crop.

378
max_crops int

Maximum number of crops per image.

12
overlap_margin int

Margin for overlapping crops in patches.

4
patch_size int

Size of each patch.

14
Source code in vllm/transformers_utils/processors/moondream3.py
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
class Moondream3Processor(ProcessorMixin):
    """
    Constructs a Moondream3 processor which handles image preprocessing
    and tokenization for the Moondream3 multimodal model.

    Args:
        tokenizer: The tokenizer to use for text processing.
        chat_template: Optional chat template string.
        crop_size: Size of each image crop.
        max_crops: Maximum number of crops per image.
        overlap_margin: Margin for overlapping crops in patches.
        patch_size: Size of each patch.
    """

    attributes = ["tokenizer"]
    valid_kwargs = [
        "chat_template",
        "crop_size",
        "max_crops",
        "overlap_margin",
        "patch_size",
    ]

    tokenizer_class = "AutoTokenizer"
    # Use separate tokenizer repo
    _tokenizer_repo = "moondream/starmie-v1"

    # Default chat template for Moondream3
    # Moondream uses special tokens for prompting:
    # - Token 0 (<|endoftext|>): BOS token (ALWAYS present at position 0)
    # - Token 1 (<|md_reserved_0|>): Start of instruction
    # - Token 2 (<|md_reserved_1|>): Separator before question
    # - Token 3 (<|md_reserved_2|>): End of question / start of answer
    #
    # Task routing based on text prefix:
    #   "caption [short|normal|long]" → describe<|md_reserved_1|>{length}
    #   "describe [short|normal|long]" → describe<|md_reserved_1|>{length}
    #   otherwise                      → query<|md_reserved_1|><text>
    #
    # Format with image:
    #   <|endoftext|><image><|md_reserved_0|>{task}<|md_reserved_1|>{q}<|md_reserved_2|>
    # Format without image:
    #   <|endoftext|><|md_reserved_0|>{task}<|md_reserved_1|>{q}<|md_reserved_2|>
    _default_chat_template = (
        "{% for message in messages %}"
        "{% if message['role'] == 'user' %}"
        "{% if message['content'] is string %}"
        # Simple string content (with image assumed) - route by prefix
        "<|endoftext|><image><|md_reserved_0|>"
        "{% if message['content'] == 'caption' %}"
        "describe<|md_reserved_1|>normal<|md_reserved_2|>"
        "{% elif message['content'].startswith('caption ') %}"
        "describe<|md_reserved_1|>{{ message['content'][8:] }}<|md_reserved_2|>"
        "{% elif message['content'] == 'describe' %}"
        "describe<|md_reserved_1|>normal<|md_reserved_2|>"
        "{% elif message['content'].startswith('describe ') %}"
        "describe<|md_reserved_1|>{{ message['content'][9:] }}<|md_reserved_2|>"
        "{% else %}"
        "query<|md_reserved_1|>{{ message['content'] }}<|md_reserved_2|>"
        "{% endif %}"
        "{% else %}"
        # List content - build Moondream's image prefix independently of
        # OpenAI-style content part order, then render the text task.
        "<|endoftext|>"
        "{% for content in message['content'] %}"
        "{% if content['type'] in ['image', 'image_url', 'input_image', 'image_pil'] %}"  # noqa: E501
        "<image>"
        "{% endif %}"
        "{% endfor %}"
        "{% for content in message['content'] %}"
        "{% if content['type'] == 'text' %}"
        "<|md_reserved_0|>"
        "{% if content['text'] == 'caption' %}"
        "describe<|md_reserved_1|>normal<|md_reserved_2|>"
        "{% elif content['text'].startswith('caption ') %}"
        "describe<|md_reserved_1|>{{ content['text'][8:] }}<|md_reserved_2|>"
        "{% elif content['text'] == 'describe' %}"
        "describe<|md_reserved_1|>normal<|md_reserved_2|>"
        "{% elif content['text'].startswith('describe ') %}"
        "describe<|md_reserved_1|>{{ content['text'][9:] }}<|md_reserved_2|>"
        "{% else %}"
        "query<|md_reserved_1|>{{ content['text'] }}<|md_reserved_2|>"
        "{% endif %}"
        "{% endif %}"
        "{% endfor %}"
        "{% endif %}"
        "{% elif message['role'] == 'assistant' %}"
        "{{ message['content'] }}"
        "{% endif %}"
        "{% endfor %}"
    )

    def __init__(
        self,
        tokenizer: PreTrainedTokenizerBase | None = None,
        chat_template: str | None = None,
        crop_size: int = 378,
        max_crops: int = 12,
        overlap_margin: int = 4,
        patch_size: int = 14,
        **kwargs,
    ):
        self.image_token = "<image>"
        self.crop_size = crop_size
        self.max_crops = max_crops
        self.overlap_margin = overlap_margin
        self.patch_size = patch_size

        # Number of patches per crop (27x27 = 729 for 378/14)
        self.patches_per_crop = (crop_size // patch_size) ** 2

        # Use default chat template if none provided
        if chat_template is None:
            chat_template = self._default_chat_template

        super().__init__(tokenizer, chat_template=chat_template)

    @classmethod
    def from_pretrained(
        cls,
        pretrained_model_name_or_path,
        **kwargs,
    ):
        """
        Load the processor, using a separate tokenizer repo.

        The moondream3 model uses a custom tokenizer from 'moondream/starmie-v1'
        instead of having tokenizer files in the model repo.
        """
        from transformers import AutoTokenizer, PreTrainedTokenizerFast
        from transformers.utils import cached_file

        tokenizer = kwargs.pop("tokenizer", None)

        tokenizer_kwargs = {
            "trust_remote_code": kwargs.get("trust_remote_code", False),
        }
        for key in (
            "cache_dir",
            "force_download",
            "local_files_only",
            "revision",
            "subfolder",
            "token",
            "use_fast",
        ):
            if key in kwargs:
                tokenizer_kwargs[key] = kwargs[key]

        cached_file_kwargs = {
            key: tokenizer_kwargs[key]
            for key in (
                "cache_dir",
                "force_download",
                "local_files_only",
                "revision",
                "subfolder",
                "token",
            )
            if key in tokenizer_kwargs
        }

        def load_tokenizer(repo_or_path):
            try:
                return AutoTokenizer.from_pretrained(repo_or_path, **tokenizer_kwargs)
            except Exception:
                tokenizer_file = cached_file(
                    repo_or_path,
                    "tokenizer.json",
                    **cached_file_kwargs,
                )
                return PreTrainedTokenizerFast(
                    tokenizer_file=tokenizer_file,
                    clean_up_tokenization_spaces=False,
                )

        if isinstance(tokenizer, str):
            tokenizer = load_tokenizer(tokenizer)

        if tokenizer is None:
            # Prefer model-local tokenizer files first. If unavailable, fall
            # back to moondream's dedicated tokenizer repository.
            try:
                tokenizer = load_tokenizer(pretrained_model_name_or_path)
            except Exception:
                tokenizer = load_tokenizer(cls._tokenizer_repo)

        # Configure special tokens for Moondream3
        # BOS and EOS are both token 0 (<|endoftext|>), matching the native
        # config (TokenizerConfig.bos_id=0, eos_id=0). This is standard for
        # GPT-2 style models where <|endoftext|> signals both start and end.
        # Token 1 (<|md_reserved_0|>) is a template delimiter, NOT the EOS.
        tokenizer.bos_token = "<|endoftext|>"
        tokenizer.bos_token_id = 0
        tokenizer.eos_token = "<|endoftext|>"
        tokenizer.eos_token_id = 0

        # Extract processor-specific kwargs
        crop_size = kwargs.pop("crop_size", 378)
        max_crops = kwargs.pop("max_crops", 12)
        overlap_margin = kwargs.pop("overlap_margin", 4)
        patch_size = kwargs.pop("patch_size", 14)
        chat_template = kwargs.pop("chat_template", None)

        # Set default chat template on tokenizer if not already set
        if chat_template is None:
            chat_template = cls._default_chat_template
        if tokenizer.chat_template is None:
            tokenizer.chat_template = chat_template

        return cls(
            tokenizer=tokenizer,
            chat_template=chat_template,
            crop_size=crop_size,
            max_crops=max_crops,
            overlap_margin=overlap_margin,
            patch_size=patch_size,
        )

    def __call__(
        self,
        images: ImageInput = None,
        text: TextInput
        | PreTokenizedInput
        | list[TextInput]
        | list[PreTokenizedInput] = None,
        **kwargs: Unpack[Moondream3ProcessorKwargs],
    ) -> BatchFeature:
        """
        Process images and text for Moondream3 model.

        Args:
            images: Input images (PIL Image, numpy array, or list thereof).
            text: Input text or list of texts.
            **kwargs: Additional processing arguments.

        Returns:
            BatchFeature with processed inputs.
        """
        output_kwargs = self._merge_kwargs(
            Moondream3ProcessorKwargs,
            tokenizer_init_kwargs=self.tokenizer.init_kwargs,
            **kwargs,
        )

        # Process images
        image_features = {}
        if images is not None:
            processed_images = []
            tilings = []

            images_list = images if isinstance(images, list) else [images]
            for image in images_list:
                pixel_values, tiling = self.preprocess_image(
                    image, **output_kwargs["images_kwargs"]
                )
                processed_images.append(pixel_values)
                tilings.append(tiling)

            if processed_images:
                image_features["pixel_values"] = processed_images
                image_features["tilings"] = tilings

        # Process text
        if text is not None:
            if not isinstance(text, list):
                text = [text]

            # Get text kwargs, remove keys we set ourselves
            text_kwargs = output_kwargs.get("text_kwargs", {}).copy()
            text_kwargs.pop("return_tensors", None)
            text_kwargs.pop("add_special_tokens", None)

            # Tokenize text
            tokenized = self.tokenizer(
                text,
                add_special_tokens=True,
                return_tensors="pt",
                **text_kwargs,
            )

            output = BatchFeature(data=dict(tokenized))

            # Add image features
            if image_features:
                output["pixel_values"] = image_features["pixel_values"]
                output["tilings"] = image_features["tilings"]

            return output

        # If only images were provided
        return BatchFeature(data=image_features)

    @staticmethod
    def _image_array_to_uint8(array: np.ndarray) -> np.ndarray:
        if array.dtype == np.uint8:
            return np.ascontiguousarray(array)

        if array.dtype == np.bool_:
            return np.ascontiguousarray(array.astype(np.uint8) * 255)

        if np.issubdtype(array.dtype, np.floating):
            array = np.nan_to_num(array, nan=0.0, posinf=255.0, neginf=0.0)
            if array.size > 0 and array.max() <= 1.0:
                array = array * 255.0
            array = np.rint(array)

        return np.ascontiguousarray(np.clip(array, 0, 255).astype(np.uint8))

    @staticmethod
    def _to_pil_image(image: ImageInput) -> Image.Image:
        if isinstance(image, Image.Image):
            return image

        if isinstance(image, torch.Tensor):
            tensor = image.detach().cpu()
            if tensor.dtype == torch.bfloat16:
                tensor = tensor.to(torch.float32)
            image_array = tensor.numpy()
        elif isinstance(image, np.ndarray):
            image_array = image
        else:
            raise TypeError(
                "Moondream3 images must be PIL images, numpy arrays, "
                f"or torch tensors, got {type(image)!r}."
            )

        if image_array.ndim == 2:
            image_array = Moondream3Processor._image_array_to_uint8(image_array)
            return Image.fromarray(image_array)

        if image_array.ndim != 3:
            raise ValueError(
                "Moondream3 image arrays must have 2 or 3 dimensions, "
                f"got shape {image_array.shape}."
            )

        channel_dims = (1, 3, 4)
        if image_array.shape[-1] not in channel_dims:
            if image_array.shape[0] not in channel_dims:
                raise ValueError(
                    "Moondream3 image arrays must be HWC or CHW with 1, 3, "
                    f"or 4 channels, got shape {image_array.shape}."
                )
            image_array = np.transpose(image_array, (1, 2, 0))

        image_array = Moondream3Processor._image_array_to_uint8(image_array)
        if image_array.shape[-1] == 1:
            image_array = image_array[..., 0]

        return Image.fromarray(image_array)

    def preprocess_image(
        self,
        image: ImageInput,
        max_crops: int = 12,
        overlap_margin: int = 4,
        crop_size: int = 378,
        patch_size: int = 14,
        convert_to_rgb: bool = True,
        return_tensors: str = "pt",
    ) -> tuple[torch.Tensor, tuple[int, int]]:
        """
        Preprocess an image using overlap-and-resize cropping strategy.

        Args:
            image: Input PIL image, numpy array, or torch tensor.
            max_crops: Maximum number of crops.
            overlap_margin: Margin for overlapping in patches.
            crop_size: Size of each crop.
            patch_size: Size of each patch.
            convert_to_rgb: Whether to convert to RGB.
            return_tensors: Return type ("pt" for PyTorch).

        Returns:
            Tuple of (pixel_values tensor, tiling tuple).
        """
        image = self._to_pil_image(image)
        if convert_to_rgb:
            image = convert_image_mode(image, "RGB")

        # Convert to numpy array
        image_array = np.array(image)
        original_h, original_w = image_array.shape[:2]

        margin_pixels = patch_size * overlap_margin
        total_margin_pixels = margin_pixels * 2

        crop_patches = crop_size // patch_size
        crop_window_patches = crop_patches - (2 * overlap_margin)
        crop_window_size = crop_window_patches * patch_size

        tiling = select_tiling(
            original_h - total_margin_pixels,
            original_w - total_margin_pixels,
            crop_window_size,
            max_crops,
        )

        n_crops = tiling[0] * tiling[1] + 1
        crops = np.zeros((n_crops, crop_size, crop_size, 3), dtype=np.uint8)

        target_size = (
            tiling[0] * crop_window_size + total_margin_pixels,
            tiling[1] * crop_window_size + total_margin_pixels,
        )

        # Resize image
        pil_img = Image.fromarray(image_array)
        resized = pil_img.resize(
            (int(target_size[1]), int(target_size[0])),
            resample=Image.Resampling.LANCZOS,
        )
        resized_array = np.asarray(resized)

        # Create global crop
        global_pil = pil_img.resize(
            (crop_size, crop_size), resample=Image.Resampling.LANCZOS
        )
        crops[0] = np.asarray(global_pil)

        # Create local crops
        for i in range(tiling[0]):
            for j in range(tiling[1]):
                y0 = i * crop_window_size
                x0 = j * crop_window_size
                y_end = min(y0 + crop_size, resized_array.shape[0])
                x_end = min(x0 + crop_size, resized_array.shape[1])

                crop_region = resized_array[y0:y_end, x0:x_end]
                crop_idx = 1 + i * tiling[1] + j
                h_slice = slice(None, crop_region.shape[0])
                w_slice = slice(None, crop_region.shape[1])
                crops[crop_idx, h_slice, w_slice] = crop_region

        # Convert to tensor: (n_crops, H, W, C) -> (n_crops, C, H, W)
        pixel_values = np.transpose(crops, (0, 3, 1, 2))

        if return_tensors == "pt":
            # Match HF reference preprocessing exactly: convert uint8 crops to
            # bfloat16 before in-place normalization.
            pixel_values = (
                torch.from_numpy(pixel_values)
                .to(dtype=torch.bfloat16)
                .div_(255.0)
                .sub_(0.5)
                .div_(0.5)
            )
        else:
            pixel_values = pixel_values.astype(np.float32) / 255.0
            pixel_values = (pixel_values - 0.5) / 0.5

        return pixel_values, tiling

    def get_num_image_tokens(self) -> int:
        """Return the number of image tokens (729 = 27x27 patches)."""
        return self.patches_per_crop

    def batch_decode(self, *args, **kwargs):
        """Forward to tokenizer's batch_decode."""
        return self.tokenizer.batch_decode(*args, **kwargs)

    def decode(self, *args, **kwargs):
        """Forward to tokenizer's decode."""
        return self.tokenizer.decode(*args, **kwargs)

    @property
    def model_input_names(self):
        tokenizer_input_names = self.tokenizer.model_input_names
        return tokenizer_input_names + ["pixel_values", "tilings"]

__call__

__call__(
    images: ImageInput = None,
    text: TextInput
    | PreTokenizedInput
    | list[TextInput]
    | list[PreTokenizedInput] = None,
    **kwargs: Unpack[Moondream3ProcessorKwargs],
) -> BatchFeature

Process images and text for Moondream3 model.

Parameters:

Name Type Description Default
images ImageInput

Input images (PIL Image, numpy array, or list thereof).

None
text TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput]

Input text or list of texts.

None
**kwargs Unpack[Moondream3ProcessorKwargs]

Additional processing arguments.

{}

Returns:

Type Description
BatchFeature

BatchFeature with processed inputs.

Source code in vllm/transformers_utils/processors/moondream3.py
def __call__(
    self,
    images: ImageInput = None,
    text: TextInput
    | PreTokenizedInput
    | list[TextInput]
    | list[PreTokenizedInput] = None,
    **kwargs: Unpack[Moondream3ProcessorKwargs],
) -> BatchFeature:
    """
    Process images and text for Moondream3 model.

    Args:
        images: Input images (PIL Image, numpy array, or list thereof).
        text: Input text or list of texts.
        **kwargs: Additional processing arguments.

    Returns:
        BatchFeature with processed inputs.
    """
    output_kwargs = self._merge_kwargs(
        Moondream3ProcessorKwargs,
        tokenizer_init_kwargs=self.tokenizer.init_kwargs,
        **kwargs,
    )

    # Process images
    image_features = {}
    if images is not None:
        processed_images = []
        tilings = []

        images_list = images if isinstance(images, list) else [images]
        for image in images_list:
            pixel_values, tiling = self.preprocess_image(
                image, **output_kwargs["images_kwargs"]
            )
            processed_images.append(pixel_values)
            tilings.append(tiling)

        if processed_images:
            image_features["pixel_values"] = processed_images
            image_features["tilings"] = tilings

    # Process text
    if text is not None:
        if not isinstance(text, list):
            text = [text]

        # Get text kwargs, remove keys we set ourselves
        text_kwargs = output_kwargs.get("text_kwargs", {}).copy()
        text_kwargs.pop("return_tensors", None)
        text_kwargs.pop("add_special_tokens", None)

        # Tokenize text
        tokenized = self.tokenizer(
            text,
            add_special_tokens=True,
            return_tensors="pt",
            **text_kwargs,
        )

        output = BatchFeature(data=dict(tokenized))

        # Add image features
        if image_features:
            output["pixel_values"] = image_features["pixel_values"]
            output["tilings"] = image_features["tilings"]

        return output

    # If only images were provided
    return BatchFeature(data=image_features)

batch_decode

batch_decode(*args, **kwargs)

Forward to tokenizer's batch_decode.

Source code in vllm/transformers_utils/processors/moondream3.py
def batch_decode(self, *args, **kwargs):
    """Forward to tokenizer's batch_decode."""
    return self.tokenizer.batch_decode(*args, **kwargs)

decode

decode(*args, **kwargs)

Forward to tokenizer's decode.

Source code in vllm/transformers_utils/processors/moondream3.py
def decode(self, *args, **kwargs):
    """Forward to tokenizer's decode."""
    return self.tokenizer.decode(*args, **kwargs)

from_pretrained classmethod

from_pretrained(pretrained_model_name_or_path, **kwargs)

Load the processor, using a separate tokenizer repo.

The moondream3 model uses a custom tokenizer from 'moondream/starmie-v1' instead of having tokenizer files in the model repo.

Source code in vllm/transformers_utils/processors/moondream3.py
@classmethod
def from_pretrained(
    cls,
    pretrained_model_name_or_path,
    **kwargs,
):
    """
    Load the processor, using a separate tokenizer repo.

    The moondream3 model uses a custom tokenizer from 'moondream/starmie-v1'
    instead of having tokenizer files in the model repo.
    """
    from transformers import AutoTokenizer, PreTrainedTokenizerFast
    from transformers.utils import cached_file

    tokenizer = kwargs.pop("tokenizer", None)

    tokenizer_kwargs = {
        "trust_remote_code": kwargs.get("trust_remote_code", False),
    }
    for key in (
        "cache_dir",
        "force_download",
        "local_files_only",
        "revision",
        "subfolder",
        "token",
        "use_fast",
    ):
        if key in kwargs:
            tokenizer_kwargs[key] = kwargs[key]

    cached_file_kwargs = {
        key: tokenizer_kwargs[key]
        for key in (
            "cache_dir",
            "force_download",
            "local_files_only",
            "revision",
            "subfolder",
            "token",
        )
        if key in tokenizer_kwargs
    }

    def load_tokenizer(repo_or_path):
        try:
            return AutoTokenizer.from_pretrained(repo_or_path, **tokenizer_kwargs)
        except Exception:
            tokenizer_file = cached_file(
                repo_or_path,
                "tokenizer.json",
                **cached_file_kwargs,
            )
            return PreTrainedTokenizerFast(
                tokenizer_file=tokenizer_file,
                clean_up_tokenization_spaces=False,
            )

    if isinstance(tokenizer, str):
        tokenizer = load_tokenizer(tokenizer)

    if tokenizer is None:
        # Prefer model-local tokenizer files first. If unavailable, fall
        # back to moondream's dedicated tokenizer repository.
        try:
            tokenizer = load_tokenizer(pretrained_model_name_or_path)
        except Exception:
            tokenizer = load_tokenizer(cls._tokenizer_repo)

    # Configure special tokens for Moondream3
    # BOS and EOS are both token 0 (<|endoftext|>), matching the native
    # config (TokenizerConfig.bos_id=0, eos_id=0). This is standard for
    # GPT-2 style models where <|endoftext|> signals both start and end.
    # Token 1 (<|md_reserved_0|>) is a template delimiter, NOT the EOS.
    tokenizer.bos_token = "<|endoftext|>"
    tokenizer.bos_token_id = 0
    tokenizer.eos_token = "<|endoftext|>"
    tokenizer.eos_token_id = 0

    # Extract processor-specific kwargs
    crop_size = kwargs.pop("crop_size", 378)
    max_crops = kwargs.pop("max_crops", 12)
    overlap_margin = kwargs.pop("overlap_margin", 4)
    patch_size = kwargs.pop("patch_size", 14)
    chat_template = kwargs.pop("chat_template", None)

    # Set default chat template on tokenizer if not already set
    if chat_template is None:
        chat_template = cls._default_chat_template
    if tokenizer.chat_template is None:
        tokenizer.chat_template = chat_template

    return cls(
        tokenizer=tokenizer,
        chat_template=chat_template,
        crop_size=crop_size,
        max_crops=max_crops,
        overlap_margin=overlap_margin,
        patch_size=patch_size,
    )

get_num_image_tokens

get_num_image_tokens() -> int

Return the number of image tokens (729 = 27x27 patches).

Source code in vllm/transformers_utils/processors/moondream3.py
def get_num_image_tokens(self) -> int:
    """Return the number of image tokens (729 = 27x27 patches)."""
    return self.patches_per_crop

preprocess_image

preprocess_image(
    image: ImageInput,
    max_crops: int = 12,
    overlap_margin: int = 4,
    crop_size: int = 378,
    patch_size: int = 14,
    convert_to_rgb: bool = True,
    return_tensors: str = "pt",
) -> tuple[Tensor, tuple[int, int]]

Preprocess an image using overlap-and-resize cropping strategy.

Parameters:

Name Type Description Default
image ImageInput

Input PIL image, numpy array, or torch tensor.

required
max_crops int

Maximum number of crops.

12
overlap_margin int

Margin for overlapping in patches.

4
crop_size int

Size of each crop.

378
patch_size int

Size of each patch.

14
convert_to_rgb bool

Whether to convert to RGB.

True
return_tensors str

Return type ("pt" for PyTorch).

'pt'

Returns:

Type Description
tuple[Tensor, tuple[int, int]]

Tuple of (pixel_values tensor, tiling tuple).

Source code in vllm/transformers_utils/processors/moondream3.py
def preprocess_image(
    self,
    image: ImageInput,
    max_crops: int = 12,
    overlap_margin: int = 4,
    crop_size: int = 378,
    patch_size: int = 14,
    convert_to_rgb: bool = True,
    return_tensors: str = "pt",
) -> tuple[torch.Tensor, tuple[int, int]]:
    """
    Preprocess an image using overlap-and-resize cropping strategy.

    Args:
        image: Input PIL image, numpy array, or torch tensor.
        max_crops: Maximum number of crops.
        overlap_margin: Margin for overlapping in patches.
        crop_size: Size of each crop.
        patch_size: Size of each patch.
        convert_to_rgb: Whether to convert to RGB.
        return_tensors: Return type ("pt" for PyTorch).

    Returns:
        Tuple of (pixel_values tensor, tiling tuple).
    """
    image = self._to_pil_image(image)
    if convert_to_rgb:
        image = convert_image_mode(image, "RGB")

    # Convert to numpy array
    image_array = np.array(image)
    original_h, original_w = image_array.shape[:2]

    margin_pixels = patch_size * overlap_margin
    total_margin_pixels = margin_pixels * 2

    crop_patches = crop_size // patch_size
    crop_window_patches = crop_patches - (2 * overlap_margin)
    crop_window_size = crop_window_patches * patch_size

    tiling = select_tiling(
        original_h - total_margin_pixels,
        original_w - total_margin_pixels,
        crop_window_size,
        max_crops,
    )

    n_crops = tiling[0] * tiling[1] + 1
    crops = np.zeros((n_crops, crop_size, crop_size, 3), dtype=np.uint8)

    target_size = (
        tiling[0] * crop_window_size + total_margin_pixels,
        tiling[1] * crop_window_size + total_margin_pixels,
    )

    # Resize image
    pil_img = Image.fromarray(image_array)
    resized = pil_img.resize(
        (int(target_size[1]), int(target_size[0])),
        resample=Image.Resampling.LANCZOS,
    )
    resized_array = np.asarray(resized)

    # Create global crop
    global_pil = pil_img.resize(
        (crop_size, crop_size), resample=Image.Resampling.LANCZOS
    )
    crops[0] = np.asarray(global_pil)

    # Create local crops
    for i in range(tiling[0]):
        for j in range(tiling[1]):
            y0 = i * crop_window_size
            x0 = j * crop_window_size
            y_end = min(y0 + crop_size, resized_array.shape[0])
            x_end = min(x0 + crop_size, resized_array.shape[1])

            crop_region = resized_array[y0:y_end, x0:x_end]
            crop_idx = 1 + i * tiling[1] + j
            h_slice = slice(None, crop_region.shape[0])
            w_slice = slice(None, crop_region.shape[1])
            crops[crop_idx, h_slice, w_slice] = crop_region

    # Convert to tensor: (n_crops, H, W, C) -> (n_crops, C, H, W)
    pixel_values = np.transpose(crops, (0, 3, 1, 2))

    if return_tensors == "pt":
        # Match HF reference preprocessing exactly: convert uint8 crops to
        # bfloat16 before in-place normalization.
        pixel_values = (
            torch.from_numpy(pixel_values)
            .to(dtype=torch.bfloat16)
            .div_(255.0)
            .sub_(0.5)
            .div_(0.5)
        )
    else:
        pixel_values = pixel_values.astype(np.float32) / 255.0
        pixel_values = (pixel_values - 0.5) / 0.5

    return pixel_values, tiling

select_tiling

select_tiling(
    height: int, width: int, crop_size: int, max_crops: int
) -> tuple[int, int]

Determine the optimal number of tiles to cover an image.

Source code in vllm/transformers_utils/processors/moondream3.py
def select_tiling(
    height: int, width: int, crop_size: int, max_crops: int
) -> tuple[int, int]:
    """Determine the optimal number of tiles to cover an image."""
    if height <= crop_size or width <= crop_size:
        return (1, 1)

    min_h = math.ceil(height / crop_size)
    min_w = math.ceil(width / crop_size)

    if min_h * min_w > max_crops:
        ratio = math.sqrt(max_crops / (min_h * min_w))
        return (max(1, math.floor(min_h * ratio)), max(1, math.floor(min_w * ratio)))

    h_tiles = math.floor(math.sqrt(max_crops * height / width))
    w_tiles = math.floor(math.sqrt(max_crops * width / height))

    h_tiles = max(h_tiles, min_h)
    w_tiles = max(w_tiles, min_w)

    if h_tiles * w_tiles > max_crops:
        if w_tiles > h_tiles:
            w_tiles = math.floor(max_crops / h_tiles)
        else:
            h_tiles = math.floor(max_crops / w_tiles)

    return (max(1, h_tiles), max(1, w_tiles))