diff --git a/server/text_generation_server/layers/attention/ipex.py b/server/text_generation_server/layers/attention/ipex.py index 677f3f56..9b37c1b5 100644 --- a/server/text_generation_server/layers/attention/ipex.py +++ b/server/text_generation_server/layers/attention/ipex.py @@ -1,9 +1,12 @@ import intel_extension_for_pytorch as ipex import torch from text_generation_server.layers.attention.kv_cache import KVCache, KVScales -from text_generation_server.models.flash_causal_lm import BLOCK_SIZE from text_generation_server.layers.attention import Seqlen from typing import Optional +from text_generation_server.models.globals import ( + ATTENTION, + BLOCK_SIZE, +) SUPPORTS_WINDOWING = False @@ -28,22 +31,38 @@ def attention( out = torch.empty_like(query) # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. - ipex.llm.functional.varlen_attention( - query.contiguous() if query.device.type == "xpu" else query, - key.contiguous() if key.device.type == "xpu" else key, - value.contiguous() if value.device.type == "xpu" else value, - out, - seqlen.cu_seqlen_q, - seqlen.cu_seqlen_q, - seqlen.max_q, - seqlen.max_q, - 0.0, - softmax_scale, - False, - causal, - False, - None, - ) + if ATTENTION == "flashdecoding": + ipex.llm.modules.PagedAttention.flash_attn_varlen_func( + out, + query, + kv_cache.key, + kv_cache.value, + seqlen.cu_seqlen_q, + seqlen.cu_seqlen_k, + seqlen.max_q, + seqlen.max_k, + softmax_scale, + causal, + block_tables, + None, + ) + else: + ipex.llm.functional.varlen_attention( + query.contiguous() if query.device.type == "xpu" else query, + key.contiguous() if key.device.type == "xpu" else key, + value.contiguous() if value.device.type == "xpu" else value, + out, + seqlen.cu_seqlen_q, + seqlen.cu_seqlen_q, + seqlen.max_q, + seqlen.max_q, + 0.0, + softmax_scale, + False, + causal, + False, + None, + ) return out @@ -64,20 +83,37 @@ def paged_attention( raise NotImplementedError("softcap is not available in IPEX") out = torch.empty_like(query) - input_lengths = seqlen.input_lengths + seqlen.cache_lengths - ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( - out, - query, - kv_cache.key, - kv_cache.value, - kv_head_mapping, - softmax_scale, - block_tables, - input_lengths, - BLOCK_SIZE, - max_s, - None, - ) + + if ATTENTION == "flashdecoding": + ipex.llm.modules.PagedAttention.flash_attn_varlen_func( + out, + query, + kv_cache.key, + kv_cache.value, + seqlen.cu_seqlen_q, + seqlen.cu_seqlen_k, + seqlen.max_q, + seqlen.max_k, + softmax_scale, + True, + block_tables, + None, + ) + else: + input_lengths = seqlen.input_lengths + seqlen.cache_lengths + ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( + out, + query, + kv_cache.key, + kv_cache.value, + kv_head_mapping, + softmax_scale, + block_tables, + input_lengths, + BLOCK_SIZE, + max_s, + None, + ) return out diff --git a/server/text_generation_server/layers/attention/kv_cache.py b/server/text_generation_server/layers/attention/kv_cache.py index cad1d98a..191771ca 100644 --- a/server/text_generation_server/layers/attention/kv_cache.py +++ b/server/text_generation_server/layers/attention/kv_cache.py @@ -66,7 +66,7 @@ class KVCache: else: x = BLOCK_SIZE // element_size - if ATTENTION in {"flashdecoding", "flashinfer"}: + if ATTENTION in {"flashdecoding", "flashinfer"} and SYSTEM != "ipex": self.kv_cache = ( torch.empty( (num_blocks, BLOCK_SIZE, num_heads, head_size), @@ -174,7 +174,7 @@ class KVCache: scalar=True, )[0] - if ATTENTION in {"flashdecoding", "flashinfer"}: + if ATTENTION in {"flashdecoding", "flashinfer"} and SYSTEM != "ipex": key = key.to(key_cache.dtype) value = value.to(value_cache.dtype) if key_cache.dtype in {torch.float8_e4m3fn, torch.float8_e5m2}: