Add fp8 kv cache for ROCm (#2856)

* add fp8 kv cache for rocm

* improvements

* update log statement

* remove bookkeeping field
This commit is contained in:
Mohit Sharma 2025-01-17 18:43:29 +05:30 committed by GitHub
parent de19e7e844
commit c20025dbf7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 62 additions and 34 deletions

View File

@ -52,13 +52,18 @@ class KVCache:
device: torch.device, device: torch.device,
): ):
"""Construct the key-value cache for a layer.""" """Construct the key-value cache for a layer."""
if dtype in {torch.float8_e5m2, torch.float8_e4m3fn}:
if dtype in {torch.float8_e5m2, torch.float8_e4m3fn} and ( if not (
ATTENTION != "flashinfer" or SYSTEM != "cuda" (ATTENTION == "flashinfer" and SYSTEM == "cuda")
): or (ATTENTION == "paged" and SYSTEM == "rocm")
raise ValueError( ):
"FP8 KV cache is currently only supported for flashinfer on CUDA" raise ValueError(
) "FP8 KV cache is currently only supported for flashinfer on CUDA and paged attention on ROCm. "
)
if SYSTEM == "rocm" and dtype == torch.float8_e5m2:
raise ValueError(
"float8_e5m2 FP8 KV cache is not supported on AMD ROCm"
)
element_size = torch.tensor([], dtype=dtype).element_size() element_size = torch.tensor([], dtype=dtype).element_size()
if SYSTEM == "ipex" and device.type == "xpu": if SYSTEM == "ipex" and device.type == "xpu":
@ -113,21 +118,17 @@ class KVCache:
"""Check if the cache can be scaled by the given scales.""" """Check if the cache can be scaled by the given scales."""
if kv_scales.key_scale_cpu == 1.0 and kv_scales.value_scale_cpu == 1.0: if kv_scales.key_scale_cpu == 1.0 and kv_scales.value_scale_cpu == 1.0:
return False return False
elif ( elif self.dtype == torch.float8_e4m3fn and (
self.dtype == torch.float8_e4m3fn (ATTENTION == "flashinfer" and SYSTEM == "cuda")
and ATTENTION == "flashinfer" or (ATTENTION == "paged" and SYSTEM == "rocm")
and SYSTEM == "cuda"
): ):
log_once( log_once(logger.info, "Using FP8 KV cache scales")
logger.info,
"Using FP8 KV cache scales",
)
return True return True
else: else:
# We have scales, but not the correct FP8 cache type, so warn once. # We have scales, but not the correct FP8 cache type, so warn once.
log_once( log_once(
logger.info, logger.info,
"Ignoring FP8 KV cache scales, only float8_e4m3fn KV cache on flashinfer is supported", "Ignoring FP8 KV cache scales, supported only for float8_e4m3fn KV cache with flashinfer on CUDA and paged attention on ROCm",
) )
return False return False
@ -161,7 +162,7 @@ class KVCache:
key_cache = self.kv_cache[0] key_cache = self.kv_cache[0]
value_cache = self.kv_cache[1] value_cache = self.kv_cache[1]
if self.can_scale(kv_scales): if self.can_scale(kv_scales) and SYSTEM == "cuda":
if kv_scales.key_scale_cpu != 1.0: if kv_scales.key_scale_cpu != 1.0:
key = fp8_quantize( key = fp8_quantize(
key.float(), key.float(),
@ -197,7 +198,15 @@ class KVCache:
key, value, key_cache, value_cache, slots key, value, key_cache, value_cache, slots
) )
else: else:
paged_reshape_and_cache(key, value, key_cache, value_cache, slots) paged_reshape_and_cache(
key,
value,
key_cache,
value_cache,
slots,
kv_scales.key_scale_cpu,
kv_scales.value_scale_cpu,
)
def paged_reshape_and_cache( def paged_reshape_and_cache(
@ -206,7 +215,10 @@ def paged_reshape_and_cache(
key_cache: torch.Tensor, key_cache: torch.Tensor,
value_cache: torch.Tensor, value_cache: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
k_scale: float = 1.0,
v_scale: float = 1.0,
): ):
if SYSTEM == "cuda": if SYSTEM == "cuda":
try: try:
import attention_kernels import attention_kernels
@ -224,8 +236,15 @@ def paged_reshape_and_cache(
raise ImportError( raise ImportError(
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}" f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
) )
kv_cache_dtype = "auto"
if key_cache.dtype == torch.float8_e4m3fn:
key_cache = key_cache.view(torch.uint8)
value_cache = value_cache.view(torch.uint8)
kv_cache_dtype = "fp8"
ops.reshape_and_cache( ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, "auto", 1.0, 1.0 key, value, key_cache, value_cache, slots, kv_cache_dtype, k_scale, v_scale
) )
elif SYSTEM == "ipex": elif SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex

View File

@ -133,6 +133,15 @@ def paged_attention(
out = torch.empty_like(query) out = torch.empty_like(query)
if kv_cache.dtype == torch.float8_e4m3fn:
key = kv_cache.key.view(torch.uint8)
value = kv_cache.value.view(torch.uint8)
kv_cache_dtype = "fp8"
else:
key = kv_cache.key
value = kv_cache.value
kv_cache_dtype = "auto"
# NOTE(woosuk): We use a simple heuristic to decide whether to use # NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use # PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of # V1 to avoid the overhead of reduction. Also, if the number of
@ -147,8 +156,8 @@ def paged_attention(
ops.paged_attention_v1( ops.paged_attention_v1(
out, out,
query, query,
kv_cache.key, key,
kv_cache.value, value,
num_kv_heads, num_kv_heads,
softmax_scale, softmax_scale,
block_tables, block_tables,
@ -156,9 +165,9 @@ def paged_attention(
block_size, block_size,
max_s, max_s,
None, None,
"auto", kv_cache_dtype,
1.0, kv_scales.key_scale_cpu,
1.0, kv_scales.value_scale_cpu,
) )
else: else:
# Run PagedAttention V2. # Run PagedAttention V2.
@ -182,8 +191,8 @@ def paged_attention(
max_logits, max_logits,
tmp_output, tmp_output,
query, query,
kv_cache.key, key,
kv_cache.value, value,
num_kv_heads, num_kv_heads,
softmax_scale, softmax_scale,
block_tables, block_tables,
@ -191,9 +200,9 @@ def paged_attention(
block_size, block_size,
max_s, max_s,
None, None,
"auto", kv_cache_dtype,
1.0, kv_scales.key_scale_cpu,
1.0, kv_scales.value_scale_cpu,
) )
else: else:
ops.paged_attention_rocm( ops.paged_attention_rocm(
@ -202,8 +211,8 @@ def paged_attention(
max_logits, max_logits,
tmp_output, tmp_output,
query, query,
kv_cache.key, key,
kv_cache.value, value,
num_kv_heads, num_kv_heads,
softmax_scale, softmax_scale,
block_tables, block_tables,
@ -211,9 +220,9 @@ def paged_attention(
block_size, block_size,
max_s, max_s,
None, None,
"auto", kv_cache_dtype,
1.0, kv_scales.key_scale_cpu,
1.0, kv_scales.value_scale_cpu,
None, None,
_PARTITION_SIZE, _PARTITION_SIZE,
) )