add softcap and slidingwindow

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2025-04-07 22:42:19 -07:00
parent 102e29902a
commit d9e47b651c

View File

@ -8,7 +8,10 @@ from text_generation_server.models.globals import (
BLOCK_SIZE,
)
SUPPORTS_WINDOWING = False
if ATTENTION == "flashdecoding-ipex":
SUPPORTS_WINDOWING = True
else:
SUPPORTS_WINDOWING = False
def attention(
@ -25,8 +28,6 @@ def attention(
causal: bool = True,
softcap: Optional[float] = None,
):
if softcap is not None:
raise NotImplementedError("softcap is not available in IPEX")
out = torch.empty_like(query)
kv_cache_dtype = "auto"
@ -37,6 +38,7 @@ def attention(
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
if ATTENTION == "flashdecoding-ipex":
window_size_right = -1 if window_size_left == -1 else 0
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
out,
query.contiguous() if query.device.type == "xpu" else query,
@ -50,11 +52,18 @@ def attention(
causal,
block_tables,
None,
window_size_left=window_size_left,
window_size_right=window_size_right,
kv_cache_dtype=kv_cache_dtype,
k_scale=kv_scales.key_scale_cpu,
v_scale=kv_scales.value_scale_cpu,
softcap=softcap,
)
else:
if softcap is not None:
raise NotImplementedError(
"softcap is not available in IPEX paged attention"
)
ipex.llm.functional.varlen_attention(
query.contiguous() if query.device.type == "xpu" else query,
key.contiguous() if key.device.type == "xpu" else key,
@ -88,17 +97,14 @@ def paged_attention(
softcap: Optional[float] = None,
window_size_left: Optional[int] = -1,
):
if softcap is not None:
raise NotImplementedError("softcap is not available in IPEX")
out = torch.empty_like(query)
kv_cache_dtype = "auto"
if kv_cache.key.dtype == torch.float8_e5m2:
kv_cache_dtype = "fp8_e5m2"
if kv_cache.key.dtype == torch.float8_e4m3fn:
kv_cache_dtype = "fp8_e4m3"
if ATTENTION == "flashdecoding-ipex":
window_size_right = -1 if window_size_left == -1 else 0
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
out,
query.contiguous() if query.device.type == "xpu" else query,
@ -112,12 +118,19 @@ def paged_attention(
True,
block_tables,
None,
window_size_left=window_size_left,
window_size_right=window_size_right,
kv_cache_dtype=kv_cache_dtype,
k_scale=kv_scales.key_scale_cpu,
v_scale=kv_scales.value_scale_cpu,
softcap=softcap,
)
else:
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
if softcap is not None:
raise NotImplementedError(
"softcap is not available in IPEX paged attention"
)
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
out,
query,
@ -130,9 +143,6 @@ def paged_attention(
BLOCK_SIZE,
max_s,
None,
kv_cache_dtype=kv_cache_dtype,
k_scale=kv_scales.key_scale_cpu,
v_scale=kv_scales.value_scale_cpu,
)
return out