diff --git a/server/text_generation_server/layers/attention/ipex.py b/server/text_generation_server/layers/attention/ipex.py index ee608e71..479d6566 100644 --- a/server/text_generation_server/layers/attention/ipex.py +++ b/server/text_generation_server/layers/attention/ipex.py @@ -8,7 +8,10 @@ from text_generation_server.models.globals import ( BLOCK_SIZE, ) -SUPPORTS_WINDOWING = False +if ATTENTION == "flashdecoding-ipex": + SUPPORTS_WINDOWING = True +else: + SUPPORTS_WINDOWING = False def attention( @@ -25,8 +28,6 @@ def attention( causal: bool = True, softcap: Optional[float] = None, ): - if softcap is not None: - raise NotImplementedError("softcap is not available in IPEX") out = torch.empty_like(query) kv_cache_dtype = "auto" @@ -37,6 +38,7 @@ def attention( # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. if ATTENTION == "flashdecoding-ipex": + window_size_right = -1 if window_size_left == -1 else 0 ipex.llm.modules.PagedAttention.flash_attn_varlen_func( out, query.contiguous() if query.device.type == "xpu" else query, @@ -50,11 +52,18 @@ def attention( causal, block_tables, None, + window_size_left=window_size_left, + window_size_right=window_size_right, kv_cache_dtype=kv_cache_dtype, k_scale=kv_scales.key_scale_cpu, v_scale=kv_scales.value_scale_cpu, + softcap=softcap, ) else: + if softcap is not None: + raise NotImplementedError( + "softcap is not available in IPEX paged attention" + ) ipex.llm.functional.varlen_attention( query.contiguous() if query.device.type == "xpu" else query, key.contiguous() if key.device.type == "xpu" else key, @@ -88,17 +97,14 @@ def paged_attention( softcap: Optional[float] = None, window_size_left: Optional[int] = -1, ): - if softcap is not None: - raise NotImplementedError("softcap is not available in IPEX") - out = torch.empty_like(query) kv_cache_dtype = "auto" if kv_cache.key.dtype == torch.float8_e5m2: kv_cache_dtype = "fp8_e5m2" if kv_cache.key.dtype == torch.float8_e4m3fn: kv_cache_dtype = "fp8_e4m3" - if ATTENTION == "flashdecoding-ipex": + window_size_right = -1 if window_size_left == -1 else 0 ipex.llm.modules.PagedAttention.flash_attn_varlen_func( out, query.contiguous() if query.device.type == "xpu" else query, @@ -112,12 +118,19 @@ def paged_attention( True, block_tables, None, + window_size_left=window_size_left, + window_size_right=window_size_right, kv_cache_dtype=kv_cache_dtype, k_scale=kv_scales.key_scale_cpu, v_scale=kv_scales.value_scale_cpu, + softcap=softcap, ) else: input_lengths = seqlen.input_lengths + seqlen.cache_lengths + if softcap is not None: + raise NotImplementedError( + "softcap is not available in IPEX paged attention" + ) ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( out, query, @@ -130,9 +143,6 @@ def paged_attention( BLOCK_SIZE, max_s, None, - kv_cache_dtype=kv_cache_dtype, - k_scale=kv_scales.key_scale_cpu, - v_scale=kv_scales.value_scale_cpu, ) return out