def check_allspark_supported_dtype_shape(input_size_per_partition: int,
output_size_per_partition: int,
group_size: int,
weight_dtype: ScalarType,
act_dtype: torch.dtype):
capability_tuple = current_platform.get_device_capability()
device_capability = (-1 if capability_tuple is None else
capability_tuple.to_int())
# For Ampere GPU
if device_capability >= 80 and device_capability < 90:
if group_size != -1:
return False, \
"For Ampere GPU, AllSpark does not support group_size "\
f"= {group_size}. Only group_size = -1 are supported."
if weight_dtype not in ALLSPARK_SUPPORTED_QUANT_TYPES:
return False, "For Ampere GPU, AllSpark does not support "\
f"quant type ({weight_dtype}). Only quant type "\
f"({ALLSPARK_SUPPORTED_QUANT_TYPES}) are supported."
if input_size_per_partition % ALLSPARK_AMPERE_K_ALIGN != 0 \
or output_size_per_partition % ALLSPARK_AMPERE_N_ALIGN != 0:
return False, \
"AllSpark needs input_size_per_partition % "\
f"{ALLSPARK_AMPERE_K_ALIGN} = 0 and "\
f"output_size_per_partition % {ALLSPARK_AMPERE_N_ALIGN} = 0 "\
"for Ampere GPU optimized kernels."
if act_dtype != torch.float16 and act_dtype != torch.bfloat16:
return False, \
"AllSpark only supports act_dtype = float16 or bfloat16,"\
f"for Ampere GPU, but got act_dtype = {act_dtype}."
else:
return False, "AllSpark currently does not support "\
f"device_capability = {device_capability}."
return True, None