Skip to content

vllm.model_executor.model_loader.sharded_state_loader

logger module-attribute

logger = init_logger(__name__)

ShardedStateLoader

Bases: BaseModelLoader

Model loader that directly loads each worker's model state dict, which enables a fast load path for large tensor-parallel models where each worker only needs to read its own shard rather than the entire checkpoint. See examples/offline_inference/save_sharded_state.py for creating a sharded checkpoint.

Source code in vllm/model_executor/model_loader/sharded_state_loader.py
class ShardedStateLoader(BaseModelLoader):
    """
    Model loader that directly loads each worker's model state dict, which
    enables a fast load path for large tensor-parallel models where each worker
    only needs to read its own shard rather than the entire checkpoint. See
    `examples/offline_inference/save_sharded_state.py` for creating a sharded
    checkpoint.
    """

    DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"

    def __init__(self,
                 load_config: LoadConfig,
                 runai_model_streamer: bool = False):
        super().__init__(load_config)

        self.runai_model_streamer = runai_model_streamer
        extra_config = ({} if load_config.model_loader_extra_config is None
                        else load_config.model_loader_extra_config.copy())
        self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN)
        if extra_config:
            raise ValueError(f"Unexpected extra config keys for load format "
                             f"{load_config.load_format}: "
                             f"{load_config.model_loader_extra_config.keys()}")

    @staticmethod
    def _filter_subtensors(
        tensors: dict[str, torch.Tensor], ) -> dict[str, torch.Tensor]:
        """
        Filter out all tensors that share the same memory or a subset of the
        memory of another tensor.
        """
        same_storage_groups: dict[Any, list[tuple[str, torch.Tensor]]] = (
            collections.defaultdict(list))
        for key, tensor in tensors.items():
            if tensor.numel():
                ptr = tensor.untyped_storage().data_ptr()
                same_storage_groups[tensor.device, ptr].append((key, tensor))

        def get_end_ptr(tensor: torch.Tensor) -> int:
            return tensor.view(-1)[-1].data_ptr() + tensor.element_size()

        result: dict[str, torch.Tensor] = {}
        for group in same_storage_groups.values():
            for k, t in group:
                a, b = t.data_ptr(), get_end_ptr(t)
                for k2, t2 in group:
                    if not t2.is_contiguous():
                        continue
                    a2, b2 = t2.data_ptr(), get_end_ptr(t2)
                    if a < a2 or b2 < b:
                        continue
                    if a2 < a or b < b2 or not t.is_contiguous():
                        break  # t2 covers strictly more memory than t.
                    if k2 < k:
                        # Same tensors, keep the one with the smaller key.
                        break
                else:
                    result[k] = t
        return result

    def _prepare_weights(self, model_name_or_path: str,
                         revision: Optional[str]):
        if is_s3(model_name_or_path) or os.path.isdir(model_name_or_path):
            return model_name_or_path
        else:
            allow_patterns = ["*.safetensors"]
            return download_weights_from_hf(
                model_name_or_path,
                self.load_config.download_dir,
                allow_patterns,
                revision,
                ignore_patterns=self.load_config.ignore_patterns,
            )

    def download_model(self, model_config: ModelConfig) -> None:
        self._prepare_weights(model_config.model, model_config.revision)

    def load_weights(self, model: nn.Module,
                     model_config: ModelConfig) -> None:
        from vllm.distributed import get_tensor_model_parallel_rank

        model_weights = model_config.model
        if hasattr(model_config, "model_weights"):
            model_weights = model_config.model_weights
        local_model_path = model_weights

        rank = get_tensor_model_parallel_rank()
        pattern = os.path.join(
            local_model_path,
            self.pattern.format(rank=rank, part="*"),
        )

        filepaths = []
        if is_s3(local_model_path):
            file_pattern = f"*{self.pattern.format(rank=rank, part=' * ')}"
            filepaths = s3_glob(path=local_model_path,
                                allow_pattern=[file_pattern])
        else:
            filepaths = glob.glob(pattern)
        if not filepaths:
            # TODO: support un-sharded checkpoints too
            raise ValueError(
                f"Could not find checkpoint files '{pattern}', only "
                f"pre-sharded checkpoints are currently supported!")
        state_dict = self._filter_subtensors(model.state_dict())
        for key, tensor in self.iterate_over_files(filepaths):
            # If loading with LoRA enabled, additional padding may
            # be added to certain parameters. We only load into a
            # narrowed view of the parameter data.
            param_data = state_dict[key].data
            param_shape = state_dict[key].shape
            for dim, size in enumerate(tensor.shape):
                if size < param_shape[dim]:
                    param_data = param_data.narrow(dim, 0, size)
            if tensor.shape != param_shape:
                logger.warning(
                    "loading tensor of shape %s into "
                    "parameter '%s' of shape %s",
                    tensor.shape,
                    key,
                    param_shape,
                )
            param_data.copy_(tensor)
            state_dict.pop(key)
        if state_dict:
            raise ValueError(
                f"Missing keys {tuple(state_dict)} in loaded state!")

    def iterate_over_files(
            self, paths) -> Generator[tuple[str, torch.Tensor], None, None]:
        if self.runai_model_streamer:
            yield from runai_safetensors_weights_iterator(paths, True)
        else:
            from safetensors.torch import safe_open
            for path in paths:
                with safe_open(path, framework="pt") as f:
                    for key in f.keys():  # noqa: SIM118
                        tensor = f.get_tensor(key)
                        yield key, tensor

    @staticmethod
    def save_model(
        model: torch.nn.Module,
        path: str,
        pattern: Optional[str] = None,
        max_size: Optional[int] = None,
    ) -> None:
        from safetensors.torch import save_file

        from vllm.distributed import get_tensor_model_parallel_rank

        if pattern is None:
            pattern = ShardedStateLoader.DEFAULT_PATTERN
        rank = get_tensor_model_parallel_rank()
        part_idx = 0
        total_size = 0
        state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())
        state_dict_part: dict[str, torch.Tensor] = {}
        for key, tensor in state_dict.items():
            param_size = tensor.nelement() * tensor.element_size()
            if max_size is not None and total_size + param_size > max_size:
                filename = pattern.format(rank=rank, part=part_idx)
                save_file(
                    state_dict_part,
                    os.path.join(path, filename),
                )
                part_idx += 1
                total_size = 0
                state_dict_part = {}
            state_dict_part[key] = tensor
            total_size += param_size
        if len(state_dict_part) > 0:
            filename = pattern.format(rank=rank, part=part_idx)
            save_file(
                state_dict_part,
                os.path.join(path, filename),
            )

DEFAULT_PATTERN class-attribute instance-attribute

DEFAULT_PATTERN = (
    "model-rank-{rank}-part-{part}.safetensors"
)

pattern instance-attribute

pattern = pop('pattern', DEFAULT_PATTERN)

runai_model_streamer instance-attribute

runai_model_streamer = runai_model_streamer

__init__

__init__(
    load_config: LoadConfig,
    runai_model_streamer: bool = False,
)
Source code in vllm/model_executor/model_loader/sharded_state_loader.py
def __init__(self,
             load_config: LoadConfig,
             runai_model_streamer: bool = False):
    super().__init__(load_config)

    self.runai_model_streamer = runai_model_streamer
    extra_config = ({} if load_config.model_loader_extra_config is None
                    else load_config.model_loader_extra_config.copy())
    self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN)
    if extra_config:
        raise ValueError(f"Unexpected extra config keys for load format "
                         f"{load_config.load_format}: "
                         f"{load_config.model_loader_extra_config.keys()}")

_filter_subtensors staticmethod

_filter_subtensors(
    tensors: dict[str, Tensor],
) -> dict[str, Tensor]

Filter out all tensors that share the same memory or a subset of the memory of another tensor.

Source code in vllm/model_executor/model_loader/sharded_state_loader.py
@staticmethod
def _filter_subtensors(
    tensors: dict[str, torch.Tensor], ) -> dict[str, torch.Tensor]:
    """
    Filter out all tensors that share the same memory or a subset of the
    memory of another tensor.
    """
    same_storage_groups: dict[Any, list[tuple[str, torch.Tensor]]] = (
        collections.defaultdict(list))
    for key, tensor in tensors.items():
        if tensor.numel():
            ptr = tensor.untyped_storage().data_ptr()
            same_storage_groups[tensor.device, ptr].append((key, tensor))

    def get_end_ptr(tensor: torch.Tensor) -> int:
        return tensor.view(-1)[-1].data_ptr() + tensor.element_size()

    result: dict[str, torch.Tensor] = {}
    for group in same_storage_groups.values():
        for k, t in group:
            a, b = t.data_ptr(), get_end_ptr(t)
            for k2, t2 in group:
                if not t2.is_contiguous():
                    continue
                a2, b2 = t2.data_ptr(), get_end_ptr(t2)
                if a < a2 or b2 < b:
                    continue
                if a2 < a or b < b2 or not t.is_contiguous():
                    break  # t2 covers strictly more memory than t.
                if k2 < k:
                    # Same tensors, keep the one with the smaller key.
                    break
            else:
                result[k] = t
    return result

_prepare_weights

_prepare_weights(
    model_name_or_path: str, revision: Optional[str]
)
Source code in vllm/model_executor/model_loader/sharded_state_loader.py
def _prepare_weights(self, model_name_or_path: str,
                     revision: Optional[str]):
    if is_s3(model_name_or_path) or os.path.isdir(model_name_or_path):
        return model_name_or_path
    else:
        allow_patterns = ["*.safetensors"]
        return download_weights_from_hf(
            model_name_or_path,
            self.load_config.download_dir,
            allow_patterns,
            revision,
            ignore_patterns=self.load_config.ignore_patterns,
        )

download_model

download_model(model_config: ModelConfig) -> None
Source code in vllm/model_executor/model_loader/sharded_state_loader.py
def download_model(self, model_config: ModelConfig) -> None:
    self._prepare_weights(model_config.model, model_config.revision)

iterate_over_files

iterate_over_files(
    paths,
) -> Generator[tuple[str, Tensor], None, None]
Source code in vllm/model_executor/model_loader/sharded_state_loader.py
def iterate_over_files(
        self, paths) -> Generator[tuple[str, torch.Tensor], None, None]:
    if self.runai_model_streamer:
        yield from runai_safetensors_weights_iterator(paths, True)
    else:
        from safetensors.torch import safe_open
        for path in paths:
            with safe_open(path, framework="pt") as f:
                for key in f.keys():  # noqa: SIM118
                    tensor = f.get_tensor(key)
                    yield key, tensor

load_weights

load_weights(
    model: Module, model_config: ModelConfig
) -> None
Source code in vllm/model_executor/model_loader/sharded_state_loader.py
def load_weights(self, model: nn.Module,
                 model_config: ModelConfig) -> None:
    from vllm.distributed import get_tensor_model_parallel_rank

    model_weights = model_config.model
    if hasattr(model_config, "model_weights"):
        model_weights = model_config.model_weights
    local_model_path = model_weights

    rank = get_tensor_model_parallel_rank()
    pattern = os.path.join(
        local_model_path,
        self.pattern.format(rank=rank, part="*"),
    )

    filepaths = []
    if is_s3(local_model_path):
        file_pattern = f"*{self.pattern.format(rank=rank, part=' * ')}"
        filepaths = s3_glob(path=local_model_path,
                            allow_pattern=[file_pattern])
    else:
        filepaths = glob.glob(pattern)
    if not filepaths:
        # TODO: support un-sharded checkpoints too
        raise ValueError(
            f"Could not find checkpoint files '{pattern}', only "
            f"pre-sharded checkpoints are currently supported!")
    state_dict = self._filter_subtensors(model.state_dict())
    for key, tensor in self.iterate_over_files(filepaths):
        # If loading with LoRA enabled, additional padding may
        # be added to certain parameters. We only load into a
        # narrowed view of the parameter data.
        param_data = state_dict[key].data
        param_shape = state_dict[key].shape
        for dim, size in enumerate(tensor.shape):
            if size < param_shape[dim]:
                param_data = param_data.narrow(dim, 0, size)
        if tensor.shape != param_shape:
            logger.warning(
                "loading tensor of shape %s into "
                "parameter '%s' of shape %s",
                tensor.shape,
                key,
                param_shape,
            )
        param_data.copy_(tensor)
        state_dict.pop(key)
    if state_dict:
        raise ValueError(
            f"Missing keys {tuple(state_dict)} in loaded state!")

save_model staticmethod

save_model(
    model: Module,
    path: str,
    pattern: Optional[str] = None,
    max_size: Optional[int] = None,
) -> None
Source code in vllm/model_executor/model_loader/sharded_state_loader.py
@staticmethod
def save_model(
    model: torch.nn.Module,
    path: str,
    pattern: Optional[str] = None,
    max_size: Optional[int] = None,
) -> None:
    from safetensors.torch import save_file

    from vllm.distributed import get_tensor_model_parallel_rank

    if pattern is None:
        pattern = ShardedStateLoader.DEFAULT_PATTERN
    rank = get_tensor_model_parallel_rank()
    part_idx = 0
    total_size = 0
    state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())
    state_dict_part: dict[str, torch.Tensor] = {}
    for key, tensor in state_dict.items():
        param_size = tensor.nelement() * tensor.element_size()
        if max_size is not None and total_size + param_size > max_size:
            filename = pattern.format(rank=rank, part=part_idx)
            save_file(
                state_dict_part,
                os.path.join(path, filename),
            )
            part_idx += 1
            total_size = 0
            state_dict_part = {}
        state_dict_part[key] = tensor
        total_size += param_size
    if len(state_dict_part) > 0:
        filename = pattern.format(rank=rank, part=part_idx)
        save_file(
            state_dict_part,
            os.path.join(path, filename),
        )