diff --git a/server/text_generation_server/layers/attention/kv_cache.py b/server/text_generation_server/layers/attention/kv_cache.py index b02b46c9..ba070d04 100644 --- a/server/text_generation_server/layers/attention/kv_cache.py +++ b/server/text_generation_server/layers/attention/kv_cache.py @@ -65,11 +65,6 @@ class KVCache: "float8_e5m2 FP8 KV cache is not supported on AMD ROCm" ) - self.kv_cache_dtype = "auto" - if SYSTEM == "rocm" and dtype == torch.float8_e4m3fn: - self.kv_cache_dtype = "fp8" - dtype = torch.uint8 - element_size = torch.tensor([], dtype=dtype).element_size() if SYSTEM == "ipex" and device.type == "xpu": x = 1 @@ -120,12 +115,9 @@ class KVCache: """Check if the cache can be scaled by the given scales.""" if kv_scales.key_scale_cpu == 1.0 and kv_scales.value_scale_cpu == 1.0: return False - elif ( - self.dtype == torch.float8_e4m3fn - and ATTENTION == "flashinfer" - and SYSTEM == "cuda" - ) or ( - self.kv_cache_dtype == "fp8" and ATTENTION == "paged" and SYSTEM == "rocm" + elif self.dtype == torch.float8_e4m3fn and ( + (ATTENTION == "flashinfer" and SYSTEM == "cuda") + or (ATTENTION == "paged" and SYSTEM == "rocm") ): log_once(logger.info, "Using FP8 KV cache scales") return True @@ -203,7 +195,6 @@ class KVCache: key_cache, value_cache, slots, - self.kv_cache_dtype, kv_scales.key_scale_cpu, kv_scales.value_scale_cpu, ) @@ -215,7 +206,6 @@ def paged_reshape_and_cache( key_cache: torch.Tensor, value_cache: torch.Tensor, slots: torch.Tensor, - kv_cache_dtype: str = "auto", k_scale: float = 1.0, v_scale: float = 1.0, ): @@ -237,6 +227,13 @@ def paged_reshape_and_cache( raise ImportError( f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}" ) + + kv_cache_dtype = "auto" + if key_cache.dtype == torch.float8_e4m3fn: + key_cache = key_cache.view(torch.uint8) + value_cache = value_cache.view(torch.uint8) + kv_cache_dtype = "fp8" + ops.reshape_and_cache( key, value, key_cache, value_cache, slots, kv_cache_dtype, k_scale, v_scale ) diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index 146c15e9..65f3ea41 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -133,6 +133,15 @@ def paged_attention( out = torch.empty_like(query) + if kv_cache.dtype == torch.float8_e4m3fn: + key = kv_cache.key.view(torch.uint8) + value = kv_cache.value.view(torch.uint8) + kv_cache_dtype = "fp8" + else: + key = kv_cache.key + value = kv_cache.value + kv_cache_dtype = "auto" + # NOTE(woosuk): We use a simple heuristic to decide whether to use # PagedAttention V1 or V2. If the number of partitions is 1, we use # V1 to avoid the overhead of reduction. Also, if the number of @@ -147,8 +156,8 @@ def paged_attention( ops.paged_attention_v1( out, query, - kv_cache.key, - kv_cache.value, + key, + value, num_kv_heads, softmax_scale, block_tables, @@ -156,7 +165,7 @@ def paged_attention( block_size, max_s, None, - kv_cache.kv_cache_dtype, + kv_cache_dtype, kv_scales.key_scale_cpu, kv_scales.value_scale_cpu, ) @@ -182,8 +191,8 @@ def paged_attention( max_logits, tmp_output, query, - kv_cache.key, - kv_cache.value, + key, + value, num_kv_heads, softmax_scale, block_tables, @@ -191,7 +200,7 @@ def paged_attention( block_size, max_s, None, - kv_cache.kv_cache_dtype, + kv_cache_dtype, kv_scales.key_scale_cpu, kv_scales.value_scale_cpu, ) @@ -202,8 +211,8 @@ def paged_attention( max_logits, tmp_output, query, - kv_cache.key, - kv_cache.value, + key, + value, num_kv_heads, softmax_scale, block_tables, @@ -211,7 +220,7 @@ def paged_attention( block_size, max_s, None, - kv_cache.kv_cache_dtype, + kv_cache_dtype, kv_scales.key_scale_cpu, kv_scales.value_scale_cpu, None,