From c20025dbf7708b953a4c39600a8569fe9d8f023b Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Fri, 17 Jan 2025 18:43:29 +0530 Subject: [PATCH] Add fp8 kv cache for ROCm (#2856) * add fp8 kv cache for rocm * improvements * update log statement * remove bookkeeping field --- .../layers/attention/kv_cache.py | 57 ++++++++++++------- .../layers/attention/rocm.py | 39 ++++++++----- 2 files changed, 62 insertions(+), 34 deletions(-) diff --git a/server/text_generation_server/layers/attention/kv_cache.py b/server/text_generation_server/layers/attention/kv_cache.py index 00308601..77e761a3 100644 --- a/server/text_generation_server/layers/attention/kv_cache.py +++ b/server/text_generation_server/layers/attention/kv_cache.py @@ -52,13 +52,18 @@ class KVCache: device: torch.device, ): """Construct the key-value cache for a layer.""" - - if dtype in {torch.float8_e5m2, torch.float8_e4m3fn} and ( - ATTENTION != "flashinfer" or SYSTEM != "cuda" - ): - raise ValueError( - "FP8 KV cache is currently only supported for flashinfer on CUDA" - ) + if dtype in {torch.float8_e5m2, torch.float8_e4m3fn}: + if not ( + (ATTENTION == "flashinfer" and SYSTEM == "cuda") + or (ATTENTION == "paged" and SYSTEM == "rocm") + ): + raise ValueError( + "FP8 KV cache is currently only supported for flashinfer on CUDA and paged attention on ROCm. " + ) + if SYSTEM == "rocm" and dtype == torch.float8_e5m2: + raise ValueError( + "float8_e5m2 FP8 KV cache is not supported on AMD ROCm" + ) element_size = torch.tensor([], dtype=dtype).element_size() if SYSTEM == "ipex" and device.type == "xpu": @@ -113,21 +118,17 @@ class KVCache: """Check if the cache can be scaled by the given scales.""" if kv_scales.key_scale_cpu == 1.0 and kv_scales.value_scale_cpu == 1.0: return False - elif ( - self.dtype == torch.float8_e4m3fn - and ATTENTION == "flashinfer" - and SYSTEM == "cuda" + elif self.dtype == torch.float8_e4m3fn and ( + (ATTENTION == "flashinfer" and SYSTEM == "cuda") + or (ATTENTION == "paged" and SYSTEM == "rocm") ): - log_once( - logger.info, - "Using FP8 KV cache scales", - ) + log_once(logger.info, "Using FP8 KV cache scales") return True else: # We have scales, but not the correct FP8 cache type, so warn once. log_once( logger.info, - "Ignoring FP8 KV cache scales, only float8_e4m3fn KV cache on flashinfer is supported", + "Ignoring FP8 KV cache scales, supported only for float8_e4m3fn KV cache with flashinfer on CUDA and paged attention on ROCm", ) return False @@ -161,7 +162,7 @@ class KVCache: key_cache = self.kv_cache[0] value_cache = self.kv_cache[1] - if self.can_scale(kv_scales): + if self.can_scale(kv_scales) and SYSTEM == "cuda": if kv_scales.key_scale_cpu != 1.0: key = fp8_quantize( key.float(), @@ -197,7 +198,15 @@ class KVCache: key, value, key_cache, value_cache, slots ) else: - paged_reshape_and_cache(key, value, key_cache, value_cache, slots) + paged_reshape_and_cache( + key, + value, + key_cache, + value_cache, + slots, + kv_scales.key_scale_cpu, + kv_scales.value_scale_cpu, + ) def paged_reshape_and_cache( @@ -206,7 +215,10 @@ def paged_reshape_and_cache( key_cache: torch.Tensor, value_cache: torch.Tensor, slots: torch.Tensor, + k_scale: float = 1.0, + v_scale: float = 1.0, ): + if SYSTEM == "cuda": try: import attention_kernels @@ -224,8 +236,15 @@ def paged_reshape_and_cache( raise ImportError( f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}" ) + + kv_cache_dtype = "auto" + if key_cache.dtype == torch.float8_e4m3fn: + key_cache = key_cache.view(torch.uint8) + value_cache = value_cache.view(torch.uint8) + kv_cache_dtype = "fp8" + ops.reshape_and_cache( - key, value, key_cache, value_cache, slots, "auto", 1.0, 1.0 + key, value, key_cache, value_cache, slots, kv_cache_dtype, k_scale, v_scale ) elif SYSTEM == "ipex": import intel_extension_for_pytorch as ipex diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index b94b737d..65f3ea41 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -133,6 +133,15 @@ def paged_attention( out = torch.empty_like(query) + if kv_cache.dtype == torch.float8_e4m3fn: + key = kv_cache.key.view(torch.uint8) + value = kv_cache.value.view(torch.uint8) + kv_cache_dtype = "fp8" + else: + key = kv_cache.key + value = kv_cache.value + kv_cache_dtype = "auto" + # NOTE(woosuk): We use a simple heuristic to decide whether to use # PagedAttention V1 or V2. If the number of partitions is 1, we use # V1 to avoid the overhead of reduction. Also, if the number of @@ -147,8 +156,8 @@ def paged_attention( ops.paged_attention_v1( out, query, - kv_cache.key, - kv_cache.value, + key, + value, num_kv_heads, softmax_scale, block_tables, @@ -156,9 +165,9 @@ def paged_attention( block_size, max_s, None, - "auto", - 1.0, - 1.0, + kv_cache_dtype, + kv_scales.key_scale_cpu, + kv_scales.value_scale_cpu, ) else: # Run PagedAttention V2. @@ -182,8 +191,8 @@ def paged_attention( max_logits, tmp_output, query, - kv_cache.key, - kv_cache.value, + key, + value, num_kv_heads, softmax_scale, block_tables, @@ -191,9 +200,9 @@ def paged_attention( block_size, max_s, None, - "auto", - 1.0, - 1.0, + kv_cache_dtype, + kv_scales.key_scale_cpu, + kv_scales.value_scale_cpu, ) else: ops.paged_attention_rocm( @@ -202,8 +211,8 @@ def paged_attention( max_logits, tmp_output, query, - kv_cache.key, - kv_cache.value, + key, + value, num_kv_heads, softmax_scale, block_tables, @@ -211,9 +220,9 @@ def paged_attention( block_size, max_s, None, - "auto", - 1.0, - 1.0, + kv_cache_dtype, + kv_scales.key_scale_cpu, + kv_scales.value_scale_cpu, None, _PARTITION_SIZE, )