diff --git a/server/text_generation_server/layers/attention/kv_cache.py b/server/text_generation_server/layers/attention/kv_cache.py index 8a79eebb..9d739da5 100644 --- a/server/text_generation_server/layers/attention/kv_cache.py +++ b/server/text_generation_server/layers/attention/kv_cache.py @@ -110,7 +110,11 @@ class KVCache: """Check if the cache can be scaled by the given scales.""" if kv_scales.key_scale_cpu == 1.0 and kv_scales.value_scale_cpu == 1.0: return False - elif self.dtype == torch.float8_e4m3fn and SYSTEM == "cuda": + elif ( + self.dtype == torch.float8_e4m3fn + and ATTENTION == "flashinfer" + and SYSTEM == "cuda" + ): log_once( logger.info, "Using FP8 KV cache scales", @@ -120,7 +124,7 @@ class KVCache: # We have scales, but not the correct FP8 cache type, so warn once. log_once( logger.info, - "Ignoring FP8 KV cache scales, only float8_e4m3fn KV cache on CUDA is supported", + "Ignoring FP8 KV cache scales, only float8_e4m3fn KV cache on flashinfer is supported", ) return False