diff --git a/server/text_generation_server/layers/attention/ipex.py b/server/text_generation_server/layers/attention/ipex.py index 9b37c1b5..817dfbd3 100644 --- a/server/text_generation_server/layers/attention/ipex.py +++ b/server/text_generation_server/layers/attention/ipex.py @@ -34,7 +34,7 @@ def attention( if ATTENTION == "flashdecoding": ipex.llm.modules.PagedAttention.flash_attn_varlen_func( out, - query, + query.contiguous() if query.device.type == "xpu" else query, kv_cache.key, kv_cache.value, seqlen.cu_seqlen_q, @@ -87,7 +87,7 @@ def paged_attention( if ATTENTION == "flashdecoding": ipex.llm.modules.PagedAttention.flash_attn_varlen_func( out, - query, + query.contiguous() if query.device.type == "xpu" else query, kv_cache.key, kv_cache.value, seqlen.cu_seqlen_q, diff --git a/server/text_generation_server/layers/attention/kv_cache.py b/server/text_generation_server/layers/attention/kv_cache.py index 191771ca..0ff6522e 100644 --- a/server/text_generation_server/layers/attention/kv_cache.py +++ b/server/text_generation_server/layers/attention/kv_cache.py @@ -66,7 +66,9 @@ class KVCache: else: x = BLOCK_SIZE // element_size - if ATTENTION in {"flashdecoding", "flashinfer"} and SYSTEM != "ipex": + if ATTENTION in {"flashdecoding", "flashinfer"} and not ( + SYSTEM == "ipex" and device == torch.device("cpu") + ): self.kv_cache = ( torch.empty( (num_blocks, BLOCK_SIZE, num_heads, head_size), @@ -174,7 +176,9 @@ class KVCache: scalar=True, )[0] - if ATTENTION in {"flashdecoding", "flashinfer"} and SYSTEM != "ipex": + if ATTENTION in {"flashdecoding", "flashinfer"} and not ( + SYSTEM == "ipex" and key.device == torch.device("cpu") + ): key = key.to(key_cache.dtype) value = value.to(value_cache.dtype) if key_cache.dtype in {torch.float8_e4m3fn, torch.float8_e5m2}: diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index 4ac6a6b4..3561f13f 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -4,6 +4,7 @@ from loguru import logger from typing import Dict, Optional from text_generation_server.utils.log import log_master +from text_generation_server.utils.import_utils import SYSTEM ATTENTION = os.environ["ATTENTION"] # default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0" @@ -27,9 +28,12 @@ TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.90")) assert TGI_WIGGLE_ROOM > 0 assert TGI_WIGGLE_ROOM < 1 + # This is overridden by the cli BLOCK_SIZE: int -if ATTENTION == "flashdecoding": +if SYSTEM == "ipex": + BLOCK_SIZE = 16 +elif ATTENTION == "flashdecoding": BLOCK_SIZE = 256 elif ATTENTION == "flashinfer": BLOCK_SIZE = 1