Skip to content

vllm.model_executor.layers.quantization.deepgemm

logger module-attribute

logger = getLogger(__name__)

prepare_block_fp8_matmul_inputs

prepare_block_fp8_matmul_inputs(
    A: Tensor,
    B: Tensor,
    As: Tensor,
    Bs: Tensor,
    block_size: list[int],
    output_dtype: dtype = float16,
) -> tuple[int, int, int, Tensor]
Source code in vllm/model_executor/layers/quantization/deepgemm.py
def prepare_block_fp8_matmul_inputs(
    A: torch.Tensor,
    B: torch.Tensor,
    As: torch.Tensor,
    Bs: torch.Tensor,
    block_size: list[int],
    output_dtype: torch.dtype = torch.float16,
) -> tuple[int, int, int, torch.Tensor]:
    assert len(block_size) == 2
    block_n, block_k = block_size[0], block_size[1]

    assert A.shape[-1] == B.shape[-1]
    assert A.shape[:-1] == As.shape[:-1]
    assert A.is_contiguous()
    assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]

    M = A.numel() // A.shape[-1]

    assert B.ndim == 2
    assert B.is_contiguous()
    assert Bs.ndim == 2
    N, K = B.shape
    assert triton.cdiv(N, block_n) == Bs.shape[0]
    assert triton.cdiv(K, block_k) == Bs.shape[1]

    C_shape = A.shape[:-1] + (N, )
    C = A.new_empty(C_shape, dtype=output_dtype)

    return M, N, K, C

w8a8_block_fp8_matmul_deepgemm

w8a8_block_fp8_matmul_deepgemm(
    A: Tensor,
    B: Tensor,
    As: Tensor,
    Bs: Tensor,
    block_size: list[int],
    output_dtype: dtype,
) -> Tensor
Source code in vllm/model_executor/layers/quantization/deepgemm.py
def w8a8_block_fp8_matmul_deepgemm(
    A: torch.Tensor,
    B: torch.Tensor,
    As: torch.Tensor,
    Bs: torch.Tensor,
    block_size: list[int],
    output_dtype: torch.dtype,
) -> torch.Tensor:
    M, N, K, C = prepare_block_fp8_matmul_inputs(A, B, As, Bs, block_size,
                                                 output_dtype)
    # Deepgemm only supports output tensor type as bfloat16
    assert C.dtype == torch.bfloat16
    deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
    return C

w8a8_block_fp8_matmul_deepgemm_fake

w8a8_block_fp8_matmul_deepgemm_fake(
    A: Tensor,
    B: Tensor,
    As: Tensor,
    Bs: Tensor,
    block_size: list[int],
    output_dtype: dtype,
) -> Tensor
Source code in vllm/model_executor/layers/quantization/deepgemm.py
def w8a8_block_fp8_matmul_deepgemm_fake(
    A: torch.Tensor,
    B: torch.Tensor,
    As: torch.Tensor,
    Bs: torch.Tensor,
    block_size: list[int],
    output_dtype: torch.dtype,
) -> torch.Tensor:
    M, N, K, C = prepare_block_fp8_matmul_inputs(A, B, As, Bs, block_size,
                                                 output_dtype)
    return C