Make check more obvious

This commit is contained in:
Daniël de Kok 2024-10-16 13:54:57 +00:00
parent aa92e451a0
commit 751f1bb815

View File

@ -24,10 +24,8 @@ class KVCache:
):
"""Construct the key-value cache for a layer."""
if (
dtype.itemsize == 1
and dtype.is_floating_point
and (ATTENTION != "flashinfer" or SYSTEM != "cuda")
if dtype in {torch.float8_e5m2, torch.float8_e4m3fn} and (
ATTENTION != "flashinfer" or SYSTEM != "cuda"
):
raise ValueError(
"FP8 KV cache is currently only supported for flashinfer on CUDA"