Skip to content

vllm.compilation.fix_functionalization

logger module-attribute

logger = init_logger(__name__)

FixFunctionalizationPass

Bases: VllmInductorPass

This pass defunctionalizes certain nodes to avoid redundant tensor copies. After this pass, DCE (dead-code elimination) should never be run, as de-functionalized nodes may appear as dead code.

To add new nodes to defunctionalize, add to the if-elif chain in call.

Source code in vllm/compilation/fix_functionalization.py
class FixFunctionalizationPass(VllmInductorPass):
    """
    This pass defunctionalizes certain nodes to avoid redundant tensor copies.
    After this pass, DCE (dead-code elimination) should never be run,
    as de-functionalized nodes may appear as dead code.

    To add new nodes to defunctionalize, add to the if-elif chain in __call__.
    """

    def __call__(self, graph: torch.fx.Graph):
        self.begin()
        self.dump_graph(graph, "before_fix_functionalization")

        self.nodes_to_remove: list[torch.fx.Node] = []
        count = 0
        for node in graph.nodes:
            if not is_func(node, auto_functionalized):
                continue  # Avoid deep if-elif nesting

            kwargs = node.kwargs
            at_target = node.args[0]

            if at_target == torch.ops._C.rotary_embedding.default:
                query = kwargs['query']
                mm_node = query.args[0].args[0]

                # rotary_embedding is a special case: the two mutating inputs
                # are query and key, which are slices of mm_node.
                # While functionalized, results at[1] and at[2] are scattered
                # back into mm_node. After de-functionalization, we can just
                # use mm_node directly.
                for idx, user in self.getitem_users(node).items():
                    for user_of_getitem in user.users:
                        if is_func(user_of_getitem,
                                   torch.ops.aten.slice_scatter.default):
                            user_of_getitem.replace_all_uses_with(mm_node)
                            self._remove(user_of_getitem)
                    self._remove(user)

                self.insert_defunctionalized(graph, node)
                self._remove(node)

            # rms_norm replacements avoid the most copies for LLaMa.
            elif at_target == torch.ops._C.fused_add_rms_norm.default:
                mutated_args = {1: 'input', 2: 'residual'}
                self.defunctionalize(graph, node, mutated_args)
            elif at_target == torch.ops._C.fused_add_rms_norm_static_fp8_quant.default:  # noqa: E501
                mutated_args = {1: 'result', 2: 'residual'}
                self.defunctionalize(graph, node, mutated_args)
            elif at_target == torch.ops._C.rms_norm_dynamic_per_token_quant.default:  # noqa: E501
                mutated_args = {1: 'result', 2: 'scale', 3: 'residual'}
                self.defunctionalize(graph, node, mutated_args)
            elif at_target in [
                    torch.ops._C.rms_norm.default,
                    torch.ops._C.rms_norm_static_fp8_quant.default,
            ]:
                mutated_args = {1: 'result'}
                self.defunctionalize(graph, node, mutated_args)
            # For some reason we need to specify the args for both
            # silu_and_mul and silu_and_mul_quant. The kwargs
            # pathway gets the wrong answer.
            elif at_target == torch.ops._C.silu_and_mul.default:
                mutated_args = {1: 'result'}
                self.defunctionalize(graph,
                                     node,
                                     mutated_args,
                                     args=('result', 'input'))
            elif at_target == torch.ops._C.silu_and_mul_quant.default:
                mutated_args = {1: 'result'}
                self.defunctionalize(graph,
                                     node,
                                     mutated_args,
                                     args=('result', 'input', 'scale'))
            else:
                continue  # skip the count

            count += 1

        self.dump_graph(graph, "before_fix_functionalization_cleanup")

        # Remove the nodes all at once
        count_removed = len(self.nodes_to_remove)
        for node in self.nodes_to_remove:
            graph.erase_node(node)

        logger.debug("De-functionalized %s nodes, removed %s nodes", count,
                     count_removed)
        self.dump_graph(graph, "after_fix_functionalization")
        self.end_and_log()

    def _remove(self, node_or_nodes: Union[torch.fx.Node,
                                           Iterable[torch.fx.Node]]):
        """
        Stage a node (or nodes) for removal at the end of the pass.
        """
        if isinstance(node_or_nodes, torch.fx.Node):
            self.nodes_to_remove.append(node_or_nodes)
        else:
            self.nodes_to_remove.extend(node_or_nodes)

    def defunctionalize(self,
                        graph: torch.fx.Graph,
                        node: torch.fx.Node,
                        mutated_args: dict[int, Union[torch.fx.Node, str]],
                        args: Optional[tuple[Union[torch.fx.Node, str],
                                             ...]] = None):
        """
        De-functionalize a node by replacing it with a call to the original.
        It also replaces the getitem users with the mutated arguments.
        See replace_users_with_mutated_args and insert_defunctionalized.
        """
        self.replace_users_with_mutated_args(node, mutated_args)
        self.insert_defunctionalized(graph, node, args=args)
        self._remove(node)

    def replace_users_with_mutated_args(self, node: torch.fx.Node,
                                        mutated_args: dict[int,
                                                           Union[torch.fx.Node,
                                                                 str]]):
        """
        Replace all getitem users of the auto-functionalized node with the
        mutated arguments.
        :param node: The auto-functionalized node
        :param mutated_args: The mutated arguments, indexed by getitem index.
        If the value of an arg is a string, `node.kwargs[arg]` is used.
        """
        for idx, user in self.getitem_users(node).items():
            arg = mutated_args[idx]
            arg = node.kwargs[arg] if isinstance(arg, str) else arg
            user.replace_all_uses_with(arg)
            self._remove(user)

    def getitem_users(self, node: torch.fx.Node) -> dict[int, torch.fx.Node]:
        """
        Returns the operator.getitem users of the auto-functionalized node,
        indexed by the index they are getting.
        """
        users = {}
        for user in node.users:
            if is_func(user, operator.getitem):
                idx = user.args[1]
                users[idx] = user
        return users

    def insert_defunctionalized(self,
                                graph: torch.fx.Graph,
                                node: torch.fx.Node,
                                args: Optional[tuple[Union[torch.fx.Node, str],
                                                     ...]] = None):
        """
        Insert a new defunctionalized node into the graph before node.
        If one of the kwargs is 'out', provide args directly,
        as node.kwargs cannot be used.
        See https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351

        :param graph: Graph to insert the defunctionalized node into
        :param node: The auto-functionalized node to defunctionalize
        :param args: If we cannot use kwargs, specify args directly.
        If an arg is a string, `node.kwargs[arg]` is used.
        """  # noqa: E501
        assert is_func(node, auto_functionalized), \
            f"node must be auto-functionalized, is {node} instead"

        # Create a new call to the original function
        with graph.inserting_before(node):
            function = node.args[0]
            if args is None:
                graph.call_function(function, kwargs=node.kwargs)
            else:
                # Args passed as strings refer to items in node.kwargs
                args = tuple(node.kwargs[arg] if isinstance(arg, str) else arg
                             for arg in args)
                graph.call_function(function, args=args)

__call__

__call__(graph: Graph)
Source code in vllm/compilation/fix_functionalization.py
def __call__(self, graph: torch.fx.Graph):
    self.begin()
    self.dump_graph(graph, "before_fix_functionalization")

    self.nodes_to_remove: list[torch.fx.Node] = []
    count = 0
    for node in graph.nodes:
        if not is_func(node, auto_functionalized):
            continue  # Avoid deep if-elif nesting

        kwargs = node.kwargs
        at_target = node.args[0]

        if at_target == torch.ops._C.rotary_embedding.default:
            query = kwargs['query']
            mm_node = query.args[0].args[0]

            # rotary_embedding is a special case: the two mutating inputs
            # are query and key, which are slices of mm_node.
            # While functionalized, results at[1] and at[2] are scattered
            # back into mm_node. After de-functionalization, we can just
            # use mm_node directly.
            for idx, user in self.getitem_users(node).items():
                for user_of_getitem in user.users:
                    if is_func(user_of_getitem,
                               torch.ops.aten.slice_scatter.default):
                        user_of_getitem.replace_all_uses_with(mm_node)
                        self._remove(user_of_getitem)
                self._remove(user)

            self.insert_defunctionalized(graph, node)
            self._remove(node)

        # rms_norm replacements avoid the most copies for LLaMa.
        elif at_target == torch.ops._C.fused_add_rms_norm.default:
            mutated_args = {1: 'input', 2: 'residual'}
            self.defunctionalize(graph, node, mutated_args)
        elif at_target == torch.ops._C.fused_add_rms_norm_static_fp8_quant.default:  # noqa: E501
            mutated_args = {1: 'result', 2: 'residual'}
            self.defunctionalize(graph, node, mutated_args)
        elif at_target == torch.ops._C.rms_norm_dynamic_per_token_quant.default:  # noqa: E501
            mutated_args = {1: 'result', 2: 'scale', 3: 'residual'}
            self.defunctionalize(graph, node, mutated_args)
        elif at_target in [
                torch.ops._C.rms_norm.default,
                torch.ops._C.rms_norm_static_fp8_quant.default,
        ]:
            mutated_args = {1: 'result'}
            self.defunctionalize(graph, node, mutated_args)
        # For some reason we need to specify the args for both
        # silu_and_mul and silu_and_mul_quant. The kwargs
        # pathway gets the wrong answer.
        elif at_target == torch.ops._C.silu_and_mul.default:
            mutated_args = {1: 'result'}
            self.defunctionalize(graph,
                                 node,
                                 mutated_args,
                                 args=('result', 'input'))
        elif at_target == torch.ops._C.silu_and_mul_quant.default:
            mutated_args = {1: 'result'}
            self.defunctionalize(graph,
                                 node,
                                 mutated_args,
                                 args=('result', 'input', 'scale'))
        else:
            continue  # skip the count

        count += 1

    self.dump_graph(graph, "before_fix_functionalization_cleanup")

    # Remove the nodes all at once
    count_removed = len(self.nodes_to_remove)
    for node in self.nodes_to_remove:
        graph.erase_node(node)

    logger.debug("De-functionalized %s nodes, removed %s nodes", count,
                 count_removed)
    self.dump_graph(graph, "after_fix_functionalization")
    self.end_and_log()

_remove

_remove(node_or_nodes: Union[Node, Iterable[Node]])

Stage a node (or nodes) for removal at the end of the pass.

Source code in vllm/compilation/fix_functionalization.py
def _remove(self, node_or_nodes: Union[torch.fx.Node,
                                       Iterable[torch.fx.Node]]):
    """
    Stage a node (or nodes) for removal at the end of the pass.
    """
    if isinstance(node_or_nodes, torch.fx.Node):
        self.nodes_to_remove.append(node_or_nodes)
    else:
        self.nodes_to_remove.extend(node_or_nodes)

defunctionalize

defunctionalize(
    graph: Graph,
    node: Node,
    mutated_args: dict[int, Union[Node, str]],
    args: Optional[tuple[Union[Node, str], ...]] = None,
)

De-functionalize a node by replacing it with a call to the original. It also replaces the getitem users with the mutated arguments. See replace_users_with_mutated_args and insert_defunctionalized.

Source code in vllm/compilation/fix_functionalization.py
def defunctionalize(self,
                    graph: torch.fx.Graph,
                    node: torch.fx.Node,
                    mutated_args: dict[int, Union[torch.fx.Node, str]],
                    args: Optional[tuple[Union[torch.fx.Node, str],
                                         ...]] = None):
    """
    De-functionalize a node by replacing it with a call to the original.
    It also replaces the getitem users with the mutated arguments.
    See replace_users_with_mutated_args and insert_defunctionalized.
    """
    self.replace_users_with_mutated_args(node, mutated_args)
    self.insert_defunctionalized(graph, node, args=args)
    self._remove(node)

getitem_users

getitem_users(node: Node) -> dict[int, Node]

Returns the operator.getitem users of the auto-functionalized node, indexed by the index they are getting.

Source code in vllm/compilation/fix_functionalization.py
def getitem_users(self, node: torch.fx.Node) -> dict[int, torch.fx.Node]:
    """
    Returns the operator.getitem users of the auto-functionalized node,
    indexed by the index they are getting.
    """
    users = {}
    for user in node.users:
        if is_func(user, operator.getitem):
            idx = user.args[1]
            users[idx] = user
    return users

insert_defunctionalized

insert_defunctionalized(
    graph: Graph,
    node: Node,
    args: Optional[tuple[Union[Node, str], ...]] = None,
)

Insert a new defunctionalized node into the graph before node. If one of the kwargs is 'out', provide args directly, as node.kwargs cannot be used. See https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351

:param graph: Graph to insert the defunctionalized node into :param node: The auto-functionalized node to defunctionalize :param args: If we cannot use kwargs, specify args directly. If an arg is a string, node.kwargs[arg] is used.

Source code in vllm/compilation/fix_functionalization.py
def insert_defunctionalized(self,
                            graph: torch.fx.Graph,
                            node: torch.fx.Node,
                            args: Optional[tuple[Union[torch.fx.Node, str],
                                                 ...]] = None):
    """
    Insert a new defunctionalized node into the graph before node.
    If one of the kwargs is 'out', provide args directly,
    as node.kwargs cannot be used.
    See https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351

    :param graph: Graph to insert the defunctionalized node into
    :param node: The auto-functionalized node to defunctionalize
    :param args: If we cannot use kwargs, specify args directly.
    If an arg is a string, `node.kwargs[arg]` is used.
    """  # noqa: E501
    assert is_func(node, auto_functionalized), \
        f"node must be auto-functionalized, is {node} instead"

    # Create a new call to the original function
    with graph.inserting_before(node):
        function = node.args[0]
        if args is None:
            graph.call_function(function, kwargs=node.kwargs)
        else:
            # Args passed as strings refer to items in node.kwargs
            args = tuple(node.kwargs[arg] if isinstance(arg, str) else arg
                         for arg in args)
            graph.call_function(function, args=args)

replace_users_with_mutated_args

replace_users_with_mutated_args(
    node: Node, mutated_args: dict[int, Union[Node, str]]
)

Replace all getitem users of the auto-functionalized node with the mutated arguments. :param node: The auto-functionalized node :param mutated_args: The mutated arguments, indexed by getitem index. If the value of an arg is a string, node.kwargs[arg] is used.

Source code in vllm/compilation/fix_functionalization.py
def replace_users_with_mutated_args(self, node: torch.fx.Node,
                                    mutated_args: dict[int,
                                                       Union[torch.fx.Node,
                                                             str]]):
    """
    Replace all getitem users of the auto-functionalized node with the
    mutated arguments.
    :param node: The auto-functionalized node
    :param mutated_args: The mutated arguments, indexed by getitem index.
    If the value of an arg is a string, `node.kwargs[arg]` is used.
    """
    for idx, user in self.getitem_users(node).items():
        arg = mutated_args[idx]
        arg = node.kwargs[arg] if isinstance(arg, str) else arg
        user.replace_all_uses_with(arg)
        self._remove(user)