Skip to content

vllm.compilation.passes.fusion.collective_fusion

FP8_DTYPE module-attribute

FP8_DTYPE = fp8_dtype()

logger module-attribute

logger = init_logger(__name__)

AllGatherCutlassScaledMMPattern

Bases: BasePattern

Source code in vllm/compilation/passes/fusion/collective_fusion.py
class AllGatherCutlassScaledMMPattern(BasePattern):
    def get_inputs(self) -> list[torch.Tensor]:
        x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE)
        weight = (
            torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
            .contiguous()
            .transpose(0, 1)
        )

        s1 = x.shape[0] * self.tp_size

        scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32)
        scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)

        s2 = weight.shape[1]
        output = torch.empty([s1, s2], device=self.device, dtype=self.dtype)

        return [x, weight, scale_a, scale_b, output]

    def register(self, pm_pass: PatternMatcherPass) -> None:
        def pattern(
            x: torch.Tensor,
            weight: torch.Tensor,
            scale_a: torch.Tensor,
            scale_b: torch.Tensor,
            output: torch.Tensor,
        ) -> torch.Tensor:
            all_gather = torch.ops.vllm.all_gather.default(
                x, dim=0, world_size=self.tp_size, group_name=self.tp.unique_name
            )

            cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized(
                torch.ops._C.cutlass_scaled_mm.default,
                out=output,
                a=all_gather,
                b=weight,
                a_scales=scale_a,
                b_scales=scale_b,
                bias=None,
            )
            return cutlass_scaled_mm[1]

        def replacement(
            x: torch.Tensor,
            weight: torch.Tensor,
            scale_a: torch.Tensor,
            scale_b: torch.Tensor,
            output: torch.Tensor,
        ) -> torch.Tensor:
            ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul(  # noqa
                x,
                [weight],
                scale_a,
                [scale_b],
                gather_dim=0,
                biases=[None],
                result_scales=[None],
                out_dtypes=[self.dtype],
                use_fast_accum=[False],
                group_name=self.tp.device_group.group_name,
            )
            return mm_outputs

        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )

get_inputs

get_inputs() -> list[Tensor]
Source code in vllm/compilation/passes/fusion/collective_fusion.py
def get_inputs(self) -> list[torch.Tensor]:
    x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE)
    weight = (
        torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
        .contiguous()
        .transpose(0, 1)
    )

    s1 = x.shape[0] * self.tp_size

    scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32)
    scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)

    s2 = weight.shape[1]
    output = torch.empty([s1, s2], device=self.device, dtype=self.dtype)

    return [x, weight, scale_a, scale_b, output]

register

register(pm_pass: PatternMatcherPass) -> None
Source code in vllm/compilation/passes/fusion/collective_fusion.py
def register(self, pm_pass: PatternMatcherPass) -> None:
    def pattern(
        x: torch.Tensor,
        weight: torch.Tensor,
        scale_a: torch.Tensor,
        scale_b: torch.Tensor,
        output: torch.Tensor,
    ) -> torch.Tensor:
        all_gather = torch.ops.vllm.all_gather.default(
            x, dim=0, world_size=self.tp_size, group_name=self.tp.unique_name
        )

        cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized(
            torch.ops._C.cutlass_scaled_mm.default,
            out=output,
            a=all_gather,
            b=weight,
            a_scales=scale_a,
            b_scales=scale_b,
            bias=None,
        )
        return cutlass_scaled_mm[1]

    def replacement(
        x: torch.Tensor,
        weight: torch.Tensor,
        scale_a: torch.Tensor,
        scale_b: torch.Tensor,
        output: torch.Tensor,
    ) -> torch.Tensor:
        ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul(  # noqa
            x,
            [weight],
            scale_a,
            [scale_b],
            gather_dim=0,
            biases=[None],
            result_scales=[None],
            out_dtypes=[self.dtype],
            use_fast_accum=[False],
            group_name=self.tp.device_group.group_name,
        )
        return mm_outputs

    pm.register_replacement(
        pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
    )

AllGatherGEMMPattern

Bases: BasePattern

Source code in vllm/compilation/passes/fusion/collective_fusion.py
class AllGatherGEMMPattern(BasePattern):
    def get_inputs(self) -> list[torch.Tensor]:
        x = torch.empty([4, 4], device=self.device, dtype=self.dtype)
        weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)

        return [x, weight]

    def register(self, pm_pass: PatternMatcherPass) -> None:
        def pattern(
            x: torch.Tensor,
            weight: torch.Tensor,
        ) -> torch.Tensor:
            all_gather = torch.ops.vllm.all_gather.default(
                x,
                dim=0,
                world_size=self.tp_size,
                group_name=self.tp.unique_name,
            )

            return torch.ops.aten.mm.default(all_gather, weight)

        def replacement(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
            ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_matmul(
                x,
                [weight],
                gather_dim=0,
                group_name=self.tp.device_group.group_name,
            )
            return mm_outputs

        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )

get_inputs

get_inputs() -> list[Tensor]
Source code in vllm/compilation/passes/fusion/collective_fusion.py
def get_inputs(self) -> list[torch.Tensor]:
    x = torch.empty([4, 4], device=self.device, dtype=self.dtype)
    weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)

    return [x, weight]

register

register(pm_pass: PatternMatcherPass) -> None
Source code in vllm/compilation/passes/fusion/collective_fusion.py
def register(self, pm_pass: PatternMatcherPass) -> None:
    def pattern(
        x: torch.Tensor,
        weight: torch.Tensor,
    ) -> torch.Tensor:
        all_gather = torch.ops.vllm.all_gather.default(
            x,
            dim=0,
            world_size=self.tp_size,
            group_name=self.tp.unique_name,
        )

        return torch.ops.aten.mm.default(all_gather, weight)

    def replacement(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
        ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_matmul(
            x,
            [weight],
            gather_dim=0,
            group_name=self.tp.device_group.group_name,
        )
        return mm_outputs

    pm.register_replacement(
        pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
    )

AllGatherScaledMMPattern

Bases: BasePattern

Source code in vllm/compilation/passes/fusion/collective_fusion.py
class AllGatherScaledMMPattern(BasePattern):
    def get_inputs(self) -> list[torch.Tensor]:
        x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE)
        weight = (
            torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
            .contiguous()
            .transpose(0, 1)
        )

        s1 = x.shape[0] * self.tp_size

        scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32)
        scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)

        return [x, weight, scale_a, scale_b]

    def register(self, pm_pass: PatternMatcherPass) -> None:
        def pattern(
            x: torch.Tensor,
            weight: torch.Tensor,
            scale_a: torch.Tensor,
            scale_b: torch.Tensor,
        ) -> torch.Tensor:
            all_gather = torch.ops.vllm.all_gather.default(
                x, dim=0, world_size=self.tp_size, group_name=self.tp.unique_name
            )

            return torch.ops.aten._scaled_mm.default(
                all_gather,
                mat2=weight,
                scale_a=scale_a,
                scale_b=scale_b,
                bias=None,
                scale_result=None,
                out_dtype=self.dtype,
            )

        def replacement(
            x: torch.Tensor,
            weight: torch.Tensor,
            scale_a: torch.Tensor,
            scale_b: torch.Tensor,
        ) -> torch.Tensor:
            ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul(  # noqa
                x,
                [weight],
                scale_a,
                [scale_b],
                gather_dim=0,
                biases=[None],
                result_scales=[None],
                out_dtypes=[self.dtype],
                use_fast_accum=[False],
                group_name=self.tp.device_group.group_name,
            )
            return mm_outputs

        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )

get_inputs

get_inputs() -> list[Tensor]
Source code in vllm/compilation/passes/fusion/collective_fusion.py
def get_inputs(self) -> list[torch.Tensor]:
    x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE)
    weight = (
        torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
        .contiguous()
        .transpose(0, 1)
    )

    s1 = x.shape[0] * self.tp_size

    scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32)
    scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)

    return [x, weight, scale_a, scale_b]

register

register(pm_pass: PatternMatcherPass) -> None
Source code in vllm/compilation/passes/fusion/collective_fusion.py
def register(self, pm_pass: PatternMatcherPass) -> None:
    def pattern(
        x: torch.Tensor,
        weight: torch.Tensor,
        scale_a: torch.Tensor,
        scale_b: torch.Tensor,
    ) -> torch.Tensor:
        all_gather = torch.ops.vllm.all_gather.default(
            x, dim=0, world_size=self.tp_size, group_name=self.tp.unique_name
        )

        return torch.ops.aten._scaled_mm.default(
            all_gather,
            mat2=weight,
            scale_a=scale_a,
            scale_b=scale_b,
            bias=None,
            scale_result=None,
            out_dtype=self.dtype,
        )

    def replacement(
        x: torch.Tensor,
        weight: torch.Tensor,
        scale_a: torch.Tensor,
        scale_b: torch.Tensor,
    ) -> torch.Tensor:
        ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul(  # noqa
            x,
            [weight],
            scale_a,
            [scale_b],
            gather_dim=0,
            biases=[None],
            result_scales=[None],
            out_dtypes=[self.dtype],
            use_fast_accum=[False],
            group_name=self.tp.device_group.group_name,
        )
        return mm_outputs

    pm.register_replacement(
        pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
    )

AsyncTPPass

Bases: VllmPatternMatcherPass

Source code in vllm/compilation/passes/fusion/collective_fusion.py
class AsyncTPPass(VllmPatternMatcherPass):
    @enable_fake_mode
    def __init__(self, config: VllmConfig) -> None:
        super().__init__(config)

        # Enable symmetric memory for the TP process group
        enable_symm_mem_for_group(get_tp_group().device_group.group_name)
        self.patterns: PatternMatcherPass = PatternMatcherPass(
            pass_name="async_tp_pass"
        )
        GEMMReduceScatterPattern(self.model_dtype, self.device).register(self.patterns)

        AllGatherGEMMPattern(self.model_dtype, self.device).register(self.patterns)

        # These fusions are enabled only for bfloat16 models because
        # `scaled_mm` or `cutlass_scaled_mm` with per-token (row-wise) scaling
        # only supports bfloat16 as the output dtype.
        if self.model_dtype == torch.bfloat16:
            ScaledMMReduceScatterPattern(self.model_dtype, self.device).register(
                self.patterns
            )
            AllGatherScaledMMPattern(self.model_dtype, self.device).register(
                self.patterns
            )

            CutlassScaledMMReduceScatterPattern(self.model_dtype, self.device).register(
                self.patterns
            )
            AllGatherCutlassScaledMMPattern(self.model_dtype, self.device).register(
                self.patterns
            )

        self.dump_patterns(config, self.patterns)

    def is_applicable_for_range(self, compile_range: Range) -> bool:
        # This pass is applied on top of the sequence parallelism pass.
        # It inherits the same applicability condition as `SequenceParallelismPass`.
        # See `SequenceParallelismPass.is_applicable` for more details.
        if (
            not self.compilation_config.splitting_ops
            or self.compilation_config.use_inductor_graph_partition
        ):
            return True
        tp_size = get_tensor_model_parallel_world_size()
        return bool(compile_range.is_single_size() and compile_range.end % tp_size == 0)

    @VllmInductorPass.time_and_log
    def __call__(self, graph: fx.Graph) -> None:
        self.matched_count = self.patterns.apply(graph)
        logger.debug("Replaced %s patterns", self.matched_count)

patterns instance-attribute

patterns: PatternMatcherPass = PatternMatcherPass(
    pass_name="async_tp_pass"
)

__call__

__call__(graph: Graph) -> None
Source code in vllm/compilation/passes/fusion/collective_fusion.py
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None:
    self.matched_count = self.patterns.apply(graph)
    logger.debug("Replaced %s patterns", self.matched_count)

__init__

__init__(config: VllmConfig) -> None
Source code in vllm/compilation/passes/fusion/collective_fusion.py
@enable_fake_mode
def __init__(self, config: VllmConfig) -> None:
    super().__init__(config)

    # Enable symmetric memory for the TP process group
    enable_symm_mem_for_group(get_tp_group().device_group.group_name)
    self.patterns: PatternMatcherPass = PatternMatcherPass(
        pass_name="async_tp_pass"
    )
    GEMMReduceScatterPattern(self.model_dtype, self.device).register(self.patterns)

    AllGatherGEMMPattern(self.model_dtype, self.device).register(self.patterns)

    # These fusions are enabled only for bfloat16 models because
    # `scaled_mm` or `cutlass_scaled_mm` with per-token (row-wise) scaling
    # only supports bfloat16 as the output dtype.
    if self.model_dtype == torch.bfloat16:
        ScaledMMReduceScatterPattern(self.model_dtype, self.device).register(
            self.patterns
        )
        AllGatherScaledMMPattern(self.model_dtype, self.device).register(
            self.patterns
        )

        CutlassScaledMMReduceScatterPattern(self.model_dtype, self.device).register(
            self.patterns
        )
        AllGatherCutlassScaledMMPattern(self.model_dtype, self.device).register(
            self.patterns
        )

    self.dump_patterns(config, self.patterns)

is_applicable_for_range

is_applicable_for_range(compile_range: Range) -> bool
Source code in vllm/compilation/passes/fusion/collective_fusion.py
def is_applicable_for_range(self, compile_range: Range) -> bool:
    # This pass is applied on top of the sequence parallelism pass.
    # It inherits the same applicability condition as `SequenceParallelismPass`.
    # See `SequenceParallelismPass.is_applicable` for more details.
    if (
        not self.compilation_config.splitting_ops
        or self.compilation_config.use_inductor_graph_partition
    ):
        return True
    tp_size = get_tensor_model_parallel_world_size()
    return bool(compile_range.is_single_size() and compile_range.end % tp_size == 0)

BasePattern

Source code in vllm/compilation/passes/fusion/collective_fusion.py
class BasePattern:
    def __init__(self, dtype: torch.dtype, device: str | None) -> None:
        self.dtype = dtype
        self.device = device
        self.tp = get_tp_group()
        self.tp_size = get_tensor_model_parallel_world_size()

device instance-attribute

device = device

dtype instance-attribute

dtype = dtype

tp instance-attribute

tp = get_tp_group()

tp_size instance-attribute

__init__

__init__(dtype: dtype, device: str | None) -> None
Source code in vllm/compilation/passes/fusion/collective_fusion.py
def __init__(self, dtype: torch.dtype, device: str | None) -> None:
    self.dtype = dtype
    self.device = device
    self.tp = get_tp_group()
    self.tp_size = get_tensor_model_parallel_world_size()

CutlassScaledMMReduceScatterPattern

Bases: BasePattern

Source code in vllm/compilation/passes/fusion/collective_fusion.py
class CutlassScaledMMReduceScatterPattern(BasePattern):
    def get_inputs(self) -> list[torch.Tensor]:
        input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
        mm_weight = (
            torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
            .contiguous()
            .transpose(0, 1)
        )
        scale_a = torch.empty([16, 1], device=self.device, dtype=torch.float32)
        scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)

        cutlass_mm_output = torch.empty([16, 16], device=self.device, dtype=self.dtype)
        return [input, mm_weight, scale_a, scale_b, cutlass_mm_output]

    def register(self, pm_pass: PatternMatcherPass) -> None:
        def pattern(
            input: torch.Tensor,
            weight: torch.Tensor,
            scale_a: torch.Tensor,
            scale_b: torch.Tensor,
            cutlass_mm_output: torch.Tensor,
        ) -> torch.Tensor:
            cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized(
                torch.ops._C.cutlass_scaled_mm.default,
                out=cutlass_mm_output,
                a=input,
                b=weight,
                a_scales=scale_a,
                b_scales=scale_b,
                bias=None,
            )

            reduce_scatter = torch.ops.vllm.reduce_scatter.default(
                cutlass_scaled_mm[1],
                dim=0,
                world_size=self.tp_size,
                group_name=self.tp.unique_name,
            )
            return reduce_scatter

        def replacement(
            input: torch.Tensor,
            mat2: torch.Tensor,
            scale_a: torch.Tensor,
            scale_b: torch.Tensor,
            cutlass_mm_output: torch.Tensor,
        ) -> torch.Tensor:
            # Calculate output shape: input @ mat2 with scatter_dim reduced
            output_shape = [*input.shape[:-1], mat2.shape[1]]
            scatter_dim = 0
            gemm_rs = torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter(
                input,
                mat2,
                scale_a,
                scale_b,
                "avg",
                scatter_dim,  # orig_scatter_dim
                scatter_dim,  # scatter_dim_after_maybe_reshape
                self.tp.device_group.group_name,
                output_shape,
                None,  # bias
                None,  # result_scale
                self.dtype,  # out_dtype
                False,  # use_fast_accum
            )

            return gemm_rs

        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )

get_inputs

get_inputs() -> list[Tensor]
Source code in vllm/compilation/passes/fusion/collective_fusion.py
def get_inputs(self) -> list[torch.Tensor]:
    input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
    mm_weight = (
        torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
        .contiguous()
        .transpose(0, 1)
    )
    scale_a = torch.empty([16, 1], device=self.device, dtype=torch.float32)
    scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)

    cutlass_mm_output = torch.empty([16, 16], device=self.device, dtype=self.dtype)
    return [input, mm_weight, scale_a, scale_b, cutlass_mm_output]

register

register(pm_pass: PatternMatcherPass) -> None
Source code in vllm/compilation/passes/fusion/collective_fusion.py
def register(self, pm_pass: PatternMatcherPass) -> None:
    def pattern(
        input: torch.Tensor,
        weight: torch.Tensor,
        scale_a: torch.Tensor,
        scale_b: torch.Tensor,
        cutlass_mm_output: torch.Tensor,
    ) -> torch.Tensor:
        cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized(
            torch.ops._C.cutlass_scaled_mm.default,
            out=cutlass_mm_output,
            a=input,
            b=weight,
            a_scales=scale_a,
            b_scales=scale_b,
            bias=None,
        )

        reduce_scatter = torch.ops.vllm.reduce_scatter.default(
            cutlass_scaled_mm[1],
            dim=0,
            world_size=self.tp_size,
            group_name=self.tp.unique_name,
        )
        return reduce_scatter

    def replacement(
        input: torch.Tensor,
        mat2: torch.Tensor,
        scale_a: torch.Tensor,
        scale_b: torch.Tensor,
        cutlass_mm_output: torch.Tensor,
    ) -> torch.Tensor:
        # Calculate output shape: input @ mat2 with scatter_dim reduced
        output_shape = [*input.shape[:-1], mat2.shape[1]]
        scatter_dim = 0
        gemm_rs = torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter(
            input,
            mat2,
            scale_a,
            scale_b,
            "avg",
            scatter_dim,  # orig_scatter_dim
            scatter_dim,  # scatter_dim_after_maybe_reshape
            self.tp.device_group.group_name,
            output_shape,
            None,  # bias
            None,  # result_scale
            self.dtype,  # out_dtype
            False,  # use_fast_accum
        )

        return gemm_rs

    pm.register_replacement(
        pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
    )

GEMMReduceScatterPattern

Bases: BasePattern

Source code in vllm/compilation/passes/fusion/collective_fusion.py
class GEMMReduceScatterPattern(BasePattern):
    def get_inputs(self) -> list[torch.Tensor]:
        mul = torch.empty([16, 4], device=self.device, dtype=self.dtype)
        mm_weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
        return [mul, mm_weight]

    def register(self, pm_pass: PatternMatcherPass) -> None:
        def pattern(mul: torch.Tensor, mm_weight: torch.Tensor) -> torch.Tensor:
            mm = torch.ops.aten.mm.default(mul, mm_weight)
            reduce_scatter = torch.ops.vllm.reduce_scatter.default(
                mm,
                dim=0,
                world_size=self.tp_size,
                group_name=self.tp.unique_name,
            )
            return reduce_scatter

        def replacement(mul: torch.Tensor, mm_weight: torch.Tensor) -> torch.Tensor:
            gemm_rs = torch.ops.symm_mem.fused_matmul_reduce_scatter(
                mul,
                mm_weight,
                "avg",
                scatter_dim=0,
                group_name=self.tp.device_group.group_name,
            )

            return gemm_rs

        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )

get_inputs

get_inputs() -> list[Tensor]
Source code in vllm/compilation/passes/fusion/collective_fusion.py
def get_inputs(self) -> list[torch.Tensor]:
    mul = torch.empty([16, 4], device=self.device, dtype=self.dtype)
    mm_weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
    return [mul, mm_weight]

register

register(pm_pass: PatternMatcherPass) -> None
Source code in vllm/compilation/passes/fusion/collective_fusion.py
def register(self, pm_pass: PatternMatcherPass) -> None:
    def pattern(mul: torch.Tensor, mm_weight: torch.Tensor) -> torch.Tensor:
        mm = torch.ops.aten.mm.default(mul, mm_weight)
        reduce_scatter = torch.ops.vllm.reduce_scatter.default(
            mm,
            dim=0,
            world_size=self.tp_size,
            group_name=self.tp.unique_name,
        )
        return reduce_scatter

    def replacement(mul: torch.Tensor, mm_weight: torch.Tensor) -> torch.Tensor:
        gemm_rs = torch.ops.symm_mem.fused_matmul_reduce_scatter(
            mul,
            mm_weight,
            "avg",
            scatter_dim=0,
            group_name=self.tp.device_group.group_name,
        )

        return gemm_rs

    pm.register_replacement(
        pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
    )

ScaledMMReduceScatterPattern

Bases: BasePattern

Source code in vllm/compilation/passes/fusion/collective_fusion.py
class ScaledMMReduceScatterPattern(BasePattern):
    def get_inputs(self) -> list[torch.Tensor]:
        input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
        mm_weight = (
            torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
            .contiguous()
            .transpose(0, 1)
        )
        scale_a = torch.empty([16, 1], device=self.device, dtype=torch.float32)
        scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
        return [input, mm_weight, scale_a, scale_b]

    def register(self, pm_pass: PatternMatcherPass) -> None:
        def pattern(
            input: torch.Tensor,
            mat2: torch.Tensor,
            scale_a: torch.Tensor,
            scale_b: torch.Tensor,
        ) -> torch.Tensor:
            scaled_mm = torch.ops.aten._scaled_mm.default(
                input,
                mat2=mat2,
                scale_a=scale_a,
                scale_b=scale_b,
                bias=None,
                scale_result=None,
                out_dtype=self.dtype,
            )
            reduce_scatter = torch.ops.vllm.reduce_scatter.default(
                scaled_mm,
                dim=0,
                world_size=self.tp_size,
                group_name=self.tp.unique_name,
            )
            return reduce_scatter

        def replacement(
            input: torch.Tensor,
            mat2: torch.Tensor,
            scale_a: torch.Tensor,
            scale_b: torch.Tensor,
        ) -> torch.Tensor:
            # Calculate output shape: input @ mat2 with scatter_dim reduced
            output_shape = [*input.shape[:-1], mat2.shape[1]]
            scatter_dim = 0
            gemm_rs = torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter(
                input,
                mat2,
                scale_a,
                scale_b,
                "avg",
                scatter_dim,  # orig_scatter_dim
                scatter_dim,  # scatter_dim_after_maybe_reshape
                self.tp.device_group.group_name,
                output_shape,
                None,  # bias
                None,  # result_scale
                self.dtype,  # out_dtype
                False,  # use_fast_accum
            )

            return gemm_rs

        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )

get_inputs

get_inputs() -> list[Tensor]
Source code in vllm/compilation/passes/fusion/collective_fusion.py
def get_inputs(self) -> list[torch.Tensor]:
    input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
    mm_weight = (
        torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
        .contiguous()
        .transpose(0, 1)
    )
    scale_a = torch.empty([16, 1], device=self.device, dtype=torch.float32)
    scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
    return [input, mm_weight, scale_a, scale_b]

register

register(pm_pass: PatternMatcherPass) -> None
Source code in vllm/compilation/passes/fusion/collective_fusion.py
def register(self, pm_pass: PatternMatcherPass) -> None:
    def pattern(
        input: torch.Tensor,
        mat2: torch.Tensor,
        scale_a: torch.Tensor,
        scale_b: torch.Tensor,
    ) -> torch.Tensor:
        scaled_mm = torch.ops.aten._scaled_mm.default(
            input,
            mat2=mat2,
            scale_a=scale_a,
            scale_b=scale_b,
            bias=None,
            scale_result=None,
            out_dtype=self.dtype,
        )
        reduce_scatter = torch.ops.vllm.reduce_scatter.default(
            scaled_mm,
            dim=0,
            world_size=self.tp_size,
            group_name=self.tp.unique_name,
        )
        return reduce_scatter

    def replacement(
        input: torch.Tensor,
        mat2: torch.Tensor,
        scale_a: torch.Tensor,
        scale_b: torch.Tensor,
    ) -> torch.Tensor:
        # Calculate output shape: input @ mat2 with scatter_dim reduced
        output_shape = [*input.shape[:-1], mat2.shape[1]]
        scatter_dim = 0
        gemm_rs = torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter(
            input,
            mat2,
            scale_a,
            scale_b,
            "avg",
            scatter_dim,  # orig_scatter_dim
            scatter_dim,  # scatter_dim_after_maybe_reshape
            self.tp.device_group.group_name,
            output_shape,
            None,  # bias
            None,  # result_scale
            self.dtype,  # out_dtype
            False,  # use_fast_accum
        )

        return gemm_rs

    pm.register_replacement(
        pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
    )