mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
add kvcache dtype
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
065f87a337
commit
102e29902a
@ -29,6 +29,11 @@ def attention(
|
||||
raise NotImplementedError("softcap is not available in IPEX")
|
||||
|
||||
out = torch.empty_like(query)
|
||||
kv_cache_dtype = "auto"
|
||||
if kv_cache.key.dtype == torch.float8_e5m2:
|
||||
kv_cache_dtype = "fp8_e5m2"
|
||||
if kv_cache.key.dtype == torch.float8_e4m3fn:
|
||||
kv_cache_dtype = "fp8_e4m3"
|
||||
|
||||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||
if ATTENTION == "flashdecoding-ipex":
|
||||
@ -45,6 +50,7 @@ def attention(
|
||||
causal,
|
||||
block_tables,
|
||||
None,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
k_scale=kv_scales.key_scale_cpu,
|
||||
v_scale=kv_scales.value_scale_cpu,
|
||||
)
|
||||
@ -86,6 +92,11 @@ def paged_attention(
|
||||
raise NotImplementedError("softcap is not available in IPEX")
|
||||
|
||||
out = torch.empty_like(query)
|
||||
kv_cache_dtype = "auto"
|
||||
if kv_cache.key.dtype == torch.float8_e5m2:
|
||||
kv_cache_dtype = "fp8_e5m2"
|
||||
if kv_cache.key.dtype == torch.float8_e4m3fn:
|
||||
kv_cache_dtype = "fp8_e4m3"
|
||||
|
||||
if ATTENTION == "flashdecoding-ipex":
|
||||
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
|
||||
@ -101,6 +112,7 @@ def paged_attention(
|
||||
True,
|
||||
block_tables,
|
||||
None,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
k_scale=kv_scales.key_scale_cpu,
|
||||
v_scale=kv_scales.value_scale_cpu,
|
||||
)
|
||||
@ -118,6 +130,7 @@ def paged_attention(
|
||||
BLOCK_SIZE,
|
||||
max_s,
|
||||
None,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
k_scale=kv_scales.key_scale_cpu,
|
||||
v_scale=kv_scales.value_scale_cpu,
|
||||
)
|
||||
|
@ -213,12 +213,18 @@ class KVCache:
|
||||
elif ATTENTION == "flashdecoding-ipex" and key.device.type == "xpu":
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
||||
kv_cache_dtype = "auto"
|
||||
if key_cache.dtype == torch.float8_e5m2:
|
||||
kv_cache_dtype = "fp8_e5m2"
|
||||
if key_cache.dtype == torch.float8_e4m3fn:
|
||||
kv_cache_dtype = "fp8_e4m3"
|
||||
ipex.llm.modules.PagedAttention.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slots,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
k_scale=kv_scales.key_scale_cpu,
|
||||
v_scale=kv_scales.value_scale_cpu,
|
||||
)
|
||||
@ -279,8 +285,21 @@ def paged_reshape_and_cache(
|
||||
elif SYSTEM == "ipex":
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
||||
kv_cache_dtype = "auto"
|
||||
if key_cache.dtype == torch.float8_e5m2:
|
||||
kv_cache_dtype = "fp8_e5m2"
|
||||
if key_cache.dtype == torch.float8_e4m3fn:
|
||||
kv_cache_dtype = "fp8_e4m3"
|
||||
|
||||
ipex.llm.modules.PagedAttention.reshape_and_cache(
|
||||
key, value, key_cache, value_cache, slots, k_scale=k_scale, v_scale=v_scale
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slots,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
|
Loading…
Reference in New Issue
Block a user