diff --git a/server/text_generation_server/layers/attention/ipex.py b/server/text_generation_server/layers/attention/ipex.py index 0b44072c..ee608e71 100644 --- a/server/text_generation_server/layers/attention/ipex.py +++ b/server/text_generation_server/layers/attention/ipex.py @@ -29,6 +29,11 @@ def attention( raise NotImplementedError("softcap is not available in IPEX") out = torch.empty_like(query) + kv_cache_dtype = "auto" + if kv_cache.key.dtype == torch.float8_e5m2: + kv_cache_dtype = "fp8_e5m2" + if kv_cache.key.dtype == torch.float8_e4m3fn: + kv_cache_dtype = "fp8_e4m3" # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. if ATTENTION == "flashdecoding-ipex": @@ -45,6 +50,7 @@ def attention( causal, block_tables, None, + kv_cache_dtype=kv_cache_dtype, k_scale=kv_scales.key_scale_cpu, v_scale=kv_scales.value_scale_cpu, ) @@ -86,6 +92,11 @@ def paged_attention( raise NotImplementedError("softcap is not available in IPEX") out = torch.empty_like(query) + kv_cache_dtype = "auto" + if kv_cache.key.dtype == torch.float8_e5m2: + kv_cache_dtype = "fp8_e5m2" + if kv_cache.key.dtype == torch.float8_e4m3fn: + kv_cache_dtype = "fp8_e4m3" if ATTENTION == "flashdecoding-ipex": ipex.llm.modules.PagedAttention.flash_attn_varlen_func( @@ -101,6 +112,7 @@ def paged_attention( True, block_tables, None, + kv_cache_dtype=kv_cache_dtype, k_scale=kv_scales.key_scale_cpu, v_scale=kv_scales.value_scale_cpu, ) @@ -118,6 +130,7 @@ def paged_attention( BLOCK_SIZE, max_s, None, + kv_cache_dtype=kv_cache_dtype, k_scale=kv_scales.key_scale_cpu, v_scale=kv_scales.value_scale_cpu, ) diff --git a/server/text_generation_server/layers/attention/kv_cache.py b/server/text_generation_server/layers/attention/kv_cache.py index 7082d3ae..a37ecd4c 100644 --- a/server/text_generation_server/layers/attention/kv_cache.py +++ b/server/text_generation_server/layers/attention/kv_cache.py @@ -213,12 +213,18 @@ class KVCache: elif ATTENTION == "flashdecoding-ipex" and key.device.type == "xpu": import intel_extension_for_pytorch as ipex + kv_cache_dtype = "auto" + if key_cache.dtype == torch.float8_e5m2: + kv_cache_dtype = "fp8_e5m2" + if key_cache.dtype == torch.float8_e4m3fn: + kv_cache_dtype = "fp8_e4m3" ipex.llm.modules.PagedAttention.reshape_and_cache_flash( key, value, key_cache, value_cache, slots, + kv_cache_dtype=kv_cache_dtype, k_scale=kv_scales.key_scale_cpu, v_scale=kv_scales.value_scale_cpu, ) @@ -279,8 +285,21 @@ def paged_reshape_and_cache( elif SYSTEM == "ipex": import intel_extension_for_pytorch as ipex + kv_cache_dtype = "auto" + if key_cache.dtype == torch.float8_e5m2: + kv_cache_dtype = "fp8_e5m2" + if key_cache.dtype == torch.float8_e4m3fn: + kv_cache_dtype = "fp8_e4m3" + ipex.llm.modules.PagedAttention.reshape_and_cache( - key, value, key_cache, value_cache, slots, k_scale=k_scale, v_scale=v_scale + key, + value, + key_cache, + value_cache, + slots, + kv_cache_dtype=kv_cache_dtype, + k_scale=k_scale, + v_scale=v_scale, ) else: raise NotImplementedError(