Skip to content

vllm.model_executor.layers.quantization.compressed_tensors.schemes

Modules:

Name Description
compressed_tensors_24
compressed_tensors_scheme
compressed_tensors_w4a16_24
compressed_tensors_w4a16_nvfp4
compressed_tensors_w4a4_nvfp4
compressed_tensors_w8a16_fp8
compressed_tensors_w8a8_fp8
compressed_tensors_w8a8_int8
compressed_tensors_wNa16

W4A16SPARSE24_SUPPORTED_BITS module-attribute

W4A16SPARSE24_SUPPORTED_BITS = list(keys())

WNA16_SUPPORTED_BITS module-attribute

WNA16_SUPPORTED_BITS = list(keys())

__all__ module-attribute

__all__ = [
    "CompressedTensorsScheme",
    "CompressedTensorsWNA16",
    "CompressedTensorsW8A16Fp8",
    "CompressedTensorsW4A16Sparse24",
    "CompressedTensorsW8A8Int8",
    "CompressedTensorsW8A8Fp8",
    "WNA16_SUPPORTED_BITS",
    "W4A16SPARSE24_SUPPORTED_BITS",
    "CompressedTensors24",
    "CompressedTensorsW4A16Fp4",
    "CompressedTensorsW4A4Fp4",
]

CompressedTensors24

Bases: CompressedTensorsScheme

Source code in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py
class CompressedTensors24(CompressedTensorsScheme):

    def __init__(
        self,
        quantized: bool = False,
        weight_quant: Optional[QuantizationArgs] = None,
        input_quant: Optional[QuantizationArgs] = None,
        model_compression_config: Optional[dict[str, Any]] = None,
    ):
        self.quantized = quantized
        self.weight_quant = weight_quant
        self.input_quant = input_quant
        self.model_compressor = (
            ModelCompressor.from_compression_config(model_compression_config)
            if model_compression_config is not None else None)
        self.do_sparse_decompress = (
            self.model_compressor is not None
            and self.model_compressor.sparsity_config.format
            == CompressionFormat.sparse_24_bitmask.value)

    @classmethod
    def get_min_capability(cls) -> int:
        # Only cutlass 3.x kernels are implemented so far
        return 90

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size: int,
        output_partition_sizes: list[int],
        input_size_per_partition: int,
        params_dtype: torch.dtype,
        weight_loader: Callable,
        **kwargs,
    ):
        if not sparse_cutlass_supported():
            raise ValueError(
                "Sparse CUTLASS not supported. vLLM must be built with "
                "CUDA 12.2 or later to use this feature")

        layer.logical_widths = output_partition_sizes
        layer.input_size = input_size
        layer.input_size_per_partition = input_size_per_partition
        self.weights_dtype: torch.dtype = self._get_params_dtype(params_dtype)

        # parameter to store uncompressed weight
        weight = ModelWeightParameter(
            data=torch.empty(
                sum(output_partition_sizes),
                input_size_per_partition,
                dtype=self.weights_dtype,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
        if self.do_sparse_decompress:
            assert all(partition_size % 8 == 0
                       for partition_size in output_partition_sizes
                       ), "All partitions must be divisible by 8 for "
            "2:4 sparse compressed models"

            shape = BasevLLMParameter(
                data=torch.empty(2, 1, dtype=torch.int64),
                weight_loader=weight_loader,
            )
            compressed_weight = ModelWeightParameter(
                data=torch.empty(
                    sum(output_partition_sizes),
                    input_size_per_partition // 2,
                    dtype=self.weights_dtype,
                ),
                input_dim=1,
                output_dim=0,
                weight_loader=weight_loader,
            )

            bitmask = ModelWeightParameter(
                data=torch.empty(
                    sum(output_partition_sizes),
                    input_size_per_partition // 8,
                    dtype=torch.uint8,
                ),
                input_dim=1,
                output_dim=0,
                weight_loader=weight_loader,
            )

            layer.register_parameter("shape", shape)
            layer.register_parameter("compressed", compressed_weight)
            layer.register_parameter("bitmask", bitmask)

        # Check if quantized, not just 2:4 Sparse
        if self.quantized:
            if (self.weight_quant and self.weight_quant.strategy
                    == QuantizationStrategy.CHANNEL.value):
                weight_scale = ChannelQuantScaleParameter(
                    data=torch.empty((sum(output_partition_sizes), 1),
                                     dtype=torch.float32),
                    output_dim=0,
                    weight_loader=weight_loader,
                )
            else:
                assert (self.weight_quant and self.weight_quant.strategy
                        == QuantizationStrategy.TENSOR.value)
                weight_scale = PerTensorScaleParameter(
                    data=torch.empty(len(output_partition_sizes),
                                     dtype=torch.float32),
                    weight_loader=weight_loader,
                )

            layer.register_parameter("weight_scale", weight_scale)

            # input quant will be non-none
            if self.input_quant and not self.input_quant.dynamic:
                # register input quant scale
                assert (self.input_quant.strategy ==
                        QuantizationStrategy.TENSOR.value)
                input_scale = BasevLLMParameter(
                    data=torch.empty(1, dtype=torch.float32),
                    weight_loader=weight_loader,
                )

                layer.register_parameter("input_scale", input_scale)

        else:
            # for sparse-only, pass in 1 for weight/input scales
            weight_scale = torch.nn.Parameter(data=torch.ones(
                1, dtype=torch.float32),
                                              requires_grad=False)
            input_scale = torch.nn.Parameter(data=torch.ones(
                1, dtype=torch.float32),
                                             requires_grad=False)
            layer.register_parameter("input_scale", input_scale)
            layer.register_parameter("weight_scale", weight_scale)

        layer.register_parameter("weight", weight)

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        """
        Compress weights after loading. Store compressed weight and meta
            tensor

        :post-condition: layer.w_compressed and layer.meta are
            set to the compressed weight and meta tensor in the
            format expected by the Cutlass kernels
        :param layer: The layer with the weights to be processed

        """
        if self.do_sparse_decompress:
            layer.weight.data = self._decompress_bitmask_compressed_weight(
                compressed=layer.compressed,
                bitmask=layer.bitmask,
                layer=layer,
            )

            # compressed and bitmask tensors
            # are no longer needed after decompression
            del layer.compressed
            del layer.bitmask

        # torch.compile workaround
        if hasattr(layer, "input_scale"):
            layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
                                                   requires_grad=False)

        if self.weight_quant:
            if self.weight_quant.strategy == QuantizationStrategy.TENSOR.value:
                layer.weight_scale = torch.nn.Parameter(
                    convert_to_channelwise(
                        weight_scale=layer.weight_scale,
                        logical_widths=layer.logical_widths,
                    ),
                    requires_grad=False,
                )
            else:
                # torch.compile workaround
                layer.weight_scale = torch.nn.Parameter(
                    layer.weight_scale.data, requires_grad=False)

        # Set all negative zero values to 0 prior to compression
        if (layer.weight.dtype.is_floating_point
                and layer.weight.dtype.itemsize >= 2):
            layer.weight.data[layer.weight.data == -0.0] = 0.0

        w_compressed, meta = ops.cutlass_sparse_compress(layer.weight.data)
        layer.weight = torch.nn.Parameter(w_compressed, requires_grad=False)
        layer.meta = torch.nn.Parameter(meta, requires_grad=False)

    def apply_weights(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        Returns the output tensor for the layer with 2:4
        sparse compressed weights, given the input tensor
        and bias

        :param layer: The layer with 2:4 sparse compressed
            weights to be used for the computation
        :param x: The input tensor to the layer
        :param bias: The bias to be added to the output tensor
        :return: The output tensor of the layer
        """
        if self.quantized:
            scale = None
            if hasattr(layer, "input_scale"):
                scale = layer.input_scale

            if self.weights_dtype == torch.int8:
                ops_output = ops.scaled_int8_quant(x, scale=scale)
                q_input = ops_output[0]
                input_scale = ops_output[1]
            else:
                assert self.weights_dtype == torch.float8_e4m3fn
                if scale is not None:
                    q_input, input_scale = ops.scaled_fp8_quant(x, scale=scale)
                else:
                    q_input, input_scale = ops.scaled_fp8_quant(
                        x, use_per_token_if_dynamic=True)

        else:
            # Not quantized, nothing to do with the input_scales, use as is
            input_scale = layer.input_scale
            q_input = x

        out = ops.cutlass_scaled_sparse_mm(
            a=q_input,
            bt_nzs=layer.weight,
            bt_meta=layer.meta,
            scale_a=input_scale,
            scale_b=layer.weight_scale,
            out_dtype=x.dtype,
            bias=bias,
        )

        assert out.is_contiguous()
        return out

    def _get_params_dtype(self, params_dtype: torch.dtype) -> torch.dtype:
        if not self.quantized:
            return params_dtype

        assert self.weight_quant is not None
        assert self.input_quant is not None

        is_8_bits = self.weight_quant.num_bits == self.input_quant.num_bits == 8

        if not is_8_bits:
            raise ValueError("Cutlass only supports 8-bit quantization")

        if (self.weight_quant.type == QuantizationType.FLOAT
                and self.input_quant.type == QuantizationType.FLOAT):
            return torch.float8_e4m3fn

        if (self.weight_quant.type == QuantizationType.INT
                and self.input_quant.type == QuantizationType.INT):
            return torch.int8

        raise ValueError("Quantization type not supported by Cutlass")

    def _decompress_bitmask_compressed_weight(
        self,
        compressed: torch.Tensor,
        bitmask: torch.Tensor,
        layer: torch.nn.Module,
    ) -> torch.Tensor:
        """
        Decompress a compressed 2:4 sparse weight tensor using the bitmask and
        return the result.

        This function also supports sharded decompression.

        :param compressed: The 2:4 sparse weight tensor compressed using the
            sparse-24-bitmask compressor. This is different from
            `cutlass_sparse_compress` which uses a different scheme (2 bits for
            every nonzero element that represent the coordinate within the block
            of 4). The bitmask compression here uses a bitmask to indicate the
            positions of non-zero elements.
        :param bitmask: The 2:4 bitmask associated with the compressed weights,
            representing the positions of non-zero elements in the compressed
            tensor.
        :param layer: The layer whose weights need to be processed after 
            loading.
        :return: The decompressed 2:4 sparse weight tensor.
        """

        sparsity_compressor = self.model_compressor.sparsity_compressor

        def _process_split(
            bitmask_compressed_weight: torch.Tensor,
            shape,
            bitmask: torch.Tensor,
        ) -> torch.Tensor:
            weight_data = dict(
                compressed=bitmask_compressed_weight,
                shape=shape,
                bitmask=bitmask,
            )
            return sparsity_compressor.decompress_weight(weight_data)

        split_weights: list[torch.Tensor] = []
        split_bitmask: list[torch.Tensor] = []
        split_shape: list[tuple[int, int]] = []

        if isinstance(layer, (QKVParallelLinear, MergedColumnParallelLinear)):
            split_weights = torch.split(compressed, layer.logical_widths)
            split_bitmask = torch.split(bitmask, layer.logical_widths)
            split_shape = [(out, layer.input_size_per_partition)
                           for out in layer.logical_widths]

        if split_weights:
            decompressed_shards = [
                _process_split(compressed_weight, shape, bitmask)
                for compressed_weight, shape, bitmask in zip(
                    split_weights, split_shape, split_bitmask)
            ]
            decompressed = combine_shards(decompressed_shards)
        else:
            decompressed = sparsity_compressor.decompress_weight(
                dict(
                    compressed=compressed,
                    shape=(
                        layer.logical_widths[0],
                        layer.input_size_per_partition,
                    ),
                    bitmask=bitmask,
                ))
        return decompressed

do_sparse_decompress instance-attribute

do_sparse_decompress = (
    model_compressor is not None and format == value
)

input_quant instance-attribute

input_quant = input_quant

model_compressor instance-attribute

model_compressor = (
    from_compression_config(model_compression_config)
    if model_compression_config is not None
    else None
)

quantized instance-attribute

quantized = quantized

weight_quant instance-attribute

weight_quant = weight_quant

__init__

__init__(
    quantized: bool = False,
    weight_quant: Optional[QuantizationArgs] = None,
    input_quant: Optional[QuantizationArgs] = None,
    model_compression_config: Optional[
        dict[str, Any]
    ] = None,
)
Source code in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py
def __init__(
    self,
    quantized: bool = False,
    weight_quant: Optional[QuantizationArgs] = None,
    input_quant: Optional[QuantizationArgs] = None,
    model_compression_config: Optional[dict[str, Any]] = None,
):
    self.quantized = quantized
    self.weight_quant = weight_quant
    self.input_quant = input_quant
    self.model_compressor = (
        ModelCompressor.from_compression_config(model_compression_config)
        if model_compression_config is not None else None)
    self.do_sparse_decompress = (
        self.model_compressor is not None
        and self.model_compressor.sparsity_config.format
        == CompressionFormat.sparse_24_bitmask.value)

_decompress_bitmask_compressed_weight

_decompress_bitmask_compressed_weight(
    compressed: Tensor, bitmask: Tensor, layer: Module
) -> Tensor

Decompress a compressed 2:4 sparse weight tensor using the bitmask and return the result.

This function also supports sharded decompression.

:param compressed: The 2:4 sparse weight tensor compressed using the sparse-24-bitmask compressor. This is different from cutlass_sparse_compress which uses a different scheme (2 bits for every nonzero element that represent the coordinate within the block of 4). The bitmask compression here uses a bitmask to indicate the positions of non-zero elements. :param bitmask: The 2:4 bitmask associated with the compressed weights, representing the positions of non-zero elements in the compressed tensor. :param layer: The layer whose weights need to be processed after loading. :return: The decompressed 2:4 sparse weight tensor.

Source code in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py
def _decompress_bitmask_compressed_weight(
    self,
    compressed: torch.Tensor,
    bitmask: torch.Tensor,
    layer: torch.nn.Module,
) -> torch.Tensor:
    """
    Decompress a compressed 2:4 sparse weight tensor using the bitmask and
    return the result.

    This function also supports sharded decompression.

    :param compressed: The 2:4 sparse weight tensor compressed using the
        sparse-24-bitmask compressor. This is different from
        `cutlass_sparse_compress` which uses a different scheme (2 bits for
        every nonzero element that represent the coordinate within the block
        of 4). The bitmask compression here uses a bitmask to indicate the
        positions of non-zero elements.
    :param bitmask: The 2:4 bitmask associated with the compressed weights,
        representing the positions of non-zero elements in the compressed
        tensor.
    :param layer: The layer whose weights need to be processed after 
        loading.
    :return: The decompressed 2:4 sparse weight tensor.
    """

    sparsity_compressor = self.model_compressor.sparsity_compressor

    def _process_split(
        bitmask_compressed_weight: torch.Tensor,
        shape,
        bitmask: torch.Tensor,
    ) -> torch.Tensor:
        weight_data = dict(
            compressed=bitmask_compressed_weight,
            shape=shape,
            bitmask=bitmask,
        )
        return sparsity_compressor.decompress_weight(weight_data)

    split_weights: list[torch.Tensor] = []
    split_bitmask: list[torch.Tensor] = []
    split_shape: list[tuple[int, int]] = []

    if isinstance(layer, (QKVParallelLinear, MergedColumnParallelLinear)):
        split_weights = torch.split(compressed, layer.logical_widths)
        split_bitmask = torch.split(bitmask, layer.logical_widths)
        split_shape = [(out, layer.input_size_per_partition)
                       for out in layer.logical_widths]

    if split_weights:
        decompressed_shards = [
            _process_split(compressed_weight, shape, bitmask)
            for compressed_weight, shape, bitmask in zip(
                split_weights, split_shape, split_bitmask)
        ]
        decompressed = combine_shards(decompressed_shards)
    else:
        decompressed = sparsity_compressor.decompress_weight(
            dict(
                compressed=compressed,
                shape=(
                    layer.logical_widths[0],
                    layer.input_size_per_partition,
                ),
                bitmask=bitmask,
            ))
    return decompressed

_get_params_dtype

_get_params_dtype(params_dtype: dtype) -> dtype
Source code in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py
def _get_params_dtype(self, params_dtype: torch.dtype) -> torch.dtype:
    if not self.quantized:
        return params_dtype

    assert self.weight_quant is not None
    assert self.input_quant is not None

    is_8_bits = self.weight_quant.num_bits == self.input_quant.num_bits == 8

    if not is_8_bits:
        raise ValueError("Cutlass only supports 8-bit quantization")

    if (self.weight_quant.type == QuantizationType.FLOAT
            and self.input_quant.type == QuantizationType.FLOAT):
        return torch.float8_e4m3fn

    if (self.weight_quant.type == QuantizationType.INT
            and self.input_quant.type == QuantizationType.INT):
        return torch.int8

    raise ValueError("Quantization type not supported by Cutlass")

apply_weights

apply_weights(
    layer: Module, x: Tensor, bias: Optional[Tensor] = None
) -> Tensor

Returns the output tensor for the layer with 2:4 sparse compressed weights, given the input tensor and bias

:param layer: The layer with 2:4 sparse compressed weights to be used for the computation :param x: The input tensor to the layer :param bias: The bias to be added to the output tensor :return: The output tensor of the layer

Source code in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py
def apply_weights(
    self,
    layer: torch.nn.Module,
    x: torch.Tensor,
    bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """
    Returns the output tensor for the layer with 2:4
    sparse compressed weights, given the input tensor
    and bias

    :param layer: The layer with 2:4 sparse compressed
        weights to be used for the computation
    :param x: The input tensor to the layer
    :param bias: The bias to be added to the output tensor
    :return: The output tensor of the layer
    """
    if self.quantized:
        scale = None
        if hasattr(layer, "input_scale"):
            scale = layer.input_scale

        if self.weights_dtype == torch.int8:
            ops_output = ops.scaled_int8_quant(x, scale=scale)
            q_input = ops_output[0]
            input_scale = ops_output[1]
        else:
            assert self.weights_dtype == torch.float8_e4m3fn
            if scale is not None:
                q_input, input_scale = ops.scaled_fp8_quant(x, scale=scale)
            else:
                q_input, input_scale = ops.scaled_fp8_quant(
                    x, use_per_token_if_dynamic=True)

    else:
        # Not quantized, nothing to do with the input_scales, use as is
        input_scale = layer.input_scale
        q_input = x

    out = ops.cutlass_scaled_sparse_mm(
        a=q_input,
        bt_nzs=layer.weight,
        bt_meta=layer.meta,
        scale_a=input_scale,
        scale_b=layer.weight_scale,
        out_dtype=x.dtype,
        bias=bias,
    )

    assert out.is_contiguous()
    return out

create_weights

create_weights(
    layer: Module,
    input_size: int,
    output_partition_sizes: list[int],
    input_size_per_partition: int,
    params_dtype: dtype,
    weight_loader: Callable,
    **kwargs,
)
Source code in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py
def create_weights(
    self,
    layer: torch.nn.Module,
    input_size: int,
    output_partition_sizes: list[int],
    input_size_per_partition: int,
    params_dtype: torch.dtype,
    weight_loader: Callable,
    **kwargs,
):
    if not sparse_cutlass_supported():
        raise ValueError(
            "Sparse CUTLASS not supported. vLLM must be built with "
            "CUDA 12.2 or later to use this feature")

    layer.logical_widths = output_partition_sizes
    layer.input_size = input_size
    layer.input_size_per_partition = input_size_per_partition
    self.weights_dtype: torch.dtype = self._get_params_dtype(params_dtype)

    # parameter to store uncompressed weight
    weight = ModelWeightParameter(
        data=torch.empty(
            sum(output_partition_sizes),
            input_size_per_partition,
            dtype=self.weights_dtype,
        ),
        input_dim=1,
        output_dim=0,
        weight_loader=weight_loader,
    )
    if self.do_sparse_decompress:
        assert all(partition_size % 8 == 0
                   for partition_size in output_partition_sizes
                   ), "All partitions must be divisible by 8 for "
        "2:4 sparse compressed models"

        shape = BasevLLMParameter(
            data=torch.empty(2, 1, dtype=torch.int64),
            weight_loader=weight_loader,
        )
        compressed_weight = ModelWeightParameter(
            data=torch.empty(
                sum(output_partition_sizes),
                input_size_per_partition // 2,
                dtype=self.weights_dtype,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )

        bitmask = ModelWeightParameter(
            data=torch.empty(
                sum(output_partition_sizes),
                input_size_per_partition // 8,
                dtype=torch.uint8,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )

        layer.register_parameter("shape", shape)
        layer.register_parameter("compressed", compressed_weight)
        layer.register_parameter("bitmask", bitmask)

    # Check if quantized, not just 2:4 Sparse
    if self.quantized:
        if (self.weight_quant and self.weight_quant.strategy
                == QuantizationStrategy.CHANNEL.value):
            weight_scale = ChannelQuantScaleParameter(
                data=torch.empty((sum(output_partition_sizes), 1),
                                 dtype=torch.float32),
                output_dim=0,
                weight_loader=weight_loader,
            )
        else:
            assert (self.weight_quant and self.weight_quant.strategy
                    == QuantizationStrategy.TENSOR.value)
            weight_scale = PerTensorScaleParameter(
                data=torch.empty(len(output_partition_sizes),
                                 dtype=torch.float32),
                weight_loader=weight_loader,
            )

        layer.register_parameter("weight_scale", weight_scale)

        # input quant will be non-none
        if self.input_quant and not self.input_quant.dynamic:
            # register input quant scale
            assert (self.input_quant.strategy ==
                    QuantizationStrategy.TENSOR.value)
            input_scale = BasevLLMParameter(
                data=torch.empty(1, dtype=torch.float32),
                weight_loader=weight_loader,
            )

            layer.register_parameter("input_scale", input_scale)

    else:
        # for sparse-only, pass in 1 for weight/input scales
        weight_scale = torch.nn.Parameter(data=torch.ones(
            1, dtype=torch.float32),
                                          requires_grad=False)
        input_scale = torch.nn.Parameter(data=torch.ones(
            1, dtype=torch.float32),
                                         requires_grad=False)
        layer.register_parameter("input_scale", input_scale)
        layer.register_parameter("weight_scale", weight_scale)

    layer.register_parameter("weight", weight)

get_min_capability classmethod

get_min_capability() -> int
Source code in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py
@classmethod
def get_min_capability(cls) -> int:
    # Only cutlass 3.x kernels are implemented so far
    return 90

process_weights_after_loading

process_weights_after_loading(layer: Module) -> None

Compress weights after loading. Store compressed weight and meta tensor

:post-condition: layer.w_compressed and layer.meta are set to the compressed weight and meta tensor in the format expected by the Cutlass kernels :param layer: The layer with the weights to be processed

Source code in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
    """
    Compress weights after loading. Store compressed weight and meta
        tensor

    :post-condition: layer.w_compressed and layer.meta are
        set to the compressed weight and meta tensor in the
        format expected by the Cutlass kernels
    :param layer: The layer with the weights to be processed

    """
    if self.do_sparse_decompress:
        layer.weight.data = self._decompress_bitmask_compressed_weight(
            compressed=layer.compressed,
            bitmask=layer.bitmask,
            layer=layer,
        )

        # compressed and bitmask tensors
        # are no longer needed after decompression
        del layer.compressed
        del layer.bitmask

    # torch.compile workaround
    if hasattr(layer, "input_scale"):
        layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
                                               requires_grad=False)

    if self.weight_quant:
        if self.weight_quant.strategy == QuantizationStrategy.TENSOR.value:
            layer.weight_scale = torch.nn.Parameter(
                convert_to_channelwise(
                    weight_scale=layer.weight_scale,
                    logical_widths=layer.logical_widths,
                ),
                requires_grad=False,
            )
        else:
            # torch.compile workaround
            layer.weight_scale = torch.nn.Parameter(
                layer.weight_scale.data, requires_grad=False)

    # Set all negative zero values to 0 prior to compression
    if (layer.weight.dtype.is_floating_point
            and layer.weight.dtype.itemsize >= 2):
        layer.weight.data[layer.weight.data == -0.0] = 0.0

    w_compressed, meta = ops.cutlass_sparse_compress(layer.weight.data)
    layer.weight = torch.nn.Parameter(w_compressed, requires_grad=False)
    layer.meta = torch.nn.Parameter(meta, requires_grad=False)

CompressedTensorsScheme

Bases: ABC

Abstract class used to describe the weight creation and forward pass of different quantization schemes supported by CompressedTensors.

Source code in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py
class CompressedTensorsScheme(ABC):
    """
    Abstract class used to describe the weight creation and forward pass 
    of different quantization schemes supported by CompressedTensors.
    """

    @classmethod
    @abstractmethod
    def get_min_capability(cls) -> int:
        """
        Get minimum device capability.
        """
        raise NotImplementedError

    @abstractmethod
    def create_weights(self, *args, **kwargs):
        """
        Weight creation for the particular scheme. Inputs to this function 

        """
        raise NotImplementedError

    @abstractmethod
    def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
                      bias: Optional[torch.Tensor]):
        """
        Run the forward pass for the particular scheme. This is where 
        scheme-specific dequant/quant steps/kernels should be applied.

        :param layer: torch.nn.Module with the registered weights and 
            other parameters relevant to the particular scheme. 
        :param x: input to the layer
        :param bias: bias parameter

        """
        raise NotImplementedError

    @abstractmethod
    def process_weights_after_loading(self, layer: torch.nn.Module):
        """
        Called after weight loading is complete for any cleanup that
        needs to occur.
        """
        raise NotImplementedError

apply_weights abstractmethod

apply_weights(
    layer: Module, x: Tensor, bias: Optional[Tensor]
)

Run the forward pass for the particular scheme. This is where scheme-specific dequant/quant steps/kernels should be applied.

:param layer: torch.nn.Module with the registered weights and other parameters relevant to the particular scheme. :param x: input to the layer :param bias: bias parameter

Source code in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py
@abstractmethod
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
                  bias: Optional[torch.Tensor]):
    """
    Run the forward pass for the particular scheme. This is where 
    scheme-specific dequant/quant steps/kernels should be applied.

    :param layer: torch.nn.Module with the registered weights and 
        other parameters relevant to the particular scheme. 
    :param x: input to the layer
    :param bias: bias parameter

    """
    raise NotImplementedError

create_weights abstractmethod

create_weights(*args, **kwargs)

Weight creation for the particular scheme. Inputs to this function

Source code in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py
@abstractmethod
def create_weights(self, *args, **kwargs):
    """
    Weight creation for the particular scheme. Inputs to this function 

    """
    raise NotImplementedError

get_min_capability abstractmethod classmethod

get_min_capability() -> int

Get minimum device capability.

Source code in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py
@classmethod
@abstractmethod
def get_min_capability(cls) -> int:
    """
    Get minimum device capability.
    """
    raise NotImplementedError

process_weights_after_loading abstractmethod

process_weights_after_loading(layer: Module)

Called after weight loading is complete for any cleanup that needs to occur.

Source code in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py
@abstractmethod
def process_weights_after_loading(self, layer: torch.nn.Module):
    """
    Called after weight loading is complete for any cleanup that
    needs to occur.
    """
    raise NotImplementedError

CompressedTensorsW4A16Fp4

Bases: CompressedTensorsScheme

Source code in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py
class CompressedTensorsW4A16Fp4(CompressedTensorsScheme):

    def __init__(self, has_input_global_scale: bool = False):
        self.has_input_global_scale = has_input_global_scale
        self.group_size = 16

    @classmethod
    def get_min_capability(cls) -> int:
        # dont restrict as emulations
        return 80

    def create_weights(self, layer: torch.nn.Module,
                       output_partition_sizes: list[int],
                       input_size_per_partition: int,
                       params_dtype: torch.dtype, weight_loader: Callable,
                       **kwargs):
        output_size_per_partition = sum(output_partition_sizes)
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition

        # Weight
        weight = ModelWeightParameter(data=torch.empty(
            sum(output_partition_sizes),
            input_size_per_partition // 2,
            dtype=torch.uint8),
                                      input_dim=1,
                                      output_dim=0,
                                      weight_loader=weight_loader)
        layer.register_parameter("weight_packed", weight)

        # Global Weight Scale
        weight_global_scale = PerTensorScaleParameter(
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader)
        layer.register_parameter("weight_global_scale", weight_global_scale)

        # Per Group Weight Scale
        weight_scale = GroupQuantScaleParameter(data=torch.empty(
            sum(output_partition_sizes),
            input_size_per_partition // self.group_size,
            dtype=torch.float8_e4m3fn,
        ),
                                                input_dim=1,
                                                output_dim=0,
                                                weight_loader=weight_loader)

        layer.register_parameter("weight_scale", weight_scale)

        if self.has_input_global_scale:
            input_global_scale = PerTensorScaleParameter(
                data=torch.empty(len(output_partition_sizes),
                                 dtype=torch.float32),
                weight_loader=weight_loader)
            layer.register_parameter("input_global_scale", input_global_scale)

    def process_weights_after_loading(self, layer) -> None:
        # Process parameters for marlin repacking

        # Rename weight_packed to weight that marlin expects
        layer.weight = Parameter(layer.weight_packed.data, requires_grad=False)
        del layer.weight_packed
        # Rename weight_global_scale to weight_scale_2 that marlin expects
        # Note: ct stores the inverse of what is expected by the marlin kernel
        layer.weight_scale_2 = Parameter(
            1 / layer.weight_global_scale.max().to(torch.float32),
            requires_grad=False)
        del layer.weight_global_scale

        if self.has_input_global_scale:
            layer.input_global_scale = torch.nn.Parameter(
                layer.input_global_scale.data, requires_grad=False)

        prepare_fp4_layer_for_marlin(layer)

    def apply_weights(self,
                      layer: torch.nn.Module,
                      x: torch.Tensor,
                      bias: Optional[torch.Tensor] = None) -> torch.Tensor:
        return apply_fp4_marlin_linear(input=x,
                                       weight=layer.weight,
                                       weight_scale=layer.weight_scale,
                                       weight_scale_2=layer.weight_scale_2,
                                       workspace=layer.workspace,
                                       size_n=layer.output_size_per_partition,
                                       size_k=layer.input_size_per_partition,
                                       bias=bias)

group_size instance-attribute

group_size = 16

has_input_global_scale instance-attribute

has_input_global_scale = has_input_global_scale

__init__

__init__(has_input_global_scale: bool = False)
Source code in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py
def __init__(self, has_input_global_scale: bool = False):
    self.has_input_global_scale = has_input_global_scale
    self.group_size = 16

apply_weights

apply_weights(
    layer: Module, x: Tensor, bias: Optional[Tensor] = None
) -> Tensor
Source code in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py
def apply_weights(self,
                  layer: torch.nn.Module,
                  x: torch.Tensor,
                  bias: Optional[torch.Tensor] = None) -> torch.Tensor:
    return apply_fp4_marlin_linear(input=x,
                                   weight=layer.weight,
                                   weight_scale=layer.weight_scale,
                                   weight_scale_2=layer.weight_scale_2,
                                   workspace=layer.workspace,
                                   size_n=layer.output_size_per_partition,
                                   size_k=layer.input_size_per_partition,
                                   bias=bias)

create_weights

create_weights(
    layer: Module,
    output_partition_sizes: list[int],
    input_size_per_partition: int,
    params_dtype: dtype,
    weight_loader: Callable,
    **kwargs,
)
Source code in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py
def create_weights(self, layer: torch.nn.Module,
                   output_partition_sizes: list[int],
                   input_size_per_partition: int,
                   params_dtype: torch.dtype, weight_loader: Callable,
                   **kwargs):
    output_size_per_partition = sum(output_partition_sizes)
    layer.logical_widths = output_partition_sizes
    layer.input_size_per_partition = input_size_per_partition
    layer.output_size_per_partition = output_size_per_partition

    # Weight
    weight = ModelWeightParameter(data=torch.empty(
        sum(output_partition_sizes),
        input_size_per_partition // 2,
        dtype=torch.uint8),
                                  input_dim=1,
                                  output_dim=0,
                                  weight_loader=weight_loader)
    layer.register_parameter("weight_packed", weight)

    # Global Weight Scale
    weight_global_scale = PerTensorScaleParameter(
        data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
        weight_loader=weight_loader)
    layer.register_parameter("weight_global_scale", weight_global_scale)

    # Per Group Weight Scale
    weight_scale = GroupQuantScaleParameter(data=torch.empty(
        sum(output_partition_sizes),
        input_size_per_partition // self.group_size,
        dtype=torch.float8_e4m3fn,
    ),
                                            input_dim=1,
                                            output_dim=0,
                                            weight_loader=weight_loader)

    layer.register_parameter("weight_scale", weight_scale)

    if self.has_input_global_scale:
        input_global_scale = PerTensorScaleParameter(
            data=torch.empty(len(output_partition_sizes),
                             dtype=torch.float32),
            weight_loader=weight_loader)
        layer.register_parameter("input_global_scale", input_global_scale)

get_min_capability classmethod

get_min_capability() -> int
Source code in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py
@classmethod
def get_min_capability(cls) -> int:
    # dont restrict as emulations
    return 80

process_weights_after_loading

process_weights_after_loading(layer) -> None
Source code in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py
def process_weights_after_loading(self, layer) -> None:
    # Process parameters for marlin repacking

    # Rename weight_packed to weight that marlin expects
    layer.weight = Parameter(layer.weight_packed.data, requires_grad=False)
    del layer.weight_packed
    # Rename weight_global_scale to weight_scale_2 that marlin expects
    # Note: ct stores the inverse of what is expected by the marlin kernel
    layer.weight_scale_2 = Parameter(
        1 / layer.weight_global_scale.max().to(torch.float32),
        requires_grad=False)
    del layer.weight_global_scale

    if self.has_input_global_scale:
        layer.input_global_scale = torch.nn.Parameter(
            layer.input_global_scale.data, requires_grad=False)

    prepare_fp4_layer_for_marlin(layer)

CompressedTensorsW4A16Sparse24

Bases: CompressedTensorsScheme

Source code in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py
class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):

    def __init__(self,
                 strategy: str,
                 num_bits: int,
                 group_size: Optional[int] = None):
        self.strategy = strategy
        self.group_size = group_size
        self.tile_size = 16

        if num_bits not in W4A16SPARSE24_SUPPORTED_TYPES_MAP:
            raise ValueError(
                f"Unsupported num_bits = {num_bits}. "
                f"Supported num_bits = {W4A16SPARSE24_SUPPORTED_BITS}")

        self.quant_type = W4A16SPARSE24_SUPPORTED_TYPES_MAP[num_bits]

        if self.strategy == "group" and self.group_size is None:
            raise ValueError(
                "group_size must be given when using strategy group")

    @classmethod
    def get_min_capability(cls) -> int:
        # ampere + up
        return 80

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        # required by torch.compile to be torch.nn.Parameter
        layer.weight_packed = Parameter(layer.weight_packed.data,
                                        requires_grad=False)
        layer.scale_packed = Parameter(layer.scale_packed.data,
                                       requires_grad=False)
        layer.meta = Parameter(layer.meta.data, requires_grad=False)

    def create_weights(self, layer: torch.nn.Module, input_size: int,
                       output_partition_sizes: list[int],
                       input_size_per_partition: int,
                       params_dtype: torch.dtype, weight_loader: Callable,
                       **kwargs):

        assert params_dtype == torch.float16, (
            "float16 is required for marlin24 compressed models. Set dtype=torch.float16"  # noqa: E501
        )

        pack_factor = 32 // self.quant_type.size_bits
        output_size_per_partition = sum(output_partition_sizes)

        qweight = PackedvLLMParameter(data=torch.empty(
            input_size_per_partition // self.tile_size // 2,
            output_size_per_partition * self.tile_size // pack_factor,
            dtype=torch.int32,
        ),
                                      input_dim=0,
                                      output_dim=1,
                                      packed_dim=1,
                                      packed_factor=pack_factor,
                                      marlin_tile_size=self.tile_size,
                                      weight_loader=weight_loader)

        input_groups = (1 if self.group_size is None else
                        input_size_per_partition // self.group_size)

        weight_scale_args = {
            "data":
            torch.empty(
                input_groups,
                output_size_per_partition,
                dtype=params_dtype,
            ),
            "weight_loader":
            weight_loader
        }

        if self.group_size is not None:
            scales = GroupQuantScaleParameter(output_dim=1,
                                              input_dim=0,
                                              **weight_scale_args)
        else:
            scales = ChannelQuantScaleParameter(output_dim=1,
                                                **weight_scale_args)

        weight_shape = BasevLLMParameter(data=torch.empty(2,
                                                          dtype=torch.int64),
                                         weight_loader=weight_loader)

        meta = PackedvLLMParameter(data=torch.empty(
            input_size_per_partition // 8 // 2 // 2,
            output_size_per_partition * 2,
            dtype=torch.int16,
        ),
                                   input_dim=0,
                                   output_dim=1,
                                   packed_dim=1,
                                   packed_factor=1,
                                   marlin_tile_size=2,
                                   weight_loader=weight_loader)

        layer.register_parameter("weight_packed", qweight)
        layer.register_parameter("weight_shape", weight_shape)
        layer.register_parameter("scale_packed", scales)
        layer.register_parameter("meta", meta)

        max_workspace_size = (
            output_size_per_partition //
            GPTQ_MARLIN_24_MIN_THREAD_N) * GPTQ_MARLIN_24_MAX_PARALLEL

        workspace = Parameter(torch.zeros(max_workspace_size, dtype=torch.int),
                              requires_grad=False)
        layer.workspace = workspace

    def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
                      bias: Optional[torch.Tensor]) -> torch.Tensor:

        qweight = layer.weight_packed
        meta = layer.meta
        scales = layer.scale_packed
        workspace = layer.workspace

        x_2d = x.view(-1, x.shape[-1])

        size_m = x_2d.shape[0]
        size_k = x_2d.shape[1]
        size_n = scales.shape[1]

        output_2d = ops.gptq_marlin_24_gemm(x_2d, qweight, meta, scales,
                                            workspace, self.quant_type, size_m,
                                            size_n, size_k)

        output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))

        if bias is not None:
            output.add_(bias)  # In-place add

        return output

group_size instance-attribute

group_size = group_size

quant_type instance-attribute

quant_type = W4A16SPARSE24_SUPPORTED_TYPES_MAP[num_bits]

strategy instance-attribute

strategy = strategy

tile_size instance-attribute

tile_size = 16

__init__

__init__(
    strategy: str,
    num_bits: int,
    group_size: Optional[int] = None,
)
Source code in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py
def __init__(self,
             strategy: str,
             num_bits: int,
             group_size: Optional[int] = None):
    self.strategy = strategy
    self.group_size = group_size
    self.tile_size = 16

    if num_bits not in W4A16SPARSE24_SUPPORTED_TYPES_MAP:
        raise ValueError(
            f"Unsupported num_bits = {num_bits}. "
            f"Supported num_bits = {W4A16SPARSE24_SUPPORTED_BITS}")

    self.quant_type = W4A16SPARSE24_SUPPORTED_TYPES_MAP[num_bits]

    if self.strategy == "group" and self.group_size is None:
        raise ValueError(
            "group_size must be given when using strategy group")

apply_weights

apply_weights(
    layer: Module, x: Tensor, bias: Optional[Tensor]
) -> Tensor
Source code in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
                  bias: Optional[torch.Tensor]) -> torch.Tensor:

    qweight = layer.weight_packed
    meta = layer.meta
    scales = layer.scale_packed
    workspace = layer.workspace

    x_2d = x.view(-1, x.shape[-1])

    size_m = x_2d.shape[0]
    size_k = x_2d.shape[1]
    size_n = scales.shape[1]

    output_2d = ops.gptq_marlin_24_gemm(x_2d, qweight, meta, scales,
                                        workspace, self.quant_type, size_m,
                                        size_n, size_k)

    output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))

    if bias is not None:
        output.add_(bias)  # In-place add

    return output

create_weights

create_weights(
    layer: Module,
    input_size: int,
    output_partition_sizes: list[int],
    input_size_per_partition: int,
    params_dtype: dtype,
    weight_loader: Callable,
    **kwargs,
)
Source code in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py
def create_weights(self, layer: torch.nn.Module, input_size: int,
                   output_partition_sizes: list[int],
                   input_size_per_partition: int,
                   params_dtype: torch.dtype, weight_loader: Callable,
                   **kwargs):

    assert params_dtype == torch.float16, (
        "float16 is required for marlin24 compressed models. Set dtype=torch.float16"  # noqa: E501
    )

    pack_factor = 32 // self.quant_type.size_bits
    output_size_per_partition = sum(output_partition_sizes)

    qweight = PackedvLLMParameter(data=torch.empty(
        input_size_per_partition // self.tile_size // 2,
        output_size_per_partition * self.tile_size // pack_factor,
        dtype=torch.int32,
    ),
                                  input_dim=0,
                                  output_dim=1,
                                  packed_dim=1,
                                  packed_factor=pack_factor,
                                  marlin_tile_size=self.tile_size,
                                  weight_loader=weight_loader)

    input_groups = (1 if self.group_size is None else
                    input_size_per_partition // self.group_size)

    weight_scale_args = {
        "data":
        torch.empty(
            input_groups,
            output_size_per_partition,
            dtype=params_dtype,
        ),
        "weight_loader":
        weight_loader
    }

    if self.group_size is not None:
        scales = GroupQuantScaleParameter(output_dim=1,
                                          input_dim=0,
                                          **weight_scale_args)
    else:
        scales = ChannelQuantScaleParameter(output_dim=1,
                                            **weight_scale_args)

    weight_shape = BasevLLMParameter(data=torch.empty(2,
                                                      dtype=torch.int64),
                                     weight_loader=weight_loader)

    meta = PackedvLLMParameter(data=torch.empty(
        input_size_per_partition // 8 // 2 // 2,
        output_size_per_partition * 2,
        dtype=torch.int16,
    ),
                               input_dim=0,
                               output_dim=1,
                               packed_dim=1,
                               packed_factor=1,
                               marlin_tile_size=2,
                               weight_loader=weight_loader)

    layer.register_parameter("weight_packed", qweight)
    layer.register_parameter("weight_shape", weight_shape)
    layer.register_parameter("scale_packed", scales)
    layer.register_parameter("meta", meta)

    max_workspace_size = (
        output_size_per_partition //
        GPTQ_MARLIN_24_MIN_THREAD_N) * GPTQ_MARLIN_24_MAX_PARALLEL

    workspace = Parameter(torch.zeros(max_workspace_size, dtype=torch.int),
                          requires_grad=False)
    layer.workspace = workspace

get_min_capability classmethod

get_min_capability() -> int
Source code in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py
@classmethod
def get_min_capability(cls) -> int:
    # ampere + up
    return 80

process_weights_after_loading

process_weights_after_loading(layer: Module) -> None
Source code in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
    # required by torch.compile to be torch.nn.Parameter
    layer.weight_packed = Parameter(layer.weight_packed.data,
                                    requires_grad=False)
    layer.scale_packed = Parameter(layer.scale_packed.data,
                                   requires_grad=False)
    layer.meta = Parameter(layer.meta.data, requires_grad=False)

CompressedTensorsW4A4Fp4

Bases: CompressedTensorsScheme

Source code in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py
class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):

    def __init__(self):
        self.group_size = 16

    @classmethod
    def get_min_capability(cls) -> int:
        if envs.VLLM_USE_NVFP4_CT_EMULATIONS:
            return 80
        return 100

    def create_weights(self, layer: torch.nn.Module,
                       output_partition_sizes: list[int],
                       input_size_per_partition: int,
                       params_dtype: torch.dtype, weight_loader: Callable,
                       **kwargs):
        output_size_per_partition = sum(output_partition_sizes)
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition

        # Weight
        weight = ModelWeightParameter(data=torch.empty(
            sum(output_partition_sizes),
            input_size_per_partition // 2,
            dtype=torch.uint8),
                                      input_dim=1,
                                      output_dim=0,
                                      weight_loader=weight_loader)
        layer.register_parameter("weight_packed", weight)

        # Global Weight Scale
        weight_global_scale = PerTensorScaleParameter(
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader)
        layer.register_parameter("weight_global_scale", weight_global_scale)

        # Per Group Weight Scale
        weight_scale = GroupQuantScaleParameter(data=torch.empty(
            sum(output_partition_sizes),
            input_size_per_partition // self.group_size,
            dtype=torch.float8_e4m3fn,
        ),
                                                input_dim=1,
                                                output_dim=0,
                                                weight_loader=weight_loader)

        layer.register_parameter("weight_scale", weight_scale)

        input_global_scale = PerTensorScaleParameter(
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader)
        layer.register_parameter("input_global_scale", input_global_scale)

    def swizzle_blockscale(self, scale: torch.tensor):
        assert (scale.dtype == torch.float8_e4m3fn)
        # Pad and blockwise interleave weight_scale
        scale_ndim = scale.ndim
        if scale.ndim == 2:
            scale = scale.unsqueeze(0)
        assert scale.ndim == 3
        B, M, K = scale.shape
        round_up_multiple = lambda x, m: (x + m - 1) // m * m
        M_padded = round_up_multiple(M, 128)
        K_padded = round_up_multiple(K, 4)
        padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype)
        padded_scale[:B, :M, :K] = scale
        batches, rows, cols = padded_scale.shape
        assert rows % 128 == 0
        assert cols % 4 == 0
        padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32,
                                            cols // 4, 4)
        swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5))
        swizzled_scale = swizzled_scale.contiguous().cuda()
        return (swizzled_scale.reshape(M, K)
                if scale_ndim == 2 else swizzled_scale.reshape(B, M, K))

    def process_weights_after_loading(self, layer) -> None:

        global_input_scale = layer.input_global_scale.max().to(torch.float32)
        layer.input_global_scale = Parameter(global_input_scale,
                                             requires_grad=False)

        layer.weight_global_scale = Parameter(
            layer.weight_global_scale.max().to(torch.float32),
            requires_grad=False)

        swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale)
        layer.weight_scale_swizzled = Parameter(swizzled_weight_scale,
                                                requires_grad=False)

        # required by cutlass kernel; need Parameter, not ModelWeightParameter
        layer.weight = Parameter(layer.weight_packed.data, requires_grad=False)

        layer.alpha = Parameter(layer.input_global_scale *
                                layer.weight_global_scale,
                                requires_grad=False)

    def apply_weights(self,
                      layer: torch.nn.Module,
                      x: torch.Tensor,
                      bias: Optional[torch.Tensor] = None) -> torch.Tensor:

        if envs.VLLM_USE_NVFP4_CT_EMULATIONS:
            out = run_nvfp4_emulations(
                x=x,
                input_global_scale=layer.input_global_scale,
                weight=layer.weight,
                weight_scale_swizzled=layer.weight_scale_swizzled,
                weight_global_scale=layer.weight_global_scale)
            if bias is not None:
                out = out + bias
            return out

        output_dtype = x.dtype
        output_shape = [x.shape[0], layer.weight.shape[0]]

        # quantize BF16 or FP16 to (FP4 and interleaved block scale)
        x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_global_scale)

        out = cutlass_scaled_fp4_mm(x_fp4, layer.weight, x_blockscale,
                                    layer.weight_scale_swizzled,
                                    1 / layer.alpha, output_dtype)
        if bias is not None:
            out = out + bias
        return out.view(*output_shape)

group_size instance-attribute

group_size = 16

__init__

__init__()
Source code in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py
def __init__(self):
    self.group_size = 16

apply_weights

apply_weights(
    layer: Module, x: Tensor, bias: Optional[Tensor] = None
) -> Tensor
Source code in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py
def apply_weights(self,
                  layer: torch.nn.Module,
                  x: torch.Tensor,
                  bias: Optional[torch.Tensor] = None) -> torch.Tensor:

    if envs.VLLM_USE_NVFP4_CT_EMULATIONS:
        out = run_nvfp4_emulations(
            x=x,
            input_global_scale=layer.input_global_scale,
            weight=layer.weight,
            weight_scale_swizzled=layer.weight_scale_swizzled,
            weight_global_scale=layer.weight_global_scale)
        if bias is not None:
            out = out + bias
        return out

    output_dtype = x.dtype
    output_shape = [x.shape[0], layer.weight.shape[0]]

    # quantize BF16 or FP16 to (FP4 and interleaved block scale)
    x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_global_scale)

    out = cutlass_scaled_fp4_mm(x_fp4, layer.weight, x_blockscale,
                                layer.weight_scale_swizzled,
                                1 / layer.alpha, output_dtype)
    if bias is not None:
        out = out + bias
    return out.view(*output_shape)

create_weights

create_weights(
    layer: Module,
    output_partition_sizes: list[int],
    input_size_per_partition: int,
    params_dtype: dtype,
    weight_loader: Callable,
    **kwargs,
)
Source code in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py
def create_weights(self, layer: torch.nn.Module,
                   output_partition_sizes: list[int],
                   input_size_per_partition: int,
                   params_dtype: torch.dtype, weight_loader: Callable,
                   **kwargs):
    output_size_per_partition = sum(output_partition_sizes)
    layer.logical_widths = output_partition_sizes
    layer.input_size_per_partition = input_size_per_partition
    layer.output_size_per_partition = output_size_per_partition

    # Weight
    weight = ModelWeightParameter(data=torch.empty(
        sum(output_partition_sizes),
        input_size_per_partition // 2,
        dtype=torch.uint8),
                                  input_dim=1,
                                  output_dim=0,
                                  weight_loader=weight_loader)
    layer.register_parameter("weight_packed", weight)

    # Global Weight Scale
    weight_global_scale = PerTensorScaleParameter(
        data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
        weight_loader=weight_loader)
    layer.register_parameter("weight_global_scale", weight_global_scale)

    # Per Group Weight Scale
    weight_scale = GroupQuantScaleParameter(data=torch.empty(
        sum(output_partition_sizes),
        input_size_per_partition // self.group_size,
        dtype=torch.float8_e4m3fn,
    ),
                                            input_dim=1,
                                            output_dim=0,
                                            weight_loader=weight_loader)

    layer.register_parameter("weight_scale", weight_scale)

    input_global_scale = PerTensorScaleParameter(
        data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
        weight_loader=weight_loader)
    layer.register_parameter("input_global_scale", input_global_scale)

get_min_capability classmethod

get_min_capability() -> int
Source code in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py
@classmethod
def get_min_capability(cls) -> int:
    if envs.VLLM_USE_NVFP4_CT_EMULATIONS:
        return 80
    return 100

process_weights_after_loading

process_weights_after_loading(layer) -> None
Source code in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py
def process_weights_after_loading(self, layer) -> None:

    global_input_scale = layer.input_global_scale.max().to(torch.float32)
    layer.input_global_scale = Parameter(global_input_scale,
                                         requires_grad=False)

    layer.weight_global_scale = Parameter(
        layer.weight_global_scale.max().to(torch.float32),
        requires_grad=False)

    swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale)
    layer.weight_scale_swizzled = Parameter(swizzled_weight_scale,
                                            requires_grad=False)

    # required by cutlass kernel; need Parameter, not ModelWeightParameter
    layer.weight = Parameter(layer.weight_packed.data, requires_grad=False)

    layer.alpha = Parameter(layer.input_global_scale *
                            layer.weight_global_scale,
                            requires_grad=False)

swizzle_blockscale

swizzle_blockscale(scale: tensor)
Source code in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py
def swizzle_blockscale(self, scale: torch.tensor):
    assert (scale.dtype == torch.float8_e4m3fn)
    # Pad and blockwise interleave weight_scale
    scale_ndim = scale.ndim
    if scale.ndim == 2:
        scale = scale.unsqueeze(0)
    assert scale.ndim == 3
    B, M, K = scale.shape
    round_up_multiple = lambda x, m: (x + m - 1) // m * m
    M_padded = round_up_multiple(M, 128)
    K_padded = round_up_multiple(K, 4)
    padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype)
    padded_scale[:B, :M, :K] = scale
    batches, rows, cols = padded_scale.shape
    assert rows % 128 == 0
    assert cols % 4 == 0
    padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32,
                                        cols // 4, 4)
    swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5))
    swizzled_scale = swizzled_scale.contiguous().cuda()
    return (swizzled_scale.reshape(M, K)
            if scale_ndim == 2 else swizzled_scale.reshape(B, M, K))

CompressedTensorsW8A16Fp8

Bases: CompressedTensorsScheme

Source code in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py
class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):

    def __init__(self, strategy: str, is_static_input_scheme: bool):
        self.strategy = strategy
        self.is_static_input_scheme = is_static_input_scheme

    @classmethod
    def get_min_capability(cls) -> int:
        # ampere and up
        return 80

    # W8A8-Fp8 kernels support only per-tensor and per-channel cases.
    # So if we have a fused module (QKV, MLP) with per tensor scales,
    # we expand each scale to its shard's channels.
    def process_weights_after_loading(self, layer) -> None:
        if self.strategy == QuantizationStrategy.TENSOR:
            ws_channelwise = convert_to_channelwise(layer.weight_scale,
                                                    layer.logical_widths)
            layer.weight_scale = torch.nn.Parameter(ws_channelwise,
                                                    requires_grad=False)
        else:
            # required by torch.compile to be torch.nn.Parameter
            layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
                                                    requires_grad=False)

        # Weights must be transposed for marlin
        layer.weight = torch.nn.Parameter(layer.weight.t(),
                                          requires_grad=False)

        if self.is_static_input_scheme:
            # required by torch.compile to be torch.nn.Parameter
            layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
                                                   requires_grad=False)
        prepare_fp8_layer_for_marlin(layer)

    def create_weights(self, layer: torch.nn.Module, input_size: int,
                       output_partition_sizes: list[int],
                       input_size_per_partition: int,
                       params_dtype: torch.dtype, weight_loader: Callable,
                       **kwargs):

        output_size_per_partition = sum(output_partition_sizes)
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

        # WEIGHT
        weight = ModelWeightParameter(data=torch.empty(
            output_size_per_partition,
            input_size_per_partition,
            dtype=torch.float8_e4m3fn),
                                      input_dim=1,
                                      output_dim=0,
                                      weight_loader=weight_loader)
        layer.register_parameter("weight", weight)

        # WEIGHT SCALE
        if self.strategy == QuantizationStrategy.CHANNEL:
            weight_scale = ChannelQuantScaleParameter(
                data=torch.empty((sum(output_partition_sizes), 1),
                                 dtype=torch.float32),
                output_dim=0,
                weight_loader=weight_loader)
        elif self.strategy == QuantizationStrategy.TENSOR:
            weight_scale = PerTensorScaleParameter(data=torch.empty(
                len(output_partition_sizes), dtype=torch.float32),
                                                   weight_loader=weight_loader)
        else:
            raise ValueError(
                f"Unsupported weight strategy={self.strategy}, "
                f"supported strategies are {SUPPORTED_STRATEGIES}")

        weight_scale[:] = torch.finfo(torch.float32).min
        layer.register_parameter("weight_scale", weight_scale)

        # INPUT SCALE (to deal with converted checkpoints)
        if self.is_static_input_scheme:
            input_scale = PerTensorScaleParameter(data=torch.empty(
                len(output_partition_sizes), dtype=torch.float32),
                                                  weight_loader=weight_loader)
            layer.register_parameter("input_scale", input_scale)

    def apply_weights(self,
                      layer: torch.nn.Module,
                      x: torch.Tensor,
                      bias: Optional[torch.Tensor] = None) -> torch.Tensor:

        return apply_fp8_marlin_linear(input=x,
                                       weight=layer.weight,
                                       weight_scale=layer.weight_scale,
                                       workspace=layer.workspace,
                                       size_n=layer.output_size_per_partition,
                                       size_k=layer.input_size_per_partition,
                                       bias=bias)

is_static_input_scheme instance-attribute

is_static_input_scheme = is_static_input_scheme

strategy instance-attribute

strategy = strategy

__init__

__init__(strategy: str, is_static_input_scheme: bool)
Source code in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py
def __init__(self, strategy: str, is_static_input_scheme: bool):
    self.strategy = strategy
    self.is_static_input_scheme = is_static_input_scheme

apply_weights

apply_weights(
    layer: Module, x: Tensor, bias: Optional[Tensor] = None
) -> Tensor
Source code in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py
def apply_weights(self,
                  layer: torch.nn.Module,
                  x: torch.Tensor,
                  bias: Optional[torch.Tensor] = None) -> torch.Tensor:

    return apply_fp8_marlin_linear(input=x,
                                   weight=layer.weight,
                                   weight_scale=layer.weight_scale,
                                   workspace=layer.workspace,
                                   size_n=layer.output_size_per_partition,
                                   size_k=layer.input_size_per_partition,
                                   bias=bias)

create_weights

create_weights(
    layer: Module,
    input_size: int,
    output_partition_sizes: list[int],
    input_size_per_partition: int,
    params_dtype: dtype,
    weight_loader: Callable,
    **kwargs,
)
Source code in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py
def create_weights(self, layer: torch.nn.Module, input_size: int,
                   output_partition_sizes: list[int],
                   input_size_per_partition: int,
                   params_dtype: torch.dtype, weight_loader: Callable,
                   **kwargs):

    output_size_per_partition = sum(output_partition_sizes)
    layer.logical_widths = output_partition_sizes
    layer.input_size_per_partition = input_size_per_partition
    layer.output_size_per_partition = output_size_per_partition
    layer.orig_dtype = params_dtype
    layer.weight_block_size = None

    # WEIGHT
    weight = ModelWeightParameter(data=torch.empty(
        output_size_per_partition,
        input_size_per_partition,
        dtype=torch.float8_e4m3fn),
                                  input_dim=1,
                                  output_dim=0,
                                  weight_loader=weight_loader)
    layer.register_parameter("weight", weight)

    # WEIGHT SCALE
    if self.strategy == QuantizationStrategy.CHANNEL:
        weight_scale = ChannelQuantScaleParameter(
            data=torch.empty((sum(output_partition_sizes), 1),
                             dtype=torch.float32),
            output_dim=0,
            weight_loader=weight_loader)
    elif self.strategy == QuantizationStrategy.TENSOR:
        weight_scale = PerTensorScaleParameter(data=torch.empty(
            len(output_partition_sizes), dtype=torch.float32),
                                               weight_loader=weight_loader)
    else:
        raise ValueError(
            f"Unsupported weight strategy={self.strategy}, "
            f"supported strategies are {SUPPORTED_STRATEGIES}")

    weight_scale[:] = torch.finfo(torch.float32).min
    layer.register_parameter("weight_scale", weight_scale)

    # INPUT SCALE (to deal with converted checkpoints)
    if self.is_static_input_scheme:
        input_scale = PerTensorScaleParameter(data=torch.empty(
            len(output_partition_sizes), dtype=torch.float32),
                                              weight_loader=weight_loader)
        layer.register_parameter("input_scale", input_scale)

get_min_capability classmethod

get_min_capability() -> int
Source code in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py
@classmethod
def get_min_capability(cls) -> int:
    # ampere and up
    return 80

process_weights_after_loading

process_weights_after_loading(layer) -> None
Source code in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py
def process_weights_after_loading(self, layer) -> None:
    if self.strategy == QuantizationStrategy.TENSOR:
        ws_channelwise = convert_to_channelwise(layer.weight_scale,
                                                layer.logical_widths)
        layer.weight_scale = torch.nn.Parameter(ws_channelwise,
                                                requires_grad=False)
    else:
        # required by torch.compile to be torch.nn.Parameter
        layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
                                                requires_grad=False)

    # Weights must be transposed for marlin
    layer.weight = torch.nn.Parameter(layer.weight.t(),
                                      requires_grad=False)

    if self.is_static_input_scheme:
        # required by torch.compile to be torch.nn.Parameter
        layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
                                               requires_grad=False)
    prepare_fp8_layer_for_marlin(layer)

CompressedTensorsW8A8Fp8

Bases: CompressedTensorsScheme

Source code in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):

    def __init__(self, strategy: str, is_static_input_scheme: bool):
        self.strategy = strategy
        self.out_dtype = torch.get_default_dtype()
        self.is_static_input_scheme = is_static_input_scheme
        self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True)

    @classmethod
    def get_min_capability(cls) -> int:
        # lovelace and up
        return 89

    def process_weights_after_loading(self, layer) -> None:
        # If per tensor, when we have a fused module (e.g. QKV) with per
        # tensor scales (thus N scales being passed to the kernel),
        # requantize so we can always run per tensor
        if self.strategy == QuantizationStrategy.TENSOR:
            max_w_scale, weight = requantize_with_max_scale(
                weight=layer.weight,
                weight_scale=layer.weight_scale,
                logical_widths=layer.logical_widths,
            )

            if current_platform.is_fp8_fnuz():
                input_scale = getattr(layer, 'input_scale', None)

                weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
                    weight=weight,
                    weight_scale=max_w_scale,
                    input_scale=input_scale)
                if input_scale is not None:
                    layer.input_scale = Parameter(input_scale,
                                                  requires_grad=False)

            layer.weight = Parameter(weight.t(), requires_grad=False)
            layer.weight_scale = Parameter(max_w_scale, requires_grad=False)

        # If channelwise, scales are already lined up, so just transpose.
        elif self.strategy == QuantizationStrategy.CHANNEL:
            weight = layer.weight

            if current_platform.is_fp8_fnuz():
                input_scale = getattr(layer, 'input_scale', None)

                weight, weight_scale, input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        weight=weight,
                        weight_scale=layer.weight_scale,
                        input_scale=input_scale)
                if input_scale is not None:
                    layer.input_scale = Parameter(input_scale,
                                                  requires_grad=False)
            else:
                weight_scale = layer.weight_scale.data

            layer.weight = Parameter(weight.t(), requires_grad=False)
            # required by torch.compile to be torch.nn.Parameter
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)

        else:
            raise ValueError(f"Unknown quantization strategy {self.strategy}")

        # INPUT SCALE
        if self.is_static_input_scheme and hasattr(layer, 'input_scale'):
            layer.input_scale = Parameter(layer.input_scale.max(),
                                          requires_grad=False)
        else:
            layer.input_scale = None

    def create_weights(self, layer: torch.nn.Module,
                       output_partition_sizes: list[int],
                       input_size_per_partition: int,
                       params_dtype: torch.dtype, weight_loader: Callable,
                       **kwargs):
        maybe_create_device_identity()

        output_size_per_partition = sum(output_partition_sizes)
        layer.logical_widths = output_partition_sizes

        # WEIGHT
        weight = ModelWeightParameter(data=torch.empty(
            output_size_per_partition,
            input_size_per_partition,
            dtype=torch.float8_e4m3fn),
                                      input_dim=1,
                                      output_dim=0,
                                      weight_loader=weight_loader)
        layer.register_parameter("weight", weight)

        # WEIGHT SCALE
        # TODO: update create_xxx_parameter functions to return
        # the newly added parameters
        if self.strategy == QuantizationStrategy.CHANNEL:
            weight_scale = ChannelQuantScaleParameter(
                data=torch.empty((sum(output_partition_sizes), 1),
                                 dtype=torch.float32),
                output_dim=0,
                weight_loader=weight_loader)
        else:
            assert self.strategy == QuantizationStrategy.TENSOR
            weight_scale = PerTensorScaleParameter(data=torch.empty(
                len(output_partition_sizes), dtype=torch.float32),
                                                   weight_loader=weight_loader)

        # min requirement for fp8 kernels
        weight_scale[:] = torch.finfo(torch.float32).min
        layer.register_parameter("weight_scale", weight_scale)

        # INPUT SCALE
        if self.is_static_input_scheme:
            input_scale = PerTensorScaleParameter(data=torch.empty(
                len(output_partition_sizes), dtype=torch.float32),
                                                  weight_loader=weight_loader)
            input_scale[:] = torch.finfo(torch.float32).min
            layer.register_parameter("input_scale", input_scale)

    def apply_weights(self,
                      layer: torch.nn.Module,
                      x: torch.Tensor,
                      bias: Optional[torch.Tensor] = None) -> torch.Tensor:

        return self.fp8_linear.apply(input=x,
                                     weight=layer.weight,
                                     weight_scale=layer.weight_scale,
                                     out_dtype=self.out_dtype,
                                     input_scale=layer.input_scale,
                                     bias=bias)

fp8_linear instance-attribute

fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True)

is_static_input_scheme instance-attribute

is_static_input_scheme = is_static_input_scheme

out_dtype instance-attribute

out_dtype = get_default_dtype()

strategy instance-attribute

strategy = strategy

__init__

__init__(strategy: str, is_static_input_scheme: bool)
Source code in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
def __init__(self, strategy: str, is_static_input_scheme: bool):
    self.strategy = strategy
    self.out_dtype = torch.get_default_dtype()
    self.is_static_input_scheme = is_static_input_scheme
    self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True)

apply_weights

apply_weights(
    layer: Module, x: Tensor, bias: Optional[Tensor] = None
) -> Tensor
Source code in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
def apply_weights(self,
                  layer: torch.nn.Module,
                  x: torch.Tensor,
                  bias: Optional[torch.Tensor] = None) -> torch.Tensor:

    return self.fp8_linear.apply(input=x,
                                 weight=layer.weight,
                                 weight_scale=layer.weight_scale,
                                 out_dtype=self.out_dtype,
                                 input_scale=layer.input_scale,
                                 bias=bias)

create_weights

create_weights(
    layer: Module,
    output_partition_sizes: list[int],
    input_size_per_partition: int,
    params_dtype: dtype,
    weight_loader: Callable,
    **kwargs,
)
Source code in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
def create_weights(self, layer: torch.nn.Module,
                   output_partition_sizes: list[int],
                   input_size_per_partition: int,
                   params_dtype: torch.dtype, weight_loader: Callable,
                   **kwargs):
    maybe_create_device_identity()

    output_size_per_partition = sum(output_partition_sizes)
    layer.logical_widths = output_partition_sizes

    # WEIGHT
    weight = ModelWeightParameter(data=torch.empty(
        output_size_per_partition,
        input_size_per_partition,
        dtype=torch.float8_e4m3fn),
                                  input_dim=1,
                                  output_dim=0,
                                  weight_loader=weight_loader)
    layer.register_parameter("weight", weight)

    # WEIGHT SCALE
    # TODO: update create_xxx_parameter functions to return
    # the newly added parameters
    if self.strategy == QuantizationStrategy.CHANNEL:
        weight_scale = ChannelQuantScaleParameter(
            data=torch.empty((sum(output_partition_sizes), 1),
                             dtype=torch.float32),
            output_dim=0,
            weight_loader=weight_loader)
    else:
        assert self.strategy == QuantizationStrategy.TENSOR
        weight_scale = PerTensorScaleParameter(data=torch.empty(
            len(output_partition_sizes), dtype=torch.float32),
                                               weight_loader=weight_loader)

    # min requirement for fp8 kernels
    weight_scale[:] = torch.finfo(torch.float32).min
    layer.register_parameter("weight_scale", weight_scale)

    # INPUT SCALE
    if self.is_static_input_scheme:
        input_scale = PerTensorScaleParameter(data=torch.empty(
            len(output_partition_sizes), dtype=torch.float32),
                                              weight_loader=weight_loader)
        input_scale[:] = torch.finfo(torch.float32).min
        layer.register_parameter("input_scale", input_scale)

get_min_capability classmethod

get_min_capability() -> int
Source code in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
@classmethod
def get_min_capability(cls) -> int:
    # lovelace and up
    return 89

process_weights_after_loading

process_weights_after_loading(layer) -> None
Source code in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
def process_weights_after_loading(self, layer) -> None:
    # If per tensor, when we have a fused module (e.g. QKV) with per
    # tensor scales (thus N scales being passed to the kernel),
    # requantize so we can always run per tensor
    if self.strategy == QuantizationStrategy.TENSOR:
        max_w_scale, weight = requantize_with_max_scale(
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            logical_widths=layer.logical_widths,
        )

        if current_platform.is_fp8_fnuz():
            input_scale = getattr(layer, 'input_scale', None)

            weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
                weight=weight,
                weight_scale=max_w_scale,
                input_scale=input_scale)
            if input_scale is not None:
                layer.input_scale = Parameter(input_scale,
                                              requires_grad=False)

        layer.weight = Parameter(weight.t(), requires_grad=False)
        layer.weight_scale = Parameter(max_w_scale, requires_grad=False)

    # If channelwise, scales are already lined up, so just transpose.
    elif self.strategy == QuantizationStrategy.CHANNEL:
        weight = layer.weight

        if current_platform.is_fp8_fnuz():
            input_scale = getattr(layer, 'input_scale', None)

            weight, weight_scale, input_scale = \
                normalize_e4m3fn_to_e4m3fnuz(
                    weight=weight,
                    weight_scale=layer.weight_scale,
                    input_scale=input_scale)
            if input_scale is not None:
                layer.input_scale = Parameter(input_scale,
                                              requires_grad=False)
        else:
            weight_scale = layer.weight_scale.data

        layer.weight = Parameter(weight.t(), requires_grad=False)
        # required by torch.compile to be torch.nn.Parameter
        layer.weight_scale = Parameter(weight_scale, requires_grad=False)

    else:
        raise ValueError(f"Unknown quantization strategy {self.strategy}")

    # INPUT SCALE
    if self.is_static_input_scheme and hasattr(layer, 'input_scale'):
        layer.input_scale = Parameter(layer.input_scale.max(),
                                      requires_grad=False)
    else:
        layer.input_scale = None

CompressedTensorsW8A8Int8

Bases: CompressedTensorsScheme

Source code in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
    _kernel_backends_being_used: set[str] = set()

    def __init__(self, strategy: str, is_static_input_scheme: bool,
                 input_symmetric: bool):
        self.strategy = strategy
        self.is_static_input_scheme = is_static_input_scheme
        self.input_symmetric = input_symmetric

    @classmethod
    def get_min_capability(cls) -> int:
        # turing and up
        return 75

    def create_weights(self, layer: torch.nn.Module,
                       output_partition_sizes: list[int],
                       input_size_per_partition: int,
                       params_dtype: torch.dtype, weight_loader: Callable,
                       **kwargs):
        layer.logical_widths = output_partition_sizes

        scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig(
            is_channelwise=(self.strategy == QuantizationStrategy.CHANNEL),
            is_static_input_scheme=self.is_static_input_scheme,
            input_symmetric=self.input_symmetric)

        kernel_type = choose_scaled_mm_linear_kernel(
            scaled_mm_linear_kernel_config)

        if kernel_type.__name__ not in self._kernel_backends_being_used:
            logger.info("Using %s for CompressedTensorsW8A8Int8",
                        kernel_type.__name__)
            self._kernel_backends_being_used.add(kernel_type.__name__)

        # WEIGHT
        weight = ModelWeightParameter(data=torch.empty(
            sum(output_partition_sizes),
            input_size_per_partition,
            dtype=torch.int8),
                                      input_dim=1,
                                      output_dim=0,
                                      weight_loader=weight_loader)

        layer.register_parameter("weight", weight)

        # WEIGHT SCALE
        if self.strategy == QuantizationStrategy.CHANNEL:
            weight_scale = ChannelQuantScaleParameter(
                data=torch.empty((sum(output_partition_sizes), 1),
                                 dtype=torch.float32),
                output_dim=0,
                weight_loader=weight_loader)
        else:
            assert self.strategy == QuantizationStrategy.TENSOR
            weight_scale = PerTensorScaleParameter(data=torch.empty(
                len(output_partition_sizes), dtype=torch.float32),
                                                   weight_loader=weight_loader)
        layer.register_parameter("weight_scale", weight_scale)

        # INPUT SCALE
        if self.is_static_input_scheme:
            input_scale = BasevLLMParameter(data=torch.empty(
                1, dtype=torch.float32),
                                            weight_loader=weight_loader)
            layer.register_parameter("input_scale", input_scale)

            if not self.input_symmetric:
                # Note: compressed-tensors stores the zp using the same dtype
                # as the weights
                # AZP loaded as int8 but used as int32
                input_zero_point = BasevLLMParameter(
                    data=torch.empty(1, dtype=torch.int8),
                    weight_loader=weight_loader)
                layer.register_parameter("input_zero_point", input_zero_point)

        self.kernel = kernel_type(c=scaled_mm_linear_kernel_config,
                                  w_q_param_name="weight",
                                  w_s_param_name="weight_scale",
                                  i_s_param_name="input_scale",
                                  i_zp_param_name="input_zero_point",
                                  azp_adj_param_name="azp_adj")

    # Checkpoints are serialized in compressed-tensors format, which is
    # different from the format the kernel may want. Handle repacking here.
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        self.kernel.process_weights_after_loading(layer)

    def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
                      bias: Optional[torch.Tensor]) -> torch.Tensor:
        return self.kernel.apply_weights(layer, x, bias)

_kernel_backends_being_used class-attribute instance-attribute

_kernel_backends_being_used: set[str] = set()

input_symmetric instance-attribute

input_symmetric = input_symmetric

is_static_input_scheme instance-attribute

is_static_input_scheme = is_static_input_scheme

strategy instance-attribute

strategy = strategy

__init__

__init__(
    strategy: str,
    is_static_input_scheme: bool,
    input_symmetric: bool,
)
Source code in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
def __init__(self, strategy: str, is_static_input_scheme: bool,
             input_symmetric: bool):
    self.strategy = strategy
    self.is_static_input_scheme = is_static_input_scheme
    self.input_symmetric = input_symmetric

apply_weights

apply_weights(
    layer: Module, x: Tensor, bias: Optional[Tensor]
) -> Tensor
Source code in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
                  bias: Optional[torch.Tensor]) -> torch.Tensor:
    return self.kernel.apply_weights(layer, x, bias)

create_weights

create_weights(
    layer: Module,
    output_partition_sizes: list[int],
    input_size_per_partition: int,
    params_dtype: dtype,
    weight_loader: Callable,
    **kwargs,
)
Source code in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
def create_weights(self, layer: torch.nn.Module,
                   output_partition_sizes: list[int],
                   input_size_per_partition: int,
                   params_dtype: torch.dtype, weight_loader: Callable,
                   **kwargs):
    layer.logical_widths = output_partition_sizes

    scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig(
        is_channelwise=(self.strategy == QuantizationStrategy.CHANNEL),
        is_static_input_scheme=self.is_static_input_scheme,
        input_symmetric=self.input_symmetric)

    kernel_type = choose_scaled_mm_linear_kernel(
        scaled_mm_linear_kernel_config)

    if kernel_type.__name__ not in self._kernel_backends_being_used:
        logger.info("Using %s for CompressedTensorsW8A8Int8",
                    kernel_type.__name__)
        self._kernel_backends_being_used.add(kernel_type.__name__)

    # WEIGHT
    weight = ModelWeightParameter(data=torch.empty(
        sum(output_partition_sizes),
        input_size_per_partition,
        dtype=torch.int8),
                                  input_dim=1,
                                  output_dim=0,
                                  weight_loader=weight_loader)

    layer.register_parameter("weight", weight)

    # WEIGHT SCALE
    if self.strategy == QuantizationStrategy.CHANNEL:
        weight_scale = ChannelQuantScaleParameter(
            data=torch.empty((sum(output_partition_sizes), 1),
                             dtype=torch.float32),
            output_dim=0,
            weight_loader=weight_loader)
    else:
        assert self.strategy == QuantizationStrategy.TENSOR
        weight_scale = PerTensorScaleParameter(data=torch.empty(
            len(output_partition_sizes), dtype=torch.float32),
                                               weight_loader=weight_loader)
    layer.register_parameter("weight_scale", weight_scale)

    # INPUT SCALE
    if self.is_static_input_scheme:
        input_scale = BasevLLMParameter(data=torch.empty(
            1, dtype=torch.float32),
                                        weight_loader=weight_loader)
        layer.register_parameter("input_scale", input_scale)

        if not self.input_symmetric:
            # Note: compressed-tensors stores the zp using the same dtype
            # as the weights
            # AZP loaded as int8 but used as int32
            input_zero_point = BasevLLMParameter(
                data=torch.empty(1, dtype=torch.int8),
                weight_loader=weight_loader)
            layer.register_parameter("input_zero_point", input_zero_point)

    self.kernel = kernel_type(c=scaled_mm_linear_kernel_config,
                              w_q_param_name="weight",
                              w_s_param_name="weight_scale",
                              i_s_param_name="input_scale",
                              i_zp_param_name="input_zero_point",
                              azp_adj_param_name="azp_adj")

get_min_capability classmethod

get_min_capability() -> int
Source code in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
@classmethod
def get_min_capability(cls) -> int:
    # turing and up
    return 75

process_weights_after_loading

process_weights_after_loading(layer: Module) -> None
Source code in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
    self.kernel.process_weights_after_loading(layer)

CompressedTensorsWNA16

Bases: CompressedTensorsScheme

Source code in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py
class CompressedTensorsWNA16(CompressedTensorsScheme):
    _kernel_backends_being_used: set[str] = set()

    def __init__(self,
                 strategy: str,
                 num_bits: int,
                 group_size: Optional[int] = None,
                 symmetric: Optional[bool] = True,
                 actorder: Optional[ActivationOrdering] = None):

        self.pack_factor = 32 // num_bits
        self.strategy = strategy
        self.symmetric = symmetric
        self.group_size = -1 if group_size is None else group_size
        self.has_g_idx = actorder == ActivationOrdering.GROUP

        if self.group_size == -1 and self.strategy != "channel":
            raise ValueError("Marlin kernels require group quantization or "
                             "channelwise quantization, but found no group "
                             "size and strategy is not channelwise.")

        if num_bits not in WNA16_SUPPORTED_TYPES_MAP:
            raise ValueError(
                f"Unsupported num_bits = {num_bits}. "
                f"Supported num_bits = {WNA16_SUPPORTED_TYPES_MAP.keys()}")

        self.quant_type = (WNA16_ZP_SUPPORTED_TYPES_MAP[num_bits]
                           if not self.symmetric else
                           WNA16_SUPPORTED_TYPES_MAP[num_bits])

    @classmethod
    def get_min_capability(cls) -> int:
        # ampere and up
        return 80

    def create_weights(self, layer: torch.nn.Module, output_size: int,
                       input_size: int, output_partition_sizes: list[int],
                       input_size_per_partition: int,
                       params_dtype: torch.dtype, weight_loader: Callable,
                       **kwargs):

        output_size_per_partition = sum(output_partition_sizes)

        mp_linear_kernel_config = MPLinearLayerConfig(
            full_weight_shape=(input_size, output_size),
            partition_weight_shape=\
                (input_size_per_partition, output_size_per_partition),
            weight_type=self.quant_type,
            act_type=params_dtype,
            group_size=self.group_size,
            zero_points=not self.symmetric,
            has_g_idx=self.has_g_idx
        )

        kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config)

        if kernel_type.__name__ not in self._kernel_backends_being_used:
            logger.info("Using %s for CompressedTensorsWNA16",
                        kernel_type.__name__)
            self._kernel_backends_being_used.add(kernel_type.__name__)

        # If group_size is -1, we are in channelwise case.
        group_size = self.group_size if self.group_size != -1 else input_size
        row_parallel = (input_size != input_size_per_partition)
        partition_scales = not marlin_repeat_scales_on_all_ranks(
            self.has_g_idx, self.group_size, row_parallel)

        scales_and_zp_size = input_size // group_size

        if partition_scales:
            assert input_size_per_partition % group_size == 0
            scales_and_zp_size = input_size_per_partition // group_size

        weight = PackedvLLMParameter(input_dim=1,
                                     output_dim=0,
                                     weight_loader=weight_loader,
                                     packed_factor=self.pack_factor,
                                     packed_dim=1,
                                     data=torch.empty(
                                         output_size_per_partition,
                                         input_size_per_partition //
                                         self.pack_factor,
                                         dtype=torch.int32,
                                     ))

        weight_scale_args = {
            "weight_loader":
            weight_loader,
            "data":
            torch.empty(
                output_size_per_partition,
                scales_and_zp_size,
                dtype=params_dtype,
            )
        }

        zeros_args = {
            "weight_loader":
            weight_loader,
            "data":
            torch.zeros(
                output_size_per_partition // self.pack_factor,
                scales_and_zp_size,
                dtype=torch.int32,
            )
        }

        if not partition_scales:
            weight_scale = ChannelQuantScaleParameter(output_dim=0,
                                                      **weight_scale_args)

            if not self.symmetric:
                qzeros = PackedColumnParameter(output_dim=0,
                                               packed_dim=0,
                                               packed_factor=self.pack_factor,
                                               **zeros_args)
        else:
            weight_scale = GroupQuantScaleParameter(output_dim=0,
                                                    input_dim=1,
                                                    **weight_scale_args)
            if not self.symmetric:
                qzeros = PackedvLLMParameter(input_dim=1,
                                             output_dim=0,
                                             packed_dim=0,
                                             packed_factor=self.pack_factor,
                                             **zeros_args)

        # A 2D array defining the original shape of the weights
        # before packing
        weight_shape = BasevLLMParameter(data=torch.empty(2,
                                                          dtype=torch.int64),
                                         weight_loader=weight_loader)

        layer.register_parameter("weight_packed", weight)
        layer.register_parameter("weight_scale", weight_scale)
        layer.register_parameter("weight_shape", weight_shape)

        if not self.symmetric:
            layer.register_parameter("weight_zero_point", qzeros)

        # group index (for activation reordering)
        if self.has_g_idx:
            weight_g_idx = RowvLLMParameter(data=torch.empty(
                input_size_per_partition,
                dtype=torch.int32,
            ),
                                            input_dim=0,
                                            weight_loader=weight_loader)
            layer.register_parameter("weight_g_idx", weight_g_idx)

        self.kernel = kernel_type(mp_linear_kernel_config,
                                  w_q_param_name="weight_packed",
                                  w_s_param_name="weight_scale",
                                  w_zp_param_name="weight_zero_point",
                                  w_gidx_param_name="weight_g_idx")

    # Checkpoints are serialized in compressed-tensors format, which is
    # different from the format the kernel may want. Handle repacking here.
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        self.kernel.process_weights_after_loading(layer)

    def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
                      bias: Optional[torch.Tensor]) -> torch.Tensor:
        return self.kernel.apply_weights(layer, x, bias)

_kernel_backends_being_used class-attribute instance-attribute

_kernel_backends_being_used: set[str] = set()

group_size instance-attribute

group_size = -1 if group_size is None else group_size

has_g_idx instance-attribute

has_g_idx = actorder == GROUP

pack_factor instance-attribute

pack_factor = 32 // num_bits

quant_type instance-attribute

quant_type = (
    WNA16_ZP_SUPPORTED_TYPES_MAP[num_bits]
    if not symmetric
    else WNA16_SUPPORTED_TYPES_MAP[num_bits]
)

strategy instance-attribute

strategy = strategy

symmetric instance-attribute

symmetric = symmetric

__init__

__init__(
    strategy: str,
    num_bits: int,
    group_size: Optional[int] = None,
    symmetric: Optional[bool] = True,
    actorder: Optional[ActivationOrdering] = None,
)
Source code in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py
def __init__(self,
             strategy: str,
             num_bits: int,
             group_size: Optional[int] = None,
             symmetric: Optional[bool] = True,
             actorder: Optional[ActivationOrdering] = None):

    self.pack_factor = 32 // num_bits
    self.strategy = strategy
    self.symmetric = symmetric
    self.group_size = -1 if group_size is None else group_size
    self.has_g_idx = actorder == ActivationOrdering.GROUP

    if self.group_size == -1 and self.strategy != "channel":
        raise ValueError("Marlin kernels require group quantization or "
                         "channelwise quantization, but found no group "
                         "size and strategy is not channelwise.")

    if num_bits not in WNA16_SUPPORTED_TYPES_MAP:
        raise ValueError(
            f"Unsupported num_bits = {num_bits}. "
            f"Supported num_bits = {WNA16_SUPPORTED_TYPES_MAP.keys()}")

    self.quant_type = (WNA16_ZP_SUPPORTED_TYPES_MAP[num_bits]
                       if not self.symmetric else
                       WNA16_SUPPORTED_TYPES_MAP[num_bits])

apply_weights

apply_weights(
    layer: Module, x: Tensor, bias: Optional[Tensor]
) -> Tensor
Source code in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
                  bias: Optional[torch.Tensor]) -> torch.Tensor:
    return self.kernel.apply_weights(layer, x, bias)

create_weights

create_weights(
    layer: Module,
    output_size: int,
    input_size: int,
    output_partition_sizes: list[int],
    input_size_per_partition: int,
    params_dtype: dtype,
    weight_loader: Callable,
    **kwargs,
)
Source code in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py
def create_weights(self, layer: torch.nn.Module, output_size: int,
                   input_size: int, output_partition_sizes: list[int],
                   input_size_per_partition: int,
                   params_dtype: torch.dtype, weight_loader: Callable,
                   **kwargs):

    output_size_per_partition = sum(output_partition_sizes)

    mp_linear_kernel_config = MPLinearLayerConfig(
        full_weight_shape=(input_size, output_size),
        partition_weight_shape=\
            (input_size_per_partition, output_size_per_partition),
        weight_type=self.quant_type,
        act_type=params_dtype,
        group_size=self.group_size,
        zero_points=not self.symmetric,
        has_g_idx=self.has_g_idx
    )

    kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config)

    if kernel_type.__name__ not in self._kernel_backends_being_used:
        logger.info("Using %s for CompressedTensorsWNA16",
                    kernel_type.__name__)
        self._kernel_backends_being_used.add(kernel_type.__name__)

    # If group_size is -1, we are in channelwise case.
    group_size = self.group_size if self.group_size != -1 else input_size
    row_parallel = (input_size != input_size_per_partition)
    partition_scales = not marlin_repeat_scales_on_all_ranks(
        self.has_g_idx, self.group_size, row_parallel)

    scales_and_zp_size = input_size // group_size

    if partition_scales:
        assert input_size_per_partition % group_size == 0
        scales_and_zp_size = input_size_per_partition // group_size

    weight = PackedvLLMParameter(input_dim=1,
                                 output_dim=0,
                                 weight_loader=weight_loader,
                                 packed_factor=self.pack_factor,
                                 packed_dim=1,
                                 data=torch.empty(
                                     output_size_per_partition,
                                     input_size_per_partition //
                                     self.pack_factor,
                                     dtype=torch.int32,
                                 ))

    weight_scale_args = {
        "weight_loader":
        weight_loader,
        "data":
        torch.empty(
            output_size_per_partition,
            scales_and_zp_size,
            dtype=params_dtype,
        )
    }

    zeros_args = {
        "weight_loader":
        weight_loader,
        "data":
        torch.zeros(
            output_size_per_partition // self.pack_factor,
            scales_and_zp_size,
            dtype=torch.int32,
        )
    }

    if not partition_scales:
        weight_scale = ChannelQuantScaleParameter(output_dim=0,
                                                  **weight_scale_args)

        if not self.symmetric:
            qzeros = PackedColumnParameter(output_dim=0,
                                           packed_dim=0,
                                           packed_factor=self.pack_factor,
                                           **zeros_args)
    else:
        weight_scale = GroupQuantScaleParameter(output_dim=0,
                                                input_dim=1,
                                                **weight_scale_args)
        if not self.symmetric:
            qzeros = PackedvLLMParameter(input_dim=1,
                                         output_dim=0,
                                         packed_dim=0,
                                         packed_factor=self.pack_factor,
                                         **zeros_args)

    # A 2D array defining the original shape of the weights
    # before packing
    weight_shape = BasevLLMParameter(data=torch.empty(2,
                                                      dtype=torch.int64),
                                     weight_loader=weight_loader)

    layer.register_parameter("weight_packed", weight)
    layer.register_parameter("weight_scale", weight_scale)
    layer.register_parameter("weight_shape", weight_shape)

    if not self.symmetric:
        layer.register_parameter("weight_zero_point", qzeros)

    # group index (for activation reordering)
    if self.has_g_idx:
        weight_g_idx = RowvLLMParameter(data=torch.empty(
            input_size_per_partition,
            dtype=torch.int32,
        ),
                                        input_dim=0,
                                        weight_loader=weight_loader)
        layer.register_parameter("weight_g_idx", weight_g_idx)

    self.kernel = kernel_type(mp_linear_kernel_config,
                              w_q_param_name="weight_packed",
                              w_s_param_name="weight_scale",
                              w_zp_param_name="weight_zero_point",
                              w_gidx_param_name="weight_g_idx")

get_min_capability classmethod

get_min_capability() -> int
Source code in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py
@classmethod
def get_min_capability(cls) -> int:
    # ampere and up
    return 80

process_weights_after_loading

process_weights_after_loading(layer: Module) -> None
Source code in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
    self.kernel.process_weights_after_loading(layer)