Skip to content

vllm.model_executor.model_loader.reload.torchao_decorator

support_quantized_model_reload_from_hp_weights

support_quantized_model_reload_from_hp_weights(
    original_load_weights: FunctionType,
)

Decorator for load_weights method for AutoWeightsLoader.load_weights to support reloading high precision (bfloat16/float16/float32) weight for an already quantized model, this involves restoring the weights to a high precision weights and then online quantize the weights.

Only applies to torchao quantized models. Assumes that all model weights are loaded within a single weights iterator (cannot perform batched updates)

Source code in vllm/model_executor/model_loader/reload/torchao_decorator.py
def support_quantized_model_reload_from_hp_weights(original_load_weights: FunctionType):
    """
    Decorator for `load_weights` method for AutoWeightsLoader.load_weights to support
    reloading high precision (bfloat16/float16/float32) weight for an already quantized
    model, this involves restoring the weights to a high precision weights and
    then online quantize the weights.

    Only applies to torchao quantized models. Assumes that all model weights are
    loaded within a single weights iterator (cannot perform batched updates)
    """

    @wraps(original_load_weights)
    def patched_model_load_weights(
        self: "AutoWeightsLoader",
        weights: Iterable[tuple[str, torch.Tensor]],
        *args,
        **kwargs,
    ):
        model = self.module

        if not getattr(model, "_do_torchao_reload", False):
            return original_load_weights(self, weights, *args, **kwargs)

        initialize_layerwise_reload(model)
        loaded_weights = original_load_weights(self, weights, *args, **kwargs)
        finalize_layerwise_reload(model, model._model_config)

        return loaded_weights

    return patched_model_load_weights