Skip to content

vllm.model_executor.model_loader.runai_streamer_loader

RunaiModelStreamerLoader

Bases: BaseModelLoader

Model loader that can load safetensors files from local FS or S3 bucket.

Source code in vllm/model_executor/model_loader/runai_streamer_loader.py
class RunaiModelStreamerLoader(BaseModelLoader):
    """
        Model loader that can load safetensors
        files from local FS or S3 bucket.
    """

    def __init__(self, load_config: LoadConfig):
        super().__init__(load_config)
        if load_config.model_loader_extra_config:
            extra_config = load_config.model_loader_extra_config

            if ("concurrency" in extra_config
                    and isinstance(extra_config.get("concurrency"), int)):
                os.environ["RUNAI_STREAMER_CONCURRENCY"] = str(
                    extra_config.get("concurrency"))

            if ("memory_limit" in extra_config
                    and isinstance(extra_config.get("memory_limit"), int)):
                os.environ["RUNAI_STREAMER_MEMORY_LIMIT"] = str(
                    extra_config.get("memory_limit"))

            runai_streamer_s3_endpoint = os.getenv(
                'RUNAI_STREAMER_S3_ENDPOINT')
            aws_endpoint_url = os.getenv('AWS_ENDPOINT_URL')
            if (runai_streamer_s3_endpoint is None
                    and aws_endpoint_url is not None):
                os.environ["RUNAI_STREAMER_S3_ENDPOINT"] = aws_endpoint_url

    def _prepare_weights(self, model_name_or_path: str,
                         revision: Optional[str]) -> list[str]:
        """Prepare weights for the model.

        If the model is not local, it will be downloaded."""

        is_s3_path = is_s3(model_name_or_path)
        is_local = os.path.isdir(model_name_or_path)
        safetensors_pattern = "*.safetensors"
        index_file = SAFE_WEIGHTS_INDEX_NAME

        hf_folder = (model_name_or_path if
                     (is_local or is_s3_path) else download_weights_from_hf(
                         model_name_or_path,
                         self.load_config.download_dir,
                         [safetensors_pattern],
                         revision,
                         ignore_patterns=self.load_config.ignore_patterns,
                     ))
        if is_s3_path:
            hf_weights_files = s3_glob(path=hf_folder,
                                       allow_pattern=[safetensors_pattern])
        else:
            hf_weights_files = glob.glob(
                os.path.join(hf_folder, safetensors_pattern))

        if not is_local and not is_s3_path:
            download_safetensors_index_file_from_hf(
                model_name_or_path, index_file, self.load_config.download_dir,
                revision)

        if not hf_weights_files:
            raise RuntimeError(
                f"Cannot find any safetensors model weights with "
                f"`{model_name_or_path}`")

        return hf_weights_files

    def _get_weights_iterator(
            self, model_or_path: str,
            revision: str) -> Generator[tuple[str, torch.Tensor], None, None]:
        """Get an iterator for the model weights based on the load format."""
        hf_weights_files = self._prepare_weights(model_or_path, revision)
        return runai_safetensors_weights_iterator(
            hf_weights_files,
            self.load_config.use_tqdm_on_load,
        )

    def download_model(self, model_config: ModelConfig) -> None:
        """Download model if necessary"""
        self._prepare_weights(model_config.model, model_config.revision)

    def load_weights(self, model: nn.Module,
                     model_config: ModelConfig) -> None:
        """Load weights into a model."""
        model_weights = model_config.model
        if hasattr(model_config, "model_weights"):
            model_weights = model_config.model_weights
        model.load_weights(
            self._get_weights_iterator(model_weights, model_config.revision))

__init__

__init__(load_config: LoadConfig)
Source code in vllm/model_executor/model_loader/runai_streamer_loader.py
def __init__(self, load_config: LoadConfig):
    super().__init__(load_config)
    if load_config.model_loader_extra_config:
        extra_config = load_config.model_loader_extra_config

        if ("concurrency" in extra_config
                and isinstance(extra_config.get("concurrency"), int)):
            os.environ["RUNAI_STREAMER_CONCURRENCY"] = str(
                extra_config.get("concurrency"))

        if ("memory_limit" in extra_config
                and isinstance(extra_config.get("memory_limit"), int)):
            os.environ["RUNAI_STREAMER_MEMORY_LIMIT"] = str(
                extra_config.get("memory_limit"))

        runai_streamer_s3_endpoint = os.getenv(
            'RUNAI_STREAMER_S3_ENDPOINT')
        aws_endpoint_url = os.getenv('AWS_ENDPOINT_URL')
        if (runai_streamer_s3_endpoint is None
                and aws_endpoint_url is not None):
            os.environ["RUNAI_STREAMER_S3_ENDPOINT"] = aws_endpoint_url

_get_weights_iterator

_get_weights_iterator(
    model_or_path: str, revision: str
) -> Generator[tuple[str, Tensor], None, None]

Get an iterator for the model weights based on the load format.

Source code in vllm/model_executor/model_loader/runai_streamer_loader.py
def _get_weights_iterator(
        self, model_or_path: str,
        revision: str) -> Generator[tuple[str, torch.Tensor], None, None]:
    """Get an iterator for the model weights based on the load format."""
    hf_weights_files = self._prepare_weights(model_or_path, revision)
    return runai_safetensors_weights_iterator(
        hf_weights_files,
        self.load_config.use_tqdm_on_load,
    )

_prepare_weights

_prepare_weights(
    model_name_or_path: str, revision: Optional[str]
) -> list[str]

Prepare weights for the model.

If the model is not local, it will be downloaded.

Source code in vllm/model_executor/model_loader/runai_streamer_loader.py
def _prepare_weights(self, model_name_or_path: str,
                     revision: Optional[str]) -> list[str]:
    """Prepare weights for the model.

    If the model is not local, it will be downloaded."""

    is_s3_path = is_s3(model_name_or_path)
    is_local = os.path.isdir(model_name_or_path)
    safetensors_pattern = "*.safetensors"
    index_file = SAFE_WEIGHTS_INDEX_NAME

    hf_folder = (model_name_or_path if
                 (is_local or is_s3_path) else download_weights_from_hf(
                     model_name_or_path,
                     self.load_config.download_dir,
                     [safetensors_pattern],
                     revision,
                     ignore_patterns=self.load_config.ignore_patterns,
                 ))
    if is_s3_path:
        hf_weights_files = s3_glob(path=hf_folder,
                                   allow_pattern=[safetensors_pattern])
    else:
        hf_weights_files = glob.glob(
            os.path.join(hf_folder, safetensors_pattern))

    if not is_local and not is_s3_path:
        download_safetensors_index_file_from_hf(
            model_name_or_path, index_file, self.load_config.download_dir,
            revision)

    if not hf_weights_files:
        raise RuntimeError(
            f"Cannot find any safetensors model weights with "
            f"`{model_name_or_path}`")

    return hf_weights_files

download_model

download_model(model_config: ModelConfig) -> None

Download model if necessary

Source code in vllm/model_executor/model_loader/runai_streamer_loader.py
def download_model(self, model_config: ModelConfig) -> None:
    """Download model if necessary"""
    self._prepare_weights(model_config.model, model_config.revision)

load_weights

load_weights(
    model: Module, model_config: ModelConfig
) -> None

Load weights into a model.

Source code in vllm/model_executor/model_loader/runai_streamer_loader.py
def load_weights(self, model: nn.Module,
                 model_config: ModelConfig) -> None:
    """Load weights into a model."""
    model_weights = model_config.model
    if hasattr(model_config, "model_weights"):
        model_weights = model_config.model_weights
    model.load_weights(
        self._get_weights_iterator(model_weights, model_config.revision))