Add fp8 kv cache for ROCm (#2856)

* add fp8 kv cache for rocm

* improvements

* update log statement

* remove bookkeeping field
This commit is contained in:
Mohit Sharma 2025-01-17 18:43:29 +05:30 committed by GitHub
parent de19e7e844
commit c20025dbf7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 62 additions and 34 deletions

View File

@ -52,12 +52,17 @@ class KVCache:
device: torch.device, device: torch.device,
): ):
"""Construct the key-value cache for a layer.""" """Construct the key-value cache for a layer."""
if dtype in {torch.float8_e5m2, torch.float8_e4m3fn}:
if dtype in {torch.float8_e5m2, torch.float8_e4m3fn} and ( if not (
ATTENTION != "flashinfer" or SYSTEM != "cuda" (ATTENTION == "flashinfer" and SYSTEM == "cuda")
or (ATTENTION == "paged" and SYSTEM == "rocm")
): ):
raise ValueError( raise ValueError(
"FP8 KV cache is currently only supported for flashinfer on CUDA" "FP8 KV cache is currently only supported for flashinfer on CUDA and paged attention on ROCm. "
)
if SYSTEM == "rocm" and dtype == torch.float8_e5m2:
raise ValueError(
"float8_e5m2 FP8 KV cache is not supported on AMD ROCm"
) )
element_size = torch.tensor([], dtype=dtype).element_size() element_size = torch.tensor([], dtype=dtype).element_size()
@ -113,21 +118,17 @@ class KVCache:
"""Check if the cache can be scaled by the given scales.""" """Check if the cache can be scaled by the given scales."""
if kv_scales.key_scale_cpu == 1.0 and kv_scales.value_scale_cpu == 1.0: if kv_scales.key_scale_cpu == 1.0 and kv_scales.value_scale_cpu == 1.0:
return False return False
elif ( elif self.dtype == torch.float8_e4m3fn and (
self.dtype == torch.float8_e4m3fn (ATTENTION == "flashinfer" and SYSTEM == "cuda")
and ATTENTION == "flashinfer" or (ATTENTION == "paged" and SYSTEM == "rocm")
and SYSTEM == "cuda"
): ):
log_once( log_once(logger.info, "Using FP8 KV cache scales")
logger.info,
"Using FP8 KV cache scales",
)
return True return True
else: else:
# We have scales, but not the correct FP8 cache type, so warn once. # We have scales, but not the correct FP8 cache type, so warn once.
log_once( log_once(
logger.info, logger.info,
"Ignoring FP8 KV cache scales, only float8_e4m3fn KV cache on flashinfer is supported", "Ignoring FP8 KV cache scales, supported only for float8_e4m3fn KV cache with flashinfer on CUDA and paged attention on ROCm",
) )
return False return False
@ -161,7 +162,7 @@ class KVCache:
key_cache = self.kv_cache[0] key_cache = self.kv_cache[0]
value_cache = self.kv_cache[1] value_cache = self.kv_cache[1]
if self.can_scale(kv_scales): if self.can_scale(kv_scales) and SYSTEM == "cuda":
if kv_scales.key_scale_cpu != 1.0: if kv_scales.key_scale_cpu != 1.0:
key = fp8_quantize( key = fp8_quantize(
key.float(), key.float(),
@ -197,7 +198,15 @@ class KVCache:
key, value, key_cache, value_cache, slots key, value, key_cache, value_cache, slots
) )
else: else:
paged_reshape_and_cache(key, value, key_cache, value_cache, slots) paged_reshape_and_cache(
key,
value,
key_cache,
value_cache,
slots,
kv_scales.key_scale_cpu,
kv_scales.value_scale_cpu,
)
def paged_reshape_and_cache( def paged_reshape_and_cache(
@ -206,7 +215,10 @@ def paged_reshape_and_cache(
key_cache: torch.Tensor, key_cache: torch.Tensor,
value_cache: torch.Tensor, value_cache: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
k_scale: float = 1.0,
v_scale: float = 1.0,
): ):
if SYSTEM == "cuda": if SYSTEM == "cuda":
try: try:
import attention_kernels import attention_kernels
@ -224,8 +236,15 @@ def paged_reshape_and_cache(
raise ImportError( raise ImportError(
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}" f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
) )
kv_cache_dtype = "auto"
if key_cache.dtype == torch.float8_e4m3fn:
key_cache = key_cache.view(torch.uint8)
value_cache = value_cache.view(torch.uint8)
kv_cache_dtype = "fp8"
ops.reshape_and_cache( ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, "auto", 1.0, 1.0 key, value, key_cache, value_cache, slots, kv_cache_dtype, k_scale, v_scale
) )
elif SYSTEM == "ipex": elif SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex

View File

@ -133,6 +133,15 @@ def paged_attention(
out = torch.empty_like(query) out = torch.empty_like(query)
if kv_cache.dtype == torch.float8_e4m3fn:
key = kv_cache.key.view(torch.uint8)
value = kv_cache.value.view(torch.uint8)
kv_cache_dtype = "fp8"
else:
key = kv_cache.key
value = kv_cache.value
kv_cache_dtype = "auto"
# NOTE(woosuk): We use a simple heuristic to decide whether to use # NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use # PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of # V1 to avoid the overhead of reduction. Also, if the number of
@ -147,8 +156,8 @@ def paged_attention(
ops.paged_attention_v1( ops.paged_attention_v1(
out, out,
query, query,
kv_cache.key, key,
kv_cache.value, value,
num_kv_heads, num_kv_heads,
softmax_scale, softmax_scale,
block_tables, block_tables,
@ -156,9 +165,9 @@ def paged_attention(
block_size, block_size,
max_s, max_s,
None, None,
"auto", kv_cache_dtype,
1.0, kv_scales.key_scale_cpu,
1.0, kv_scales.value_scale_cpu,
) )
else: else:
# Run PagedAttention V2. # Run PagedAttention V2.
@ -182,8 +191,8 @@ def paged_attention(
max_logits, max_logits,
tmp_output, tmp_output,
query, query,
kv_cache.key, key,
kv_cache.value, value,
num_kv_heads, num_kv_heads,
softmax_scale, softmax_scale,
block_tables, block_tables,
@ -191,9 +200,9 @@ def paged_attention(
block_size, block_size,
max_s, max_s,
None, None,
"auto", kv_cache_dtype,
1.0, kv_scales.key_scale_cpu,
1.0, kv_scales.value_scale_cpu,
) )
else: else:
ops.paged_attention_rocm( ops.paged_attention_rocm(
@ -202,8 +211,8 @@ def paged_attention(
max_logits, max_logits,
tmp_output, tmp_output,
query, query,
kv_cache.key, key,
kv_cache.value, value,
num_kv_heads, num_kv_heads,
softmax_scale, softmax_scale,
block_tables, block_tables,
@ -211,9 +220,9 @@ def paged_attention(
block_size, block_size,
max_s, max_s,
None, None,
"auto", kv_cache_dtype,
1.0, kv_scales.key_scale_cpu,
1.0, kv_scales.value_scale_cpu,
None, None,
_PARTITION_SIZE, _PARTITION_SIZE,
) )