Bases: CutlassScaledMMLinearKernel
Source code in vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py
| class TritonScaledMMLinearKernel(CutlassScaledMMLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
return 75
@classmethod
def can_implement(
cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
if current_platform.is_cpu():
return (
False,
"TritonScaledMMLinearKernel requires Triton which is not " +
"currently supported on CPU.")
if not c.input_symmetric:
return (False,
"TritonScaledMMLinearKernel only supports symmetric " +
"quantization.")
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
super().process_weights_after_loading(layer)
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return super().apply_weights(layer, x, bias)
|
apply_weights
Source code in vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py
| def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return super().apply_weights(layer, x, bias)
|
can_implement
classmethod
Source code in vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py
| @classmethod
def can_implement(
cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
if current_platform.is_cpu():
return (
False,
"TritonScaledMMLinearKernel requires Triton which is not " +
"currently supported on CPU.")
if not c.input_symmetric:
return (False,
"TritonScaledMMLinearKernel only supports symmetric " +
"quantization.")
return True, None
|
get_min_capability
classmethod
get_min_capability() -> int
Source code in vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py
| @classmethod
def get_min_capability(cls) -> int:
return 75
|
process_weights_after_loading
process_weights_after_loading(layer: Module) -> None
Source code in vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py
| def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
super().process_weights_after_loading(layer)
|