diff --git a/server/text_generation_server/layers/attention/kv_cache.py b/server/text_generation_server/layers/attention/kv_cache.py index f1f9ecce..7f1dd370 100644 --- a/server/text_generation_server/layers/attention/kv_cache.py +++ b/server/text_generation_server/layers/attention/kv_cache.py @@ -24,10 +24,8 @@ class KVCache: ): """Construct the key-value cache for a layer.""" - if ( - dtype.itemsize == 1 - and dtype.is_floating_point - and (ATTENTION != "flashinfer" or SYSTEM != "cuda") + if dtype in {torch.float8_e5m2, torch.float8_e4m3fn} and ( + ATTENTION != "flashinfer" or SYSTEM != "cuda" ): raise ValueError( "FP8 KV cache is currently only supported for flashinfer on CUDA"