From a68fae05e9f7081e9bc0eb25f7fd85dbd7ab99c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Thu, 24 Oct 2024 12:35:30 +0000 Subject: [PATCH] `can_scale`: check that the attention is flashinfer --- .../text_generation_server/layers/attention/kv_cache.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/layers/attention/kv_cache.py b/server/text_generation_server/layers/attention/kv_cache.py index 8a79eebb..9d739da5 100644 --- a/server/text_generation_server/layers/attention/kv_cache.py +++ b/server/text_generation_server/layers/attention/kv_cache.py @@ -110,7 +110,11 @@ class KVCache: """Check if the cache can be scaled by the given scales.""" if kv_scales.key_scale_cpu == 1.0 and kv_scales.value_scale_cpu == 1.0: return False - elif self.dtype == torch.float8_e4m3fn and SYSTEM == "cuda": + elif ( + self.dtype == torch.float8_e4m3fn + and ATTENTION == "flashinfer" + and SYSTEM == "cuda" + ): log_once( logger.info, "Using FP8 KV cache scales", @@ -120,7 +124,7 @@ class KVCache: # We have scales, but not the correct FP8 cache type, so warn once. log_once( logger.info, - "Ignoring FP8 KV cache scales, only float8_e4m3fn KV cache on CUDA is supported", + "Ignoring FP8 KV cache scales, only float8_e4m3fn KV cache on flashinfer is supported", ) return False