Skip to content

vllm.compilation.collective_fusion

logger module-attribute

logger = init_logger(__name__)

AllGatherGEMMPattern

Bases: BasePattern

Source code in vllm/compilation/collective_fusion.py
class AllGatherGEMMPattern(BasePattern):

    def get_inputs(self):
        x = torch.empty([4, 4], device=self.device, dtype=self.dtype)
        weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)

        return [x, weight]

    def register(self, pm_pass: PatternMatcherPass):

        def pattern(
            x: torch.Tensor,
            weight: torch.Tensor,
        ) -> tuple[torch.Tensor, torch.Tensor]:
            all_gather = torch.ops.vllm.all_gather.default(
                x,
                dim=0,
                world_size=self.tp_size,
                group_name=self.tp.unique_name)

            return torch.ops.aten.mm.default(all_gather, weight)

        def replacement(
                x: torch.Tensor,
                weight: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
            ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_matmul(
                x,
                [weight],
                gather_dim=0,
                group_name=self.tp.device_group.group_name,
            )
            return mm_outputs

        pm.register_replacement(pattern, replacement, self.get_inputs(),
                                pm.fwd_only, pm_pass)

get_inputs

get_inputs()
Source code in vllm/compilation/collective_fusion.py
def get_inputs(self):
    x = torch.empty([4, 4], device=self.device, dtype=self.dtype)
    weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)

    return [x, weight]

register

register(pm_pass: PatternMatcherPass)
Source code in vllm/compilation/collective_fusion.py
def register(self, pm_pass: PatternMatcherPass):

    def pattern(
        x: torch.Tensor,
        weight: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        all_gather = torch.ops.vllm.all_gather.default(
            x,
            dim=0,
            world_size=self.tp_size,
            group_name=self.tp.unique_name)

        return torch.ops.aten.mm.default(all_gather, weight)

    def replacement(
            x: torch.Tensor,
            weight: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_matmul(
            x,
            [weight],
            gather_dim=0,
            group_name=self.tp.device_group.group_name,
        )
        return mm_outputs

    pm.register_replacement(pattern, replacement, self.get_inputs(),
                            pm.fwd_only, pm_pass)

AsyncTPPass

Bases: VllmInductorPass

Source code in vllm/compilation/collective_fusion.py
class AsyncTPPass(VllmInductorPass):

    def __init__(self, config: VllmConfig):
        super().__init__(config)

        # Enable symmetric memory for the TP process group
        enable_symm_mem_for_group(get_tp_group().device_group.group_name)
        self.patterns: PatternMatcherPass = PatternMatcherPass(
            pass_name="async_tp_pass")
        GEMMReduceScatterPattern(self.model_dtype,
                                 self.device).register(self.patterns)

        AllGatherGEMMPattern(self.model_dtype,
                             self.device).register(self.patterns)

    def is_applicable_for_shape(self, shape: Optional[int]) -> bool:
        # only do replace for specific shapes
        tp_size = get_tensor_model_parallel_world_size()
        return shape is not None and shape % tp_size == 0

    def __call__(self, graph: fx.Graph):
        self.begin()
        self.dump_graph(graph, "before_async_tp_pass")
        count = self.patterns.apply(graph)
        logger.debug("Replaced %s patterns", count)
        self.dump_graph(graph, "after_async_tp_pass")
        self.end_and_log()

patterns instance-attribute

patterns: PatternMatcherPass = PatternMatcherPass(
    pass_name="async_tp_pass"
)

__call__

__call__(graph: Graph)
Source code in vllm/compilation/collective_fusion.py
def __call__(self, graph: fx.Graph):
    self.begin()
    self.dump_graph(graph, "before_async_tp_pass")
    count = self.patterns.apply(graph)
    logger.debug("Replaced %s patterns", count)
    self.dump_graph(graph, "after_async_tp_pass")
    self.end_and_log()

__init__

__init__(config: VllmConfig)
Source code in vllm/compilation/collective_fusion.py
def __init__(self, config: VllmConfig):
    super().__init__(config)

    # Enable symmetric memory for the TP process group
    enable_symm_mem_for_group(get_tp_group().device_group.group_name)
    self.patterns: PatternMatcherPass = PatternMatcherPass(
        pass_name="async_tp_pass")
    GEMMReduceScatterPattern(self.model_dtype,
                             self.device).register(self.patterns)

    AllGatherGEMMPattern(self.model_dtype,
                         self.device).register(self.patterns)

is_applicable_for_shape

is_applicable_for_shape(shape: Optional[int]) -> bool
Source code in vllm/compilation/collective_fusion.py
def is_applicable_for_shape(self, shape: Optional[int]) -> bool:
    # only do replace for specific shapes
    tp_size = get_tensor_model_parallel_world_size()
    return shape is not None and shape % tp_size == 0

BasePattern

Source code in vllm/compilation/collective_fusion.py
class BasePattern:

    def __init__(self, dtype: torch.dtype, device: str):
        self.dtype = dtype
        self.device = device
        self.tp = get_tp_group()
        self.tp_size = get_tensor_model_parallel_world_size()

device instance-attribute

device = device

dtype instance-attribute

dtype = dtype

tp instance-attribute

tp = get_tp_group()

tp_size instance-attribute

__init__

__init__(dtype: dtype, device: str)
Source code in vllm/compilation/collective_fusion.py
def __init__(self, dtype: torch.dtype, device: str):
    self.dtype = dtype
    self.device = device
    self.tp = get_tp_group()
    self.tp_size = get_tensor_model_parallel_world_size()

GEMMReduceScatterPattern

Bases: BasePattern

Source code in vllm/compilation/collective_fusion.py
class GEMMReduceScatterPattern(BasePattern):

    def get_inputs(self):
        mul = torch.empty([16, 4], device=self.device, dtype=self.dtype)
        mm_weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
        return [mul, mm_weight]

    def register(self, pm_pass: PatternMatcherPass):

        def pattern(mul: torch.Tensor, mm_weight: torch.Tensor):
            mm = torch.ops.aten.mm.default(mul, mm_weight)
            reduce_scatter = torch.ops.vllm.reduce_scatter.default(
                mm,
                dim=0,
                world_size=self.tp_size,
                group_name=self.tp.unique_name)
            return reduce_scatter

        def replacement(mul: torch.Tensor, mm_weight: torch.Tensor):
            gemm_rs = torch.ops.symm_mem.fused_matmul_reduce_scatter(
                mul,
                mm_weight,
                "avg",
                scatter_dim=0,
                group_name=self.tp.device_group.group_name,
            )

            return gemm_rs

        pm.register_replacement(pattern, replacement, self.get_inputs(),
                                pm.fwd_only, pm_pass)

get_inputs

get_inputs()
Source code in vllm/compilation/collective_fusion.py
def get_inputs(self):
    mul = torch.empty([16, 4], device=self.device, dtype=self.dtype)
    mm_weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
    return [mul, mm_weight]

register

register(pm_pass: PatternMatcherPass)
Source code in vllm/compilation/collective_fusion.py
def register(self, pm_pass: PatternMatcherPass):

    def pattern(mul: torch.Tensor, mm_weight: torch.Tensor):
        mm = torch.ops.aten.mm.default(mul, mm_weight)
        reduce_scatter = torch.ops.vllm.reduce_scatter.default(
            mm,
            dim=0,
            world_size=self.tp_size,
            group_name=self.tp.unique_name)
        return reduce_scatter

    def replacement(mul: torch.Tensor, mm_weight: torch.Tensor):
        gemm_rs = torch.ops.symm_mem.fused_matmul_reduce_scatter(
            mul,
            mm_weight,
            "avg",
            scatter_dim=0,
            group_name=self.tp.device_group.group_name,
        )

        return gemm_rs

    pm.register_replacement(pattern, replacement, self.get_inputs(),
                            pm.fwd_only, pm_pass)