Skip to content

vllm.model_executor.layers.quantization.tpu_int8

ACTIVATION_SCHEMES module-attribute

ACTIVATION_SCHEMES = ['none']

Int8TpuConfig

Bases: QuantizationConfig

Int8 Quantization Config class for TPU Backend.

Source code in vllm/model_executor/layers/quantization/tpu_int8.py
class Int8TpuConfig(QuantizationConfig):
    """Int8 Quantization Config class for TPU Backend."""

    def __init__(
        self,
        activation_scheme: str = "none",
    ) -> None:
        super().__init__()
        if activation_scheme not in ACTIVATION_SCHEMES:
            raise ValueError(
                f"Unsupported activation scheme {activation_scheme}")
        self.activation_scheme = activation_scheme

    def get_name(self) -> QuantizationMethods:
        return "tpu_int8"

    def get_supported_act_dtypes(self) -> list[torch.dtype]:
        return [torch.float16, torch.bfloat16]

    @classmethod
    def get_min_capability(cls) -> int:
        raise NotImplementedError(
            "This function should not be called with TPU Backend")

    @staticmethod
    def get_config_filenames() -> list[str]:
        return []

    @classmethod
    def from_config(cls, config: dict[str, Any]) -> "Int8TpuConfig":
        activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
        return cls(activation_scheme=activation_scheme)

    def get_quant_method(self, layer: Module,
                         prefix: str) -> Optional["TPUInt8LinearMethod"]:
        if isinstance(layer, LinearBase):
            return TPUInt8LinearMethod(self)
        return None

activation_scheme instance-attribute

activation_scheme = activation_scheme

__init__

__init__(activation_scheme: str = 'none') -> None
Source code in vllm/model_executor/layers/quantization/tpu_int8.py
def __init__(
    self,
    activation_scheme: str = "none",
) -> None:
    super().__init__()
    if activation_scheme not in ACTIVATION_SCHEMES:
        raise ValueError(
            f"Unsupported activation scheme {activation_scheme}")
    self.activation_scheme = activation_scheme

from_config classmethod

from_config(config: dict[str, Any]) -> Int8TpuConfig
Source code in vllm/model_executor/layers/quantization/tpu_int8.py
@classmethod
def from_config(cls, config: dict[str, Any]) -> "Int8TpuConfig":
    activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
    return cls(activation_scheme=activation_scheme)

get_config_filenames staticmethod

get_config_filenames() -> list[str]
Source code in vllm/model_executor/layers/quantization/tpu_int8.py
@staticmethod
def get_config_filenames() -> list[str]:
    return []

get_min_capability classmethod

get_min_capability() -> int
Source code in vllm/model_executor/layers/quantization/tpu_int8.py
@classmethod
def get_min_capability(cls) -> int:
    raise NotImplementedError(
        "This function should not be called with TPU Backend")

get_name

get_name() -> QuantizationMethods
Source code in vllm/model_executor/layers/quantization/tpu_int8.py
def get_name(self) -> QuantizationMethods:
    return "tpu_int8"

get_quant_method

get_quant_method(
    layer: Module, prefix: str
) -> Optional[TPUInt8LinearMethod]
Source code in vllm/model_executor/layers/quantization/tpu_int8.py
def get_quant_method(self, layer: Module,
                     prefix: str) -> Optional["TPUInt8LinearMethod"]:
    if isinstance(layer, LinearBase):
        return TPUInt8LinearMethod(self)
    return None

get_supported_act_dtypes

get_supported_act_dtypes() -> list[dtype]
Source code in vllm/model_executor/layers/quantization/tpu_int8.py
def get_supported_act_dtypes(self) -> list[torch.dtype]:
    return [torch.float16, torch.bfloat16]

TPUInt8LinearMethod

Bases: LinearMethodBase

Int8 Linear method for TPU Quant.

Source code in vllm/model_executor/layers/quantization/tpu_int8.py
class TPUInt8LinearMethod(LinearMethodBase):
    """Int8 Linear method for TPU Quant. """

    def __init__(self, quant_config: Int8TpuConfig):
        self.quant_config = quant_config

    def create_weights(self, layer: Module, input_size_per_partition: int,
                       output_partition_sizes: list[int], input_size: int,
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):

        weight_loader = extra_weight_attrs.get("weight_loader")
        weight = ModelWeightParameter(data=torch.empty(
            sum(output_partition_sizes),
            input_size_per_partition,
            dtype=params_dtype),
                                      input_dim=1,
                                      output_dim=0,
                                      weight_loader=weight_loader)
        layer.register_parameter("weight", weight)

    def _quantize_weight(
            self, weight: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        weight_dtype = weight.dtype
        weight = weight.cpu().to(torch.float32)
        n_bit = 8
        eps = 1e-5
        max_int = 2**(n_bit - 1) - 1
        min_int = -(2**(n_bit - 1))
        max_val = weight.abs().amax(dim=-1, keepdim=True)
        max_val = max_val.clamp(min=eps)
        qscale = max_val / max_int
        qweight = torch.clamp(torch.round(weight * (1.0 / qscale)), min_int,
                              max_int).to(torch.int8)
        qscale = qscale.squeeze().to(weight_dtype)
        return qweight, qscale

    def process_weights_after_loading(self, layer: Module) -> None:
        layer.weight = Parameter(layer.weight.data, requires_grad=False)
        device = layer.weight.device
        qweight, qscale = self._quantize_weight(layer.weight)
        qweight = qweight.to(device)
        qscale = qscale.to(device)
        layer.weight = Parameter(qweight, requires_grad=False)
        layer.scale = Parameter(qscale, requires_grad=False)

    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
        try:
            import torch_xla.experimental.xla_quantized_matmul  # noqa: F401
        except ImportError as err:
            raise ImportError(
                "Please install torch_xla by following the instructions at "
                "https://docs.vllm.ai/en/latest/getting_started/tpu-installation.html "  # noqa: E501
                "to run vLLM on TPU.") from err
        weight = layer.weight
        scale = layer.scale
        out = torch.ops.xla.quantized_matmul(x, weight, scale)
        if bias is not None:
            out = out + bias
        return out

quant_config instance-attribute

quant_config = quant_config

__init__

__init__(quant_config: Int8TpuConfig)
Source code in vllm/model_executor/layers/quantization/tpu_int8.py
def __init__(self, quant_config: Int8TpuConfig):
    self.quant_config = quant_config

_quantize_weight

_quantize_weight(weight: Tensor) -> tuple[Tensor, Tensor]
Source code in vllm/model_executor/layers/quantization/tpu_int8.py
def _quantize_weight(
        self, weight: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    weight_dtype = weight.dtype
    weight = weight.cpu().to(torch.float32)
    n_bit = 8
    eps = 1e-5
    max_int = 2**(n_bit - 1) - 1
    min_int = -(2**(n_bit - 1))
    max_val = weight.abs().amax(dim=-1, keepdim=True)
    max_val = max_val.clamp(min=eps)
    qscale = max_val / max_int
    qweight = torch.clamp(torch.round(weight * (1.0 / qscale)), min_int,
                          max_int).to(torch.int8)
    qscale = qscale.squeeze().to(weight_dtype)
    return qweight, qscale

apply

apply(
    layer: Module, x: Tensor, bias: Optional[Tensor] = None
) -> Tensor
Source code in vllm/model_executor/layers/quantization/tpu_int8.py
def apply(self,
          layer: torch.nn.Module,
          x: torch.Tensor,
          bias: Optional[torch.Tensor] = None) -> torch.Tensor:
    try:
        import torch_xla.experimental.xla_quantized_matmul  # noqa: F401
    except ImportError as err:
        raise ImportError(
            "Please install torch_xla by following the instructions at "
            "https://docs.vllm.ai/en/latest/getting_started/tpu-installation.html "  # noqa: E501
            "to run vLLM on TPU.") from err
    weight = layer.weight
    scale = layer.scale
    out = torch.ops.xla.quantized_matmul(x, weight, scale)
    if bias is not None:
        out = out + bias
    return out

create_weights

create_weights(
    layer: Module,
    input_size_per_partition: int,
    output_partition_sizes: list[int],
    input_size: int,
    output_size: int,
    params_dtype: dtype,
    **extra_weight_attrs,
)
Source code in vllm/model_executor/layers/quantization/tpu_int8.py
def create_weights(self, layer: Module, input_size_per_partition: int,
                   output_partition_sizes: list[int], input_size: int,
                   output_size: int, params_dtype: torch.dtype,
                   **extra_weight_attrs):

    weight_loader = extra_weight_attrs.get("weight_loader")
    weight = ModelWeightParameter(data=torch.empty(
        sum(output_partition_sizes),
        input_size_per_partition,
        dtype=params_dtype),
                                  input_dim=1,
                                  output_dim=0,
                                  weight_loader=weight_loader)
    layer.register_parameter("weight", weight)

process_weights_after_loading

process_weights_after_loading(layer: Module) -> None
Source code in vllm/model_executor/layers/quantization/tpu_int8.py
def process_weights_after_loading(self, layer: Module) -> None:
    layer.weight = Parameter(layer.weight.data, requires_grad=False)
    device = layer.weight.device
    qweight, qscale = self._quantize_weight(layer.weight)
    qweight = qweight.to(device)
    qscale = qscale.to(device)
    layer.weight = Parameter(qweight, requires_grad=False)
    layer.scale = Parameter(qscale, requires_grad=False)