Skip to content

vllm.model_executor.model_loader.tpu

logger module-attribute

logger = init_logger(__name__)

TPUModelLoader

Bases: DefaultModelLoader

A TPU model loader for model loading under SPMD mode.

Source code in vllm/model_executor/model_loader/tpu.py
class TPUModelLoader(DefaultModelLoader):
    """
    A TPU model loader for model loading under SPMD mode.
    """

    def load_model(
        self,
        vllm_config: VllmConfig,
        model_config: ModelConfig,
        mesh: Optional[xs.Mesh] = None,
    ) -> nn.Module:
        # Initialize model and load weights on CPU. Then, during SPMD partition,
        # weights are sharded and transferred to TPUs.
        self.counter_before_loading_weights = time.perf_counter()
        model_config = vllm_config.model_config
        assert model_config.quantization is None, "Quantization not supported"
        target_device = torch.device('cpu')
        with set_default_torch_dtype(model_config.dtype):
            with target_device:
                model = initialize_model(vllm_config=vllm_config)

            load_format = vllm_config.load_config.load_format
            if load_format != "dummy":
                weights_to_load = {
                    name
                    for name, _ in model.named_parameters()
                }
                all_weights = self.get_all_weights(model_config, model)
                loaded_weights = model.load_weights(all_weights)
                self.counter_after_loading_weights = time.perf_counter()
                logger.info(
                    "Loading weights took %.2f seconds",
                    self.counter_after_loading_weights -
                    self.counter_before_loading_weights)
                # We only enable strict check for non-quantized models
                # that have loaded weights tracking currently.
                if model_config.quantization is None and \
                    loaded_weights is not None:
                    weights_not_loaded = weights_to_load - loaded_weights
                    if weights_not_loaded:
                        raise ValueError(
                            "Following weights were not initialized from "
                            f"checkpoint: {weights_not_loaded}")
            else:
                logger.info("Use dummy weight during weight loading.")

            process_weights_after_loading(model, model_config, target_device)

        counter_before_partition = time.perf_counter()
        model = model.eval()
        model = model.to('xla')
        shard_model(model, mesh)
        counter_after_partition = time.perf_counter()
        logger.info("Partition model took %.2f seconds",
                    counter_after_partition - counter_before_partition)

        # Ensure the model is properly loaded.
        self._check_model_is_loaded(mesh, model)

        # Need to torch compile after model sharding are done. Because the
        # compiler hints ('xs.mark_sharding') are torch ops.
        if not model_config.is_multimodal_model:
            model.model = torch.compile(model.model, backend="openxla")
        else:
            model.language_model.model = \
                torch.compile(model.language_model.model, backend="openxla")
        return model

    def _check_model_is_loaded(self, mesh: Optional[xs.Mesh],
                               model: nn.Module) -> None:
        """
        Ensure the model is properly loaded.
        1. All model parameters and buffers are on XLA device.
        2. Non-SPMD friendly layers are replaced as expected.
        """
        device = xm.xla_device()
        device_type = str(device.type)

        # Check parameters
        for name, param in model.named_parameters():
            assert param.device.type == device_type, f"Parameter {name} is on \
                {param.device.type} instead of {device_type}"

        # Check buffers
        for name, buffer in model.named_buffers():
            assert buffer.device.type == device_type, \
                f"Buffer {name} is on {buffer.device.type} instead of \
                    {device_type}"

        for module in model.modules():
            if (mesh is not None) and (get_fqn(module) == 'QKVParallelLinear'):
                raise AssertionError("QKVParallelLinear should be replaced by \
                            XlaQKVParallelLinear under SPMD mode.")

_check_model_is_loaded

_check_model_is_loaded(
    mesh: Optional[Mesh], model: Module
) -> None

Ensure the model is properly loaded. 1. All model parameters and buffers are on XLA device. 2. Non-SPMD friendly layers are replaced as expected.

Source code in vllm/model_executor/model_loader/tpu.py
def _check_model_is_loaded(self, mesh: Optional[xs.Mesh],
                           model: nn.Module) -> None:
    """
    Ensure the model is properly loaded.
    1. All model parameters and buffers are on XLA device.
    2. Non-SPMD friendly layers are replaced as expected.
    """
    device = xm.xla_device()
    device_type = str(device.type)

    # Check parameters
    for name, param in model.named_parameters():
        assert param.device.type == device_type, f"Parameter {name} is on \
            {param.device.type} instead of {device_type}"

    # Check buffers
    for name, buffer in model.named_buffers():
        assert buffer.device.type == device_type, \
            f"Buffer {name} is on {buffer.device.type} instead of \
                {device_type}"

    for module in model.modules():
        if (mesh is not None) and (get_fqn(module) == 'QKVParallelLinear'):
            raise AssertionError("QKVParallelLinear should be replaced by \
                        XlaQKVParallelLinear under SPMD mode.")

load_model

load_model(
    vllm_config: VllmConfig,
    model_config: ModelConfig,
    mesh: Optional[Mesh] = None,
) -> Module
Source code in vllm/model_executor/model_loader/tpu.py
def load_model(
    self,
    vllm_config: VllmConfig,
    model_config: ModelConfig,
    mesh: Optional[xs.Mesh] = None,
) -> nn.Module:
    # Initialize model and load weights on CPU. Then, during SPMD partition,
    # weights are sharded and transferred to TPUs.
    self.counter_before_loading_weights = time.perf_counter()
    model_config = vllm_config.model_config
    assert model_config.quantization is None, "Quantization not supported"
    target_device = torch.device('cpu')
    with set_default_torch_dtype(model_config.dtype):
        with target_device:
            model = initialize_model(vllm_config=vllm_config)

        load_format = vllm_config.load_config.load_format
        if load_format != "dummy":
            weights_to_load = {
                name
                for name, _ in model.named_parameters()
            }
            all_weights = self.get_all_weights(model_config, model)
            loaded_weights = model.load_weights(all_weights)
            self.counter_after_loading_weights = time.perf_counter()
            logger.info(
                "Loading weights took %.2f seconds",
                self.counter_after_loading_weights -
                self.counter_before_loading_weights)
            # We only enable strict check for non-quantized models
            # that have loaded weights tracking currently.
            if model_config.quantization is None and \
                loaded_weights is not None:
                weights_not_loaded = weights_to_load - loaded_weights
                if weights_not_loaded:
                    raise ValueError(
                        "Following weights were not initialized from "
                        f"checkpoint: {weights_not_loaded}")
        else:
            logger.info("Use dummy weight during weight loading.")

        process_weights_after_loading(model, model_config, target_device)

    counter_before_partition = time.perf_counter()
    model = model.eval()
    model = model.to('xla')
    shard_model(model, mesh)
    counter_after_partition = time.perf_counter()
    logger.info("Partition model took %.2f seconds",
                counter_after_partition - counter_before_partition)

    # Ensure the model is properly loaded.
    self._check_model_is_loaded(mesh, model)

    # Need to torch compile after model sharding are done. Because the
    # compiler hints ('xs.mark_sharding') are torch ops.
    if not model_config.is_multimodal_model:
        model.model = torch.compile(model.model, backend="openxla")
    else:
        model.language_model.model = \
            torch.compile(model.language_model.model, backend="openxla")
    return model