vllm.compilation.sequence_parallelism
FirstAllReduceRMSNormPattern
¶
Bases: _SequenceParallelPatternHelper
Source code in vllm/compilation/sequence_parallelism.py
get_inputs
¶
Source code in vllm/compilation/sequence_parallelism.py
register
¶
Source code in vllm/compilation/sequence_parallelism.py
FirstAllReduceRMSNormStaticFP8Pattern
¶
Bases: _SequenceParallelPatternHelper
Source code in vllm/compilation/sequence_parallelism.py
__init__
¶
get_inputs
¶
Source code in vllm/compilation/sequence_parallelism.py
register
¶
Source code in vllm/compilation/sequence_parallelism.py
LastAllReduceRMSNormPattern
¶
Bases: _SequenceParallelPatternHelper
Source code in vllm/compilation/sequence_parallelism.py
get_inputs
¶
Source code in vllm/compilation/sequence_parallelism.py
register
¶
Source code in vllm/compilation/sequence_parallelism.py
LastAllReduceRMSNormStaticFP8Pattern
¶
Bases: _SequenceParallelPatternHelper
Source code in vllm/compilation/sequence_parallelism.py
__init__
¶
get_inputs
¶
Source code in vllm/compilation/sequence_parallelism.py
register
¶
Source code in vllm/compilation/sequence_parallelism.py
MiddleAllReduceRMSNormPattern
¶
Bases: _SequenceParallelPatternHelper
Source code in vllm/compilation/sequence_parallelism.py
get_inputs
¶
Source code in vllm/compilation/sequence_parallelism.py
register
¶
Source code in vllm/compilation/sequence_parallelism.py
MiddleAllReduceRMSNormStaticFP8Pattern
¶
Bases: _SequenceParallelPatternHelper
Source code in vllm/compilation/sequence_parallelism.py
__init__
¶
get_inputs
¶
Source code in vllm/compilation/sequence_parallelism.py
register
¶
Source code in vllm/compilation/sequence_parallelism.py
SequenceParallelismPass
¶
Bases: VllmInductorPass
This pass enables sequence parallelism for models. It identifies patterns where an AllReduce operation is followed by an RMSNorm (or RMSNorm and then Quantization) operation. These patterns are replaced with a ReduceScatter operation, followed by a local RMSNorm/Quantization, and then an AllGather operation.
The general transformation is: Input -> AllReduce -> RMSNorm -> Output becomes Input -> ReduceScatter -> RMSNorm -> AllGather -> Output
While this pass itself does not directly yield performance improvements, it lays the groundwork for subsequent fusion passes, such as GEMM + ReduceScatter and AllGather + GEMM fusions. These fusions can significantly reduce communication overhead and improve overall model performance.
Source code in vllm/compilation/sequence_parallelism.py
patterns
instance-attribute
¶
__init__
¶
__init__(config: VllmConfig)
Source code in vllm/compilation/sequence_parallelism.py
is_applicable_for_shape
¶
_RMSNormAndQuantOpHelper
¶
Base helper for RMSNorm and RMSNorm + Quantization functionalization.
Source code in vllm/compilation/sequence_parallelism.py
__init__
¶
__init__(
epsilon: float,
dtype: dtype,
device: str,
quant_op: Optional[OpOverload] = None,
**kwargs,
)
Source code in vllm/compilation/sequence_parallelism.py
_functional_fused_add_rmsnorm
¶
Source code in vllm/compilation/sequence_parallelism.py
_functional_fused_add_rmsnorm_then_quant
¶
_functional_fused_add_rmsnorm_then_quant(
quant_result_buffer,
input_tensor,
residual_tensor,
weight_tensor,
scale_tensor,
)
Source code in vllm/compilation/sequence_parallelism.py
_functional_rmsnorm
¶
Source code in vllm/compilation/sequence_parallelism.py
_functional_rmsnorm_then_quant
¶
_functional_rmsnorm_then_quant(
rmsnorm_result_buffer,
quant_result_buffer,
input_tensor,
weight_tensor,
scale_tensor,
)
Source code in vllm/compilation/sequence_parallelism.py
_SequenceParallelPatternHelper
¶
Bases: _RMSNormAndQuantOpHelper
Helper for sequence parallelism patterns.