vllm.distributed.tpu_distributed_utils
MODULE_TYPE_TO_WRAPPING_FUNC
module-attribute
¶
MODULE_TYPE_TO_WRAPPING_FUNC = OrderedDict(
[
(
"QKVParallelLinear",
partition_qkv_parallel_linear,
),
(
"ColumnParallelLinear",
partition_column_parallel_linear,
),
(
"RowParallelLinear",
partition_row_parallel_linear,
),
]
)
XlaQKVParallelLinear
¶
Bases: Module
Source code in vllm/distributed/tpu_distributed_utils.py
__init__
¶
Source code in vllm/distributed/tpu_distributed_utils.py
_load_weights_from_qkv_linear
¶
_load_weights_from_qkv_linear(qkv_linear: Module)
Source code in vllm/distributed/tpu_distributed_utils.py
_shard_weight
¶
Source code in vllm/distributed/tpu_distributed_utils.py
forward
¶
Source code in vllm/distributed/tpu_distributed_utils.py
get_fqn
¶
partition_column_parallel_linear
¶
Source code in vllm/distributed/tpu_distributed_utils.py
partition_qkv_parallel_linear
¶
Source code in vllm/distributed/tpu_distributed_utils.py
partition_row_parallel_linear
¶
Source code in vllm/distributed/tpu_distributed_utils.py
shard_model
¶
shard_model(model: Module, mesh: Mesh) -> None
Recursively check a PyTorch model and apply appropriate sharding based on the MODULE_TYPE_TO_WRAPPING_FUNC mapping.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model
|
Module
|
torch.nn.Module to process |
required |
mesh
|
Mesh
|
An XLA SPMD mesh object used for sharding |
required |