add kvcache dtype

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2025-04-02 19:29:01 -07:00
parent 065f87a337
commit 102e29902a
2 changed files with 33 additions and 1 deletions

View File

@ -29,6 +29,11 @@ def attention(
raise NotImplementedError("softcap is not available in IPEX")
out = torch.empty_like(query)
kv_cache_dtype = "auto"
if kv_cache.key.dtype == torch.float8_e5m2:
kv_cache_dtype = "fp8_e5m2"
if kv_cache.key.dtype == torch.float8_e4m3fn:
kv_cache_dtype = "fp8_e4m3"
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
if ATTENTION == "flashdecoding-ipex":
@ -45,6 +50,7 @@ def attention(
causal,
block_tables,
None,
kv_cache_dtype=kv_cache_dtype,
k_scale=kv_scales.key_scale_cpu,
v_scale=kv_scales.value_scale_cpu,
)
@ -86,6 +92,11 @@ def paged_attention(
raise NotImplementedError("softcap is not available in IPEX")
out = torch.empty_like(query)
kv_cache_dtype = "auto"
if kv_cache.key.dtype == torch.float8_e5m2:
kv_cache_dtype = "fp8_e5m2"
if kv_cache.key.dtype == torch.float8_e4m3fn:
kv_cache_dtype = "fp8_e4m3"
if ATTENTION == "flashdecoding-ipex":
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
@ -101,6 +112,7 @@ def paged_attention(
True,
block_tables,
None,
kv_cache_dtype=kv_cache_dtype,
k_scale=kv_scales.key_scale_cpu,
v_scale=kv_scales.value_scale_cpu,
)
@ -118,6 +130,7 @@ def paged_attention(
BLOCK_SIZE,
max_s,
None,
kv_cache_dtype=kv_cache_dtype,
k_scale=kv_scales.key_scale_cpu,
v_scale=kv_scales.value_scale_cpu,
)

View File

@ -213,12 +213,18 @@ class KVCache:
elif ATTENTION == "flashdecoding-ipex" and key.device.type == "xpu":
import intel_extension_for_pytorch as ipex
kv_cache_dtype = "auto"
if key_cache.dtype == torch.float8_e5m2:
kv_cache_dtype = "fp8_e5m2"
if key_cache.dtype == torch.float8_e4m3fn:
kv_cache_dtype = "fp8_e4m3"
ipex.llm.modules.PagedAttention.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
slots,
kv_cache_dtype=kv_cache_dtype,
k_scale=kv_scales.key_scale_cpu,
v_scale=kv_scales.value_scale_cpu,
)
@ -279,8 +285,21 @@ def paged_reshape_and_cache(
elif SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex
kv_cache_dtype = "auto"
if key_cache.dtype == torch.float8_e5m2:
kv_cache_dtype = "fp8_e5m2"
if key_cache.dtype == torch.float8_e4m3fn:
kv_cache_dtype = "fp8_e4m3"
ipex.llm.modules.PagedAttention.reshape_and_cache(
key, value, key_cache, value_cache, slots, k_scale=k_scale, v_scale=v_scale
key,
value,
key_cache,
value_cache,
slots,
kv_cache_dtype=kv_cache_dtype,
k_scale=k_scale,
v_scale=v_scale,
)
else:
raise NotImplementedError(