can_scale: check that the attention is flashinfer

This commit is contained in:
Daniël de Kok 2024-10-24 12:35:30 +00:00
parent 1f18cb6aa6
commit a68fae05e9

View File

@ -110,7 +110,11 @@ class KVCache:
"""Check if the cache can be scaled by the given scales."""
if kv_scales.key_scale_cpu == 1.0 and kv_scales.value_scale_cpu == 1.0:
return False
elif self.dtype == torch.float8_e4m3fn and SYSTEM == "cuda":
elif (
self.dtype == torch.float8_e4m3fn
and ATTENTION == "flashinfer"
and SYSTEM == "cuda"
):
log_once(
logger.info,
"Using FP8 KV cache scales",
@ -120,7 +124,7 @@ class KVCache:
# We have scales, but not the correct FP8 cache type, so warn once.
log_once(
logger.info,
"Ignoring FP8 KV cache scales, only float8_e4m3fn KV cache on CUDA is supported",
"Ignoring FP8 KV cache scales, only float8_e4m3fn KV cache on flashinfer is supported",
)
return False