Skip to content

vllm.compilation.activation_quant_fusion

logger module-attribute

logger = init_logger(__name__)

ActivationQuantFusionPass

Bases: VllmInductorPass

This pass fuses a pre-defined set of custom ops into fused ops. It uses the torch pattern matcher to find the patterns and replace them.

Because patterns can only be registered once, the pass is a singleton. This will be addressed in a future version of PyTorch: https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980

Source code in vllm/compilation/activation_quant_fusion.py
class ActivationQuantFusionPass(VllmInductorPass):
    """
    This pass fuses a pre-defined set of custom ops into fused ops.
    It uses the torch pattern matcher to find the patterns and replace them.

    Because patterns can only be registered once, the pass is a singleton.
    This will be addressed in a future version of PyTorch:
    https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
    """

    def __init__(self, config: VllmConfig):
        super().__init__(config)

        self.patterns: PatternMatcherPass = PatternMatcherPass(
            pass_name="activation_quant_fusion_pass")

        inputs = [
            empty_fp8(5, 4),  # Quant output
            empty_bf16(5, 4),  # Silu_and_mul output
            empty_bf16(5, 4),  # Input
            empty_fp32(1, 1)  # Scale
        ]
        register_replacement(silu_mul_pattern_static,
                             silu_mul_replacement_static, inputs, fwd_only,
                             self.patterns)

    def __call__(self, graph: torch.fx.Graph):
        self.begin()
        self.dump_graph(graph, "before_act_quant_fusion")

        count = self.patterns.apply(graph)
        logger.debug("Replaced %s patterns in ActivationQuantFusionPass",
                     count)

        self.dump_graph(graph, "after_act_quant_fusion")
        self.end_and_log()

patterns instance-attribute

patterns: PatternMatcherPass = PatternMatcherPass(
    pass_name="activation_quant_fusion_pass"
)

__call__

__call__(graph: Graph)
Source code in vllm/compilation/activation_quant_fusion.py
def __call__(self, graph: torch.fx.Graph):
    self.begin()
    self.dump_graph(graph, "before_act_quant_fusion")

    count = self.patterns.apply(graph)
    logger.debug("Replaced %s patterns in ActivationQuantFusionPass",
                 count)

    self.dump_graph(graph, "after_act_quant_fusion")
    self.end_and_log()

__init__

__init__(config: VllmConfig)
Source code in vllm/compilation/activation_quant_fusion.py
def __init__(self, config: VllmConfig):
    super().__init__(config)

    self.patterns: PatternMatcherPass = PatternMatcherPass(
        pass_name="activation_quant_fusion_pass")

    inputs = [
        empty_fp8(5, 4),  # Quant output
        empty_bf16(5, 4),  # Silu_and_mul output
        empty_bf16(5, 4),  # Input
        empty_fp32(1, 1)  # Scale
    ]
    register_replacement(silu_mul_pattern_static,
                         silu_mul_replacement_static, inputs, fwd_only,
                         self.patterns)

empty_bf16

empty_bf16(*args, **kwargs)
Source code in vllm/compilation/activation_quant_fusion.py
def empty_bf16(*args, **kwargs):
    return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda")

empty_fp32

empty_fp32(*args, **kwargs)
Source code in vllm/compilation/activation_quant_fusion.py
def empty_fp32(*args, **kwargs):
    return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda")

empty_fp8

empty_fp8(*args, **kwargs)
Source code in vllm/compilation/activation_quant_fusion.py
def empty_fp8(*args, **kwargs):
    fp8 = current_platform.fp8_dtype()
    return torch.empty(*args, **kwargs, dtype=fp8, device="cuda")

silu_mul_pattern_static

silu_mul_pattern_static(
    result: Tensor,
    result_silu_mul: Tensor,
    input: Tensor,
    scale: Tensor,
)
Source code in vllm/compilation/activation_quant_fusion.py
def silu_mul_pattern_static(result: torch.Tensor,
                            result_silu_mul: torch.Tensor, input: torch.Tensor,
                            scale: torch.Tensor):
    at1 = auto_functionalized(torch.ops._C.silu_and_mul.default,
                              result=result_silu_mul,
                              input=input)
    at2 = auto_functionalized(torch.ops._C.static_scaled_fp8_quant.default,
                              result=result,
                              input=at1[1],
                              scale=scale)
    return at2[1]

silu_mul_replacement_static

silu_mul_replacement_static(
    result: Tensor,
    result_silu_mul: Tensor,
    input: Tensor,
    scale: Tensor,
)
Source code in vllm/compilation/activation_quant_fusion.py
def silu_mul_replacement_static(result: torch.Tensor,
                                result_silu_mul: torch.Tensor,
                                input: torch.Tensor, scale: torch.Tensor):
    at = auto_functionalized(torch.ops._C.silu_and_mul_quant.default,
                             result=result,
                             input=input,
                             scale=scale)
    return at[1]