mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
add kvcache dtype
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
065f87a337
commit
102e29902a
@ -29,6 +29,11 @@ def attention(
|
|||||||
raise NotImplementedError("softcap is not available in IPEX")
|
raise NotImplementedError("softcap is not available in IPEX")
|
||||||
|
|
||||||
out = torch.empty_like(query)
|
out = torch.empty_like(query)
|
||||||
|
kv_cache_dtype = "auto"
|
||||||
|
if kv_cache.key.dtype == torch.float8_e5m2:
|
||||||
|
kv_cache_dtype = "fp8_e5m2"
|
||||||
|
if kv_cache.key.dtype == torch.float8_e4m3fn:
|
||||||
|
kv_cache_dtype = "fp8_e4m3"
|
||||||
|
|
||||||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||||
if ATTENTION == "flashdecoding-ipex":
|
if ATTENTION == "flashdecoding-ipex":
|
||||||
@ -45,6 +50,7 @@ def attention(
|
|||||||
causal,
|
causal,
|
||||||
block_tables,
|
block_tables,
|
||||||
None,
|
None,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
k_scale=kv_scales.key_scale_cpu,
|
k_scale=kv_scales.key_scale_cpu,
|
||||||
v_scale=kv_scales.value_scale_cpu,
|
v_scale=kv_scales.value_scale_cpu,
|
||||||
)
|
)
|
||||||
@ -86,6 +92,11 @@ def paged_attention(
|
|||||||
raise NotImplementedError("softcap is not available in IPEX")
|
raise NotImplementedError("softcap is not available in IPEX")
|
||||||
|
|
||||||
out = torch.empty_like(query)
|
out = torch.empty_like(query)
|
||||||
|
kv_cache_dtype = "auto"
|
||||||
|
if kv_cache.key.dtype == torch.float8_e5m2:
|
||||||
|
kv_cache_dtype = "fp8_e5m2"
|
||||||
|
if kv_cache.key.dtype == torch.float8_e4m3fn:
|
||||||
|
kv_cache_dtype = "fp8_e4m3"
|
||||||
|
|
||||||
if ATTENTION == "flashdecoding-ipex":
|
if ATTENTION == "flashdecoding-ipex":
|
||||||
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
|
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
|
||||||
@ -101,6 +112,7 @@ def paged_attention(
|
|||||||
True,
|
True,
|
||||||
block_tables,
|
block_tables,
|
||||||
None,
|
None,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
k_scale=kv_scales.key_scale_cpu,
|
k_scale=kv_scales.key_scale_cpu,
|
||||||
v_scale=kv_scales.value_scale_cpu,
|
v_scale=kv_scales.value_scale_cpu,
|
||||||
)
|
)
|
||||||
@ -118,6 +130,7 @@ def paged_attention(
|
|||||||
BLOCK_SIZE,
|
BLOCK_SIZE,
|
||||||
max_s,
|
max_s,
|
||||||
None,
|
None,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
k_scale=kv_scales.key_scale_cpu,
|
k_scale=kv_scales.key_scale_cpu,
|
||||||
v_scale=kv_scales.value_scale_cpu,
|
v_scale=kv_scales.value_scale_cpu,
|
||||||
)
|
)
|
||||||
|
@ -213,12 +213,18 @@ class KVCache:
|
|||||||
elif ATTENTION == "flashdecoding-ipex" and key.device.type == "xpu":
|
elif ATTENTION == "flashdecoding-ipex" and key.device.type == "xpu":
|
||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex
|
||||||
|
|
||||||
|
kv_cache_dtype = "auto"
|
||||||
|
if key_cache.dtype == torch.float8_e5m2:
|
||||||
|
kv_cache_dtype = "fp8_e5m2"
|
||||||
|
if key_cache.dtype == torch.float8_e4m3fn:
|
||||||
|
kv_cache_dtype = "fp8_e4m3"
|
||||||
ipex.llm.modules.PagedAttention.reshape_and_cache_flash(
|
ipex.llm.modules.PagedAttention.reshape_and_cache_flash(
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
key_cache,
|
key_cache,
|
||||||
value_cache,
|
value_cache,
|
||||||
slots,
|
slots,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
k_scale=kv_scales.key_scale_cpu,
|
k_scale=kv_scales.key_scale_cpu,
|
||||||
v_scale=kv_scales.value_scale_cpu,
|
v_scale=kv_scales.value_scale_cpu,
|
||||||
)
|
)
|
||||||
@ -279,8 +285,21 @@ def paged_reshape_and_cache(
|
|||||||
elif SYSTEM == "ipex":
|
elif SYSTEM == "ipex":
|
||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex
|
||||||
|
|
||||||
|
kv_cache_dtype = "auto"
|
||||||
|
if key_cache.dtype == torch.float8_e5m2:
|
||||||
|
kv_cache_dtype = "fp8_e5m2"
|
||||||
|
if key_cache.dtype == torch.float8_e4m3fn:
|
||||||
|
kv_cache_dtype = "fp8_e4m3"
|
||||||
|
|
||||||
ipex.llm.modules.PagedAttention.reshape_and_cache(
|
ipex.llm.modules.PagedAttention.reshape_and_cache(
|
||||||
key, value, key_cache, value_cache, slots, k_scale=k_scale, v_scale=v_scale
|
key,
|
||||||
|
value,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
slots,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
k_scale=k_scale,
|
||||||
|
v_scale=v_scale,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
|
Loading…
Reference in New Issue
Block a user