mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 06:42:10 +00:00
add softcap and slidingwindow
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
102e29902a
commit
d9e47b651c
@ -8,6 +8,9 @@ from text_generation_server.models.globals import (
|
|||||||
BLOCK_SIZE,
|
BLOCK_SIZE,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if ATTENTION == "flashdecoding-ipex":
|
||||||
|
SUPPORTS_WINDOWING = True
|
||||||
|
else:
|
||||||
SUPPORTS_WINDOWING = False
|
SUPPORTS_WINDOWING = False
|
||||||
|
|
||||||
|
|
||||||
@ -25,8 +28,6 @@ def attention(
|
|||||||
causal: bool = True,
|
causal: bool = True,
|
||||||
softcap: Optional[float] = None,
|
softcap: Optional[float] = None,
|
||||||
):
|
):
|
||||||
if softcap is not None:
|
|
||||||
raise NotImplementedError("softcap is not available in IPEX")
|
|
||||||
|
|
||||||
out = torch.empty_like(query)
|
out = torch.empty_like(query)
|
||||||
kv_cache_dtype = "auto"
|
kv_cache_dtype = "auto"
|
||||||
@ -37,6 +38,7 @@ def attention(
|
|||||||
|
|
||||||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||||
if ATTENTION == "flashdecoding-ipex":
|
if ATTENTION == "flashdecoding-ipex":
|
||||||
|
window_size_right = -1 if window_size_left == -1 else 0
|
||||||
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
|
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
|
||||||
out,
|
out,
|
||||||
query.contiguous() if query.device.type == "xpu" else query,
|
query.contiguous() if query.device.type == "xpu" else query,
|
||||||
@ -50,11 +52,18 @@ def attention(
|
|||||||
causal,
|
causal,
|
||||||
block_tables,
|
block_tables,
|
||||||
None,
|
None,
|
||||||
|
window_size_left=window_size_left,
|
||||||
|
window_size_right=window_size_right,
|
||||||
kv_cache_dtype=kv_cache_dtype,
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
k_scale=kv_scales.key_scale_cpu,
|
k_scale=kv_scales.key_scale_cpu,
|
||||||
v_scale=kv_scales.value_scale_cpu,
|
v_scale=kv_scales.value_scale_cpu,
|
||||||
|
softcap=softcap,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
if softcap is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"softcap is not available in IPEX paged attention"
|
||||||
|
)
|
||||||
ipex.llm.functional.varlen_attention(
|
ipex.llm.functional.varlen_attention(
|
||||||
query.contiguous() if query.device.type == "xpu" else query,
|
query.contiguous() if query.device.type == "xpu" else query,
|
||||||
key.contiguous() if key.device.type == "xpu" else key,
|
key.contiguous() if key.device.type == "xpu" else key,
|
||||||
@ -88,17 +97,14 @@ def paged_attention(
|
|||||||
softcap: Optional[float] = None,
|
softcap: Optional[float] = None,
|
||||||
window_size_left: Optional[int] = -1,
|
window_size_left: Optional[int] = -1,
|
||||||
):
|
):
|
||||||
if softcap is not None:
|
|
||||||
raise NotImplementedError("softcap is not available in IPEX")
|
|
||||||
|
|
||||||
out = torch.empty_like(query)
|
out = torch.empty_like(query)
|
||||||
kv_cache_dtype = "auto"
|
kv_cache_dtype = "auto"
|
||||||
if kv_cache.key.dtype == torch.float8_e5m2:
|
if kv_cache.key.dtype == torch.float8_e5m2:
|
||||||
kv_cache_dtype = "fp8_e5m2"
|
kv_cache_dtype = "fp8_e5m2"
|
||||||
if kv_cache.key.dtype == torch.float8_e4m3fn:
|
if kv_cache.key.dtype == torch.float8_e4m3fn:
|
||||||
kv_cache_dtype = "fp8_e4m3"
|
kv_cache_dtype = "fp8_e4m3"
|
||||||
|
|
||||||
if ATTENTION == "flashdecoding-ipex":
|
if ATTENTION == "flashdecoding-ipex":
|
||||||
|
window_size_right = -1 if window_size_left == -1 else 0
|
||||||
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
|
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
|
||||||
out,
|
out,
|
||||||
query.contiguous() if query.device.type == "xpu" else query,
|
query.contiguous() if query.device.type == "xpu" else query,
|
||||||
@ -112,12 +118,19 @@ def paged_attention(
|
|||||||
True,
|
True,
|
||||||
block_tables,
|
block_tables,
|
||||||
None,
|
None,
|
||||||
|
window_size_left=window_size_left,
|
||||||
|
window_size_right=window_size_right,
|
||||||
kv_cache_dtype=kv_cache_dtype,
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
k_scale=kv_scales.key_scale_cpu,
|
k_scale=kv_scales.key_scale_cpu,
|
||||||
v_scale=kv_scales.value_scale_cpu,
|
v_scale=kv_scales.value_scale_cpu,
|
||||||
|
softcap=softcap,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
|
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
|
||||||
|
if softcap is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"softcap is not available in IPEX paged attention"
|
||||||
|
)
|
||||||
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
||||||
out,
|
out,
|
||||||
query,
|
query,
|
||||||
@ -130,9 +143,6 @@ def paged_attention(
|
|||||||
BLOCK_SIZE,
|
BLOCK_SIZE,
|
||||||
max_s,
|
max_s,
|
||||||
None,
|
None,
|
||||||
kv_cache_dtype=kv_cache_dtype,
|
|
||||||
k_scale=kv_scales.key_scale_cpu,
|
|
||||||
v_scale=kv_scales.value_scale_cpu,
|
|
||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user