diff --git a/Dockerfile_intel b/Dockerfile_intel index 2b41fd8b..d6556bea 100644 --- a/Dockerfile_intel +++ b/Dockerfile_intel @@ -224,9 +224,9 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/loca COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher FROM ${PLATFORM} AS final -ENV ATTENTION=paged -ENV PREFIX_CACHING=0 -ENV PREFILL_CHUNKING=0 +ENV ATTENTION=flashdecoding-ipex +ENV PREFIX_CACHING=1 +ENV PREFILL_CHUNKING=1 ENV CUDA_GRAPHS=0 ENTRYPOINT ["text-generation-launcher"] CMD ["--json-output"] diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 18badeaf..394cc1e6 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -143,7 +143,9 @@ fn resolve_attention(config: &Option, lora_adapters: &Option) -> } } - let fallback_attention = if matches!(compute_capability, Some((major, _)) if major < 8) { + let fallback_attention = if compute_capability.is_none() + || matches!(compute_capability, Some((major, _)) if major < 8) + { "paged" } else { "flashdecoding" diff --git a/server/text_generation_server/layers/attention/ipex.py b/server/text_generation_server/layers/attention/ipex.py index 677f3f56..54422308 100644 --- a/server/text_generation_server/layers/attention/ipex.py +++ b/server/text_generation_server/layers/attention/ipex.py @@ -1,9 +1,12 @@ import intel_extension_for_pytorch as ipex import torch from text_generation_server.layers.attention.kv_cache import KVCache, KVScales -from text_generation_server.models.flash_causal_lm import BLOCK_SIZE from text_generation_server.layers.attention import Seqlen from typing import Optional +from text_generation_server.models.globals import ( + ATTENTION, + BLOCK_SIZE, +) SUPPORTS_WINDOWING = False @@ -28,22 +31,38 @@ def attention( out = torch.empty_like(query) # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. - ipex.llm.functional.varlen_attention( - query.contiguous() if query.device.type == "xpu" else query, - key.contiguous() if key.device.type == "xpu" else key, - value.contiguous() if value.device.type == "xpu" else value, - out, - seqlen.cu_seqlen_q, - seqlen.cu_seqlen_q, - seqlen.max_q, - seqlen.max_q, - 0.0, - softmax_scale, - False, - causal, - False, - None, - ) + if ATTENTION == "flashdecoding-ipex": + ipex.llm.modules.PagedAttention.flash_attn_varlen_func( + out, + query.contiguous() if query.device.type == "xpu" else query, + kv_cache.key, + kv_cache.value, + seqlen.cu_seqlen_q, + seqlen.cu_seqlen_k, + seqlen.max_q, + seqlen.max_k, + softmax_scale, + causal, + block_tables, + None, + ) + else: + ipex.llm.functional.varlen_attention( + query.contiguous() if query.device.type == "xpu" else query, + key.contiguous() if key.device.type == "xpu" else key, + value.contiguous() if value.device.type == "xpu" else value, + out, + seqlen.cu_seqlen_q, + seqlen.cu_seqlen_q, + seqlen.max_q, + seqlen.max_q, + 0.0, + softmax_scale, + False, + causal, + False, + None, + ) return out @@ -64,20 +83,37 @@ def paged_attention( raise NotImplementedError("softcap is not available in IPEX") out = torch.empty_like(query) - input_lengths = seqlen.input_lengths + seqlen.cache_lengths - ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( - out, - query, - kv_cache.key, - kv_cache.value, - kv_head_mapping, - softmax_scale, - block_tables, - input_lengths, - BLOCK_SIZE, - max_s, - None, - ) + + if ATTENTION == "flashdecoding-ipex": + ipex.llm.modules.PagedAttention.flash_attn_varlen_func( + out, + query.contiguous() if query.device.type == "xpu" else query, + kv_cache.key, + kv_cache.value, + seqlen.cu_seqlen_q, + seqlen.cu_seqlen_k, + seqlen.max_q, + seqlen.max_k, + softmax_scale, + True, + block_tables, + None, + ) + else: + input_lengths = seqlen.input_lengths + seqlen.cache_lengths + ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( + out, + query, + kv_cache.key, + kv_cache.value, + kv_head_mapping, + softmax_scale, + block_tables, + input_lengths, + BLOCK_SIZE, + max_s, + None, + ) return out diff --git a/server/text_generation_server/layers/attention/kv_cache.py b/server/text_generation_server/layers/attention/kv_cache.py index 93d74732..00308601 100644 --- a/server/text_generation_server/layers/attention/kv_cache.py +++ b/server/text_generation_server/layers/attention/kv_cache.py @@ -66,7 +66,9 @@ class KVCache: else: x = BLOCK_SIZE // element_size - if ATTENTION in {"flashdecoding", "flashinfer"}: + if ATTENTION in {"flashdecoding", "flashinfer"} or ( + ATTENTION == "flashdecoding-ipex" and device.type == "xpu" + ): self.kv_cache = ( torch.empty( (num_blocks, BLOCK_SIZE, num_heads, head_size), @@ -80,6 +82,7 @@ class KVCache: ), ) elif SYSTEM == "ipex" and device == torch.device("cpu"): + # ipex cpu flashdecoding kernel and paged attention kernel share same layout self.kv_cache = ( torch.empty( (num_blocks, num_heads, BLOCK_SIZE, head_size), @@ -187,6 +190,12 @@ class KVCache: shape = key_cache.shape key_cache.view(-1, shape[-2], shape[-1])[slots] = key value_cache.view(-1, shape[-2], shape[-1])[slots] = value + elif ATTENTION == "flashdecoding-ipex" and key.device.type == "xpu": + import intel_extension_for_pytorch as ipex + + ipex.llm.modules.PagedAttention.reshape_and_cache_flash( + key, value, key_cache, value_cache, slots + ) else: paged_reshape_and_cache(key, value, key_cache, value_cache, slots) diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index ce879141..889de028 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -14,13 +14,17 @@ PREFIX_CACHING = os.environ["PREFIX_CACHING"].lower() in { } PREFILL_CHUNKING = os.getenv("PREFILL_CHUNKING", "1").lower() in {"1", "true"} log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}") -_expected = {"paged", "flashdecoding", "flashinfer"} +_expected = {"paged", "flashdecoding", "flashdecoding-ipex", "flashinfer"} assert ( ATTENTION in _expected ), f"Attention is not valid {ATTENTION}, expected {_expected}" log_master(logger.info, f"Using Attention = {ATTENTION}") -if PREFIX_CACHING and ATTENTION not in {"flashinfer", "flashdecoding"}: +if PREFIX_CACHING and ATTENTION not in { + "flashinfer", + "flashdecoding", + "flashdecoding-ipex", +}: raise RuntimeError("Prefix caching is only supported with flashinfer") MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None @@ -28,12 +32,15 @@ TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.95")) assert TGI_WIGGLE_ROOM > 0 assert TGI_WIGGLE_ROOM < 1 + # This is overridden by the cli BLOCK_SIZE: int if ATTENTION == "flashdecoding": BLOCK_SIZE = 256 elif ATTENTION == "flashinfer": BLOCK_SIZE = 1 +elif ATTENTION == "flashdecoding-ipex": + BLOCK_SIZE = 64 else: BLOCK_SIZE = 16 diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index 805fd771..af4d1f08 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -79,10 +79,13 @@ class Model(ABC): "Prefill chunking will be turned off", ) support_chunking = False - if ATTENTION not in ["flashinfer", "flashdecoding"] and support_chunking: + if ( + ATTENTION not in ["flashinfer", "flashdecoding", "flashdecoding-ipex"] + and support_chunking + ): log_master( logger.warning, - "Prefill chunking is only supported with `flashinfer` or `flashdecoding` attention types.", + "Prefill chunking is only supported with `flashinfer` or `flashdecoding` or `flashdecoding-ipex` attention types.", ) support_chunking = False