Skip to content

vllm.compilation.noop_elimination

logger module-attribute

logger = init_logger(__name__)

NoOpEliminationPass

Bases: VllmInductorPass

This is an inductor pass that removes redundant reshape/slice operations. It is required for RMSNorm-quant fusion to work properly. That's because apply_fp8_linear adds a reshape, which is redundant in the 2D-case. Additionally, torch internal no-op elimination pass does not handle certain slice variants.

Cases handled
  1. A chain of reshapes is equivalent to the last reshape called on the base tensor (input of the first reshape).
  2. A reshape that produces the shape of the input is redundant
  3. A slice that produces the shape of the input is redundant

Example graph 1: mul_1: "f16[s0, 4096]" = ... view_1: "f16[s0, 128, 32]" = torch.reshape(mul_1, [-1, 128, 32]) view_2: "f16[s0, 4096]" = torch.reshape(view_2, [-1, 4096]) view_3: "f16[s0, 128, 32]" = torch.reshape(view_3, [-1, 128, 32])

Can be replaced with: mul_1: "f16[s0, 4096]" = ... view_3: "f16[s0, 128, 32]" = ...

Example graph 2: getitem_1: "f16[s0, 4096]" = ... view_1: "f16[s0, 4096]" = torch.reshape(getitem_1, [-1, 4096]) at = auto_functionalized(static_scaled_fp8_quant, input = view_1, ...) out: "f8e4m3fn[s0, 4096]" = at[1]

Can be replaced with: getitem_1: "f16[s0, 4096]" = ... at = auto_functionalized(static_scaled_fp8_quant, input = getitem_1, ...) out: "f8e4m3fn[s0, 4096]" = at[1]

Example graph 3: arg0: "s0" = SymInt(s0) scaled_mm: "f16[s0, 4096]" = ... slice_1: "f16[s0, 4096]" = torch.slice(scaled_mm, -1, 0, arg0) at = auto_functionalized(fused_add_rms_norm, input = slice_1, ...) out: "f16[s0, 4096]" = torch.slice_scatter(scaled_mm, at[1], 0, 0, arg0)

Can be replaced with: arg0: "s0" = SymInt(s0) scaled_mm: "f16[s0, 4096]" = ... at = auto_functionalized(fused_add_rms_norm, input = scaled_mm, ...) out: "f16[s0, 4096]" = at[1]

TODO(luka): This is currently tested in test_fusion, but separate tests could be good.

Source code in vllm/compilation/noop_elimination.py
class NoOpEliminationPass(VllmInductorPass):
    """
    This is an inductor pass that removes redundant reshape/slice operations.
    It is required for RMSNorm-quant fusion to work properly.
    That's because apply_fp8_linear adds a reshape, which is redundant
    in the 2D-case. Additionally, torch internal no-op elimination pass does
    not handle certain slice variants.

    Cases handled:
      1. A chain of reshapes is equivalent to the last reshape called on the
      base tensor (input of the first reshape).
      2. A reshape that produces the shape of the input is redundant
      3. A slice that produces the shape of the input is redundant

    Example graph 1:
    mul_1: "f16[s0, 4096]" = ...
    view_1: "f16[s0, 128, 32]" = torch.reshape(mul_1, [-1, 128, 32])
    view_2: "f16[s0, 4096]" = torch.reshape(view_2, [-1, 4096])
    view_3: "f16[s0, 128, 32]" = torch.reshape(view_3, [-1, 128, 32])

    Can be replaced with:
    mul_1: "f16[s0, 4096]" = ...
    view_3: "f16[s0, 128, 32]" = ...

    Example graph 2:
    getitem_1: "f16[s0, 4096]" = ...
    view_1: "f16[s0, 4096]" = torch.reshape(getitem_1, [-1, 4096])
    at = auto_functionalized(static_scaled_fp8_quant, input = view_1, ...)
    out: "f8e4m3fn[s0, 4096]" = at[1]

    Can be replaced with:
    getitem_1: "f16[s0, 4096]" = ...
    at = auto_functionalized(static_scaled_fp8_quant, input = getitem_1, ...)
    out: "f8e4m3fn[s0, 4096]" = at[1]

    Example graph 3:
    arg0: "s0" = SymInt(s0)
    scaled_mm: "f16[s0, 4096]" = ...
    slice_1: "f16[s0, 4096]" = torch.slice(scaled_mm, -1, 0, arg0)
    at = auto_functionalized(fused_add_rms_norm, input = slice_1, ...)
    out: "f16[s0, 4096]" = torch.slice_scatter(scaled_mm, at[1], 0, 0, arg0)

    Can be replaced with:
    arg0: "s0" = SymInt(s0)
    scaled_mm: "f16[s0, 4096]" = ...
    at = auto_functionalized(fused_add_rms_norm, input = scaled_mm, ...)
    out: "f16[s0, 4096]" = at[1]

    TODO(luka): This is currently tested in test_fusion,
     but separate tests could be good.
    """

    def __call__(self, graph: torch.fx.Graph):
        self.begin()
        self.dump_graph(graph, "before_noop_elimination")
        count = 0
        # Remove no-op reshapes/views:
        for node in graph.nodes:
            if is_func(node, torch.ops.aten.reshape.default):
                # Case 1: rewrite reshape chains to reshapes on the base tensor
                input = node.args[0]
                # If the input is a reshape, rebind to that node
                if is_func(input, torch.ops.aten.reshape.default):
                    # The new input is guaranteed not to be a reshape,
                    # because we process nodes in order
                    node.update_arg(0, input.args[0])
                    if len(input.users) == 0:
                        graph.erase_node(input)
                        count += 1

                # Case 2: remove this reshape if it produces the original shape
                input, shape = node.args[:2]
                input_shape = input.meta["val"].shape
                if len(shape) != len(input_shape):
                    # Reshape changing rank, skip
                    continue

                if shape.count(-1) > 1:
                    # Invalid reshape args, skip
                    continue

                if self.all_dims_equivalent(shape, input_shape):
                    node.replace_all_uses_with(input)
                    graph.erase_node(node)
                    count += 1

            elif is_func(node, torch.ops.aten.slice.Tensor):
                input, dim_index, start, end = node.args[:4]
                input_shape = input.meta["val"].shape
                i_dim = input_shape[dim_index]

                if start == 0 and self.dims_equivalent(end, i_dim):
                    node.replace_all_uses_with(input)
                    graph.erase_node(node)
                    count += 1

            elif is_func(node, torch.ops.aten.slice_scatter.default):
                base, view, dim_index, start, end = node.args[:5]
                base_shape = base.meta["val"].shape
                view_shape = view.meta["val"].shape

                view_dim = view_shape[dim_index]

                # Check that view fully covers base and the full view is used
                # (if the view fully covered the base after slicing but was not
                # fully used, we could replace slice_scatter with a simple slice
                # but that's a niche case).
                if (base_shape == view_shape and start == 0
                        and self.dims_equivalent(end, view_dim)):
                    node.replace_all_uses_with(view)
                    graph.erase_node(node)
                    count += 1

        logger.debug("Removed %s no-op reshapes and slices", count)
        self.dump_graph(graph, "after_noop_elimination")
        self.end_and_log()

    def all_dims_equivalent(self, dims: Iterable[Union[int, torch.fx.Node]],
                            i_dims: Iterable[Union[int, SymInt]]):
        return all(
            self.dims_equivalent(s, i_s) for s, i_s in zip(dims, i_dims))

    def dims_equivalent(self, dim: Union[int, torch.fx.Node],
                        i_dim: Union[int, SymInt]) -> bool:
        """
        This function checks if two dimensions are equivalent.
        :param dim: The dimension arg to reshape/slice
        :param i_dim: The corresponding dimension in the input tensor
        :return: Are the dimensions equivalent?

        There are three cases in which the dimensions are equivalent:
        1. The dimensions are equal (both integers)
        2. The reshape dimension is -1 (i.e. inferred)
        3. The dimensions both correspond to the same SymInt

        While case 2 does not guarantee the dimensions are equal,
        they are equal if all other dimensions are equal.

        In case 3, the reshape dimension is a torch.fx.Node,
        and its value is a SymInt. That value is equal to the
        input dimension.

        """
        # Case 1 and 2
        if dim == i_dim or dim == -1:
            return True
        # Case 3
        return isinstance(dim, torch.fx.Node) and dim.meta["val"] == i_dim

__call__

__call__(graph: Graph)
Source code in vllm/compilation/noop_elimination.py
def __call__(self, graph: torch.fx.Graph):
    self.begin()
    self.dump_graph(graph, "before_noop_elimination")
    count = 0
    # Remove no-op reshapes/views:
    for node in graph.nodes:
        if is_func(node, torch.ops.aten.reshape.default):
            # Case 1: rewrite reshape chains to reshapes on the base tensor
            input = node.args[0]
            # If the input is a reshape, rebind to that node
            if is_func(input, torch.ops.aten.reshape.default):
                # The new input is guaranteed not to be a reshape,
                # because we process nodes in order
                node.update_arg(0, input.args[0])
                if len(input.users) == 0:
                    graph.erase_node(input)
                    count += 1

            # Case 2: remove this reshape if it produces the original shape
            input, shape = node.args[:2]
            input_shape = input.meta["val"].shape
            if len(shape) != len(input_shape):
                # Reshape changing rank, skip
                continue

            if shape.count(-1) > 1:
                # Invalid reshape args, skip
                continue

            if self.all_dims_equivalent(shape, input_shape):
                node.replace_all_uses_with(input)
                graph.erase_node(node)
                count += 1

        elif is_func(node, torch.ops.aten.slice.Tensor):
            input, dim_index, start, end = node.args[:4]
            input_shape = input.meta["val"].shape
            i_dim = input_shape[dim_index]

            if start == 0 and self.dims_equivalent(end, i_dim):
                node.replace_all_uses_with(input)
                graph.erase_node(node)
                count += 1

        elif is_func(node, torch.ops.aten.slice_scatter.default):
            base, view, dim_index, start, end = node.args[:5]
            base_shape = base.meta["val"].shape
            view_shape = view.meta["val"].shape

            view_dim = view_shape[dim_index]

            # Check that view fully covers base and the full view is used
            # (if the view fully covered the base after slicing but was not
            # fully used, we could replace slice_scatter with a simple slice
            # but that's a niche case).
            if (base_shape == view_shape and start == 0
                    and self.dims_equivalent(end, view_dim)):
                node.replace_all_uses_with(view)
                graph.erase_node(node)
                count += 1

    logger.debug("Removed %s no-op reshapes and slices", count)
    self.dump_graph(graph, "after_noop_elimination")
    self.end_and_log()

all_dims_equivalent

all_dims_equivalent(
    dims: Iterable[Union[int, Node]],
    i_dims: Iterable[Union[int, SymInt]],
)
Source code in vllm/compilation/noop_elimination.py
def all_dims_equivalent(self, dims: Iterable[Union[int, torch.fx.Node]],
                        i_dims: Iterable[Union[int, SymInt]]):
    return all(
        self.dims_equivalent(s, i_s) for s, i_s in zip(dims, i_dims))

dims_equivalent

dims_equivalent(
    dim: Union[int, Node], i_dim: Union[int, SymInt]
) -> bool

This function checks if two dimensions are equivalent. :param dim: The dimension arg to reshape/slice :param i_dim: The corresponding dimension in the input tensor :return: Are the dimensions equivalent?

There are three cases in which the dimensions are equivalent: 1. The dimensions are equal (both integers) 2. The reshape dimension is -1 (i.e. inferred) 3. The dimensions both correspond to the same SymInt

While case 2 does not guarantee the dimensions are equal, they are equal if all other dimensions are equal.

In case 3, the reshape dimension is a torch.fx.Node, and its value is a SymInt. That value is equal to the input dimension.

Source code in vllm/compilation/noop_elimination.py
def dims_equivalent(self, dim: Union[int, torch.fx.Node],
                    i_dim: Union[int, SymInt]) -> bool:
    """
    This function checks if two dimensions are equivalent.
    :param dim: The dimension arg to reshape/slice
    :param i_dim: The corresponding dimension in the input tensor
    :return: Are the dimensions equivalent?

    There are three cases in which the dimensions are equivalent:
    1. The dimensions are equal (both integers)
    2. The reshape dimension is -1 (i.e. inferred)
    3. The dimensions both correspond to the same SymInt

    While case 2 does not guarantee the dimensions are equal,
    they are equal if all other dimensions are equal.

    In case 3, the reshape dimension is a torch.fx.Node,
    and its value is a SymInt. That value is equal to the
    input dimension.

    """
    # Case 1 and 2
    if dim == i_dim or dim == -1:
        return True
    # Case 3
    return isinstance(dim, torch.fx.Node) and dim.meta["val"] == i_dim