mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
add softcap and slidingwindow
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
102e29902a
commit
d9e47b651c
@ -8,6 +8,9 @@ from text_generation_server.models.globals import (
|
||||
BLOCK_SIZE,
|
||||
)
|
||||
|
||||
if ATTENTION == "flashdecoding-ipex":
|
||||
SUPPORTS_WINDOWING = True
|
||||
else:
|
||||
SUPPORTS_WINDOWING = False
|
||||
|
||||
|
||||
@ -25,8 +28,6 @@ def attention(
|
||||
causal: bool = True,
|
||||
softcap: Optional[float] = None,
|
||||
):
|
||||
if softcap is not None:
|
||||
raise NotImplementedError("softcap is not available in IPEX")
|
||||
|
||||
out = torch.empty_like(query)
|
||||
kv_cache_dtype = "auto"
|
||||
@ -37,6 +38,7 @@ def attention(
|
||||
|
||||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||
if ATTENTION == "flashdecoding-ipex":
|
||||
window_size_right = -1 if window_size_left == -1 else 0
|
||||
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
|
||||
out,
|
||||
query.contiguous() if query.device.type == "xpu" else query,
|
||||
@ -50,11 +52,18 @@ def attention(
|
||||
causal,
|
||||
block_tables,
|
||||
None,
|
||||
window_size_left=window_size_left,
|
||||
window_size_right=window_size_right,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
k_scale=kv_scales.key_scale_cpu,
|
||||
v_scale=kv_scales.value_scale_cpu,
|
||||
softcap=softcap,
|
||||
)
|
||||
else:
|
||||
if softcap is not None:
|
||||
raise NotImplementedError(
|
||||
"softcap is not available in IPEX paged attention"
|
||||
)
|
||||
ipex.llm.functional.varlen_attention(
|
||||
query.contiguous() if query.device.type == "xpu" else query,
|
||||
key.contiguous() if key.device.type == "xpu" else key,
|
||||
@ -88,17 +97,14 @@ def paged_attention(
|
||||
softcap: Optional[float] = None,
|
||||
window_size_left: Optional[int] = -1,
|
||||
):
|
||||
if softcap is not None:
|
||||
raise NotImplementedError("softcap is not available in IPEX")
|
||||
|
||||
out = torch.empty_like(query)
|
||||
kv_cache_dtype = "auto"
|
||||
if kv_cache.key.dtype == torch.float8_e5m2:
|
||||
kv_cache_dtype = "fp8_e5m2"
|
||||
if kv_cache.key.dtype == torch.float8_e4m3fn:
|
||||
kv_cache_dtype = "fp8_e4m3"
|
||||
|
||||
if ATTENTION == "flashdecoding-ipex":
|
||||
window_size_right = -1 if window_size_left == -1 else 0
|
||||
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
|
||||
out,
|
||||
query.contiguous() if query.device.type == "xpu" else query,
|
||||
@ -112,12 +118,19 @@ def paged_attention(
|
||||
True,
|
||||
block_tables,
|
||||
None,
|
||||
window_size_left=window_size_left,
|
||||
window_size_right=window_size_right,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
k_scale=kv_scales.key_scale_cpu,
|
||||
v_scale=kv_scales.value_scale_cpu,
|
||||
softcap=softcap,
|
||||
)
|
||||
else:
|
||||
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
|
||||
if softcap is not None:
|
||||
raise NotImplementedError(
|
||||
"softcap is not available in IPEX paged attention"
|
||||
)
|
||||
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
||||
out,
|
||||
query,
|
||||
@ -130,9 +143,6 @@ def paged_attention(
|
||||
BLOCK_SIZE,
|
||||
max_s,
|
||||
None,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
k_scale=kv_scales.key_scale_cpu,
|
||||
v_scale=kv_scales.value_scale_cpu,
|
||||
)
|
||||
return out
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user