From 065f87a3371077da1ed707b7ea717bdd946fe2dc Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Sat, 29 Mar 2025 02:31:38 -0700 Subject: [PATCH 1/8] IPEX support FP8 kvcache Signed-off-by: Wang, Yi A --- Dockerfile_intel | 4 +++- .../layers/attention/ipex.py | 6 +++++ .../layers/attention/kv_cache.py | 24 ++++++++++++++----- 3 files changed, 27 insertions(+), 7 deletions(-) diff --git a/Dockerfile_intel b/Dockerfile_intel index 5bf7632ce..b015760cb 100644 --- a/Dockerfile_intel +++ b/Dockerfile_intel @@ -119,7 +119,9 @@ ENV CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0 ENV TORCH_DEVICE_BACKEND_AUTOLOAD=0 RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/xpu/oneccl_bind_pt-2.6.0%2Bxpu-cp311-cp311-linux_x86_64.whl -RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/xpu/intel_extension_for_pytorch-2.6.10%2Bxpu-cp311-cp311-linux_x86_64.whl +RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout d5a7036316a01ea8220eb4da78a2207c423a1166 +RUN sed -i 's/VERSION_MINOR 7/VERSION_MINOR 6/' intel-extension-for-pytorch/version.txt +RUN cd intel-extension-for-pytorch && git submodule update --init --recursive && USE_AOT_DEVLIST='pvc,ats-m150' BUILD_SEPARATE_OPS=OFF BUILD_WITH_CPU=OFF USE_XETLA=ON python setup.py install && rm -rf /usr/src/intel-extension-for-pytorch # Install benchmarker COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark # Install router diff --git a/server/text_generation_server/layers/attention/ipex.py b/server/text_generation_server/layers/attention/ipex.py index 2b89060e9..0b44072cc 100644 --- a/server/text_generation_server/layers/attention/ipex.py +++ b/server/text_generation_server/layers/attention/ipex.py @@ -45,6 +45,8 @@ def attention( causal, block_tables, None, + k_scale=kv_scales.key_scale_cpu, + v_scale=kv_scales.value_scale_cpu, ) else: ipex.llm.functional.varlen_attention( @@ -99,6 +101,8 @@ def paged_attention( True, block_tables, None, + k_scale=kv_scales.key_scale_cpu, + v_scale=kv_scales.value_scale_cpu, ) else: input_lengths = seqlen.input_lengths + seqlen.cache_lengths @@ -114,6 +118,8 @@ def paged_attention( BLOCK_SIZE, max_s, None, + k_scale=kv_scales.key_scale_cpu, + v_scale=kv_scales.value_scale_cpu, ) return out diff --git a/server/text_generation_server/layers/attention/kv_cache.py b/server/text_generation_server/layers/attention/kv_cache.py index aaf4d2b22..7082d3ae1 100644 --- a/server/text_generation_server/layers/attention/kv_cache.py +++ b/server/text_generation_server/layers/attention/kv_cache.py @@ -68,15 +68,20 @@ class KVCache: if dtype in {torch.float8_e5m2, torch.float8_e4m3fn}: if not ( (ATTENTION == "flashinfer" and SYSTEM == "cuda") - or (ATTENTION == "paged" and SYSTEM in ("cuda", "rocm")) + or (ATTENTION == "paged" and SYSTEM in ("cuda", "rocm", "ipex")) + or (ATTENTION == "flashdecoding-ipex") ): raise ValueError( - "FP8 KV cache is currently only supported for flashinfer on CUDA and paged attention on CUDA and ROCm. " + "FP8 KV cache is currently only supported for flashinfer on CUDA and paged attention on CUDA, ROCm and INTEL IPEX and flashdecoding in Intel IPEX " ) if SYSTEM == "rocm" and dtype == torch.float8_e5m2: raise ValueError( "float8_e5m2 FP8 KV cache is not supported on AMD ROCm" ) + if device.type == "cpu" and dtype == torch.float8_e4m3fn: + raise ValueError( + "float8_e4m3fn FP8 KV cache is not supported on Intel IPEX CPU" + ) element_size = torch.tensor([], dtype=dtype).element_size() if SYSTEM == "ipex" and device.type == "xpu": @@ -133,7 +138,8 @@ class KVCache: return False elif self.dtype == torch.float8_e4m3fn and ( (ATTENTION in ("paged", "flashinfer") and SYSTEM == "cuda") - or (ATTENTION == "paged" and SYSTEM == "rocm") + or (ATTENTION == "paged" and SYSTEM in ["rocm", "ipex"]) + or (ATTENTION == "flashdecoding-ipex") ): log_once(logger.info, "Using FP8 KV cache scales") return True @@ -141,7 +147,7 @@ class KVCache: # We have scales, but not the correct FP8 cache type, so warn once. log_once( logger.info, - "Ignoring FP8 KV cache scales, supported only for float8_e4m3fn KV cache with flashinfer on CUDA and paged attention on ROCm", + "Ignoring FP8 KV cache scales, supported only for float8_e4m3fn KV cache with flashinfer on CUDA and paged attention on ROCm/IPEX and flashdecoding on IPEX", ) return False @@ -208,7 +214,13 @@ class KVCache: import intel_extension_for_pytorch as ipex ipex.llm.modules.PagedAttention.reshape_and_cache_flash( - key, value, key_cache, value_cache, slots + key, + value, + key_cache, + value_cache, + slots, + k_scale=kv_scales.key_scale_cpu, + v_scale=kv_scales.value_scale_cpu, ) else: paged_reshape_and_cache( @@ -268,7 +280,7 @@ def paged_reshape_and_cache( import intel_extension_for_pytorch as ipex ipex.llm.modules.PagedAttention.reshape_and_cache( - key, value, key_cache, value_cache, slots + key, value, key_cache, value_cache, slots, k_scale=k_scale, v_scale=v_scale ) else: raise NotImplementedError( From 102e29902a5bf04b42a87cae5af450b9daa67e3b Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Wed, 2 Apr 2025 19:29:01 -0700 Subject: [PATCH 2/8] add kvcache dtype Signed-off-by: Wang, Yi A --- .../layers/attention/ipex.py | 13 ++++++++++++ .../layers/attention/kv_cache.py | 21 ++++++++++++++++++- 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/layers/attention/ipex.py b/server/text_generation_server/layers/attention/ipex.py index 0b44072cc..ee608e71f 100644 --- a/server/text_generation_server/layers/attention/ipex.py +++ b/server/text_generation_server/layers/attention/ipex.py @@ -29,6 +29,11 @@ def attention( raise NotImplementedError("softcap is not available in IPEX") out = torch.empty_like(query) + kv_cache_dtype = "auto" + if kv_cache.key.dtype == torch.float8_e5m2: + kv_cache_dtype = "fp8_e5m2" + if kv_cache.key.dtype == torch.float8_e4m3fn: + kv_cache_dtype = "fp8_e4m3" # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. if ATTENTION == "flashdecoding-ipex": @@ -45,6 +50,7 @@ def attention( causal, block_tables, None, + kv_cache_dtype=kv_cache_dtype, k_scale=kv_scales.key_scale_cpu, v_scale=kv_scales.value_scale_cpu, ) @@ -86,6 +92,11 @@ def paged_attention( raise NotImplementedError("softcap is not available in IPEX") out = torch.empty_like(query) + kv_cache_dtype = "auto" + if kv_cache.key.dtype == torch.float8_e5m2: + kv_cache_dtype = "fp8_e5m2" + if kv_cache.key.dtype == torch.float8_e4m3fn: + kv_cache_dtype = "fp8_e4m3" if ATTENTION == "flashdecoding-ipex": ipex.llm.modules.PagedAttention.flash_attn_varlen_func( @@ -101,6 +112,7 @@ def paged_attention( True, block_tables, None, + kv_cache_dtype=kv_cache_dtype, k_scale=kv_scales.key_scale_cpu, v_scale=kv_scales.value_scale_cpu, ) @@ -118,6 +130,7 @@ def paged_attention( BLOCK_SIZE, max_s, None, + kv_cache_dtype=kv_cache_dtype, k_scale=kv_scales.key_scale_cpu, v_scale=kv_scales.value_scale_cpu, ) diff --git a/server/text_generation_server/layers/attention/kv_cache.py b/server/text_generation_server/layers/attention/kv_cache.py index 7082d3ae1..a37ecd4c3 100644 --- a/server/text_generation_server/layers/attention/kv_cache.py +++ b/server/text_generation_server/layers/attention/kv_cache.py @@ -213,12 +213,18 @@ class KVCache: elif ATTENTION == "flashdecoding-ipex" and key.device.type == "xpu": import intel_extension_for_pytorch as ipex + kv_cache_dtype = "auto" + if key_cache.dtype == torch.float8_e5m2: + kv_cache_dtype = "fp8_e5m2" + if key_cache.dtype == torch.float8_e4m3fn: + kv_cache_dtype = "fp8_e4m3" ipex.llm.modules.PagedAttention.reshape_and_cache_flash( key, value, key_cache, value_cache, slots, + kv_cache_dtype=kv_cache_dtype, k_scale=kv_scales.key_scale_cpu, v_scale=kv_scales.value_scale_cpu, ) @@ -279,8 +285,21 @@ def paged_reshape_and_cache( elif SYSTEM == "ipex": import intel_extension_for_pytorch as ipex + kv_cache_dtype = "auto" + if key_cache.dtype == torch.float8_e5m2: + kv_cache_dtype = "fp8_e5m2" + if key_cache.dtype == torch.float8_e4m3fn: + kv_cache_dtype = "fp8_e4m3" + ipex.llm.modules.PagedAttention.reshape_and_cache( - key, value, key_cache, value_cache, slots, k_scale=k_scale, v_scale=v_scale + key, + value, + key_cache, + value_cache, + slots, + kv_cache_dtype=kv_cache_dtype, + k_scale=k_scale, + v_scale=v_scale, ) else: raise NotImplementedError( From d9e47b651c632e37fc997521f744839ec693fbbb Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Mon, 7 Apr 2025 22:42:19 -0700 Subject: [PATCH 3/8] add softcap and slidingwindow Signed-off-by: Wang, Yi A --- .../layers/attention/ipex.py | 30 ++++++++++++------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/server/text_generation_server/layers/attention/ipex.py b/server/text_generation_server/layers/attention/ipex.py index ee608e71f..479d65665 100644 --- a/server/text_generation_server/layers/attention/ipex.py +++ b/server/text_generation_server/layers/attention/ipex.py @@ -8,7 +8,10 @@ from text_generation_server.models.globals import ( BLOCK_SIZE, ) -SUPPORTS_WINDOWING = False +if ATTENTION == "flashdecoding-ipex": + SUPPORTS_WINDOWING = True +else: + SUPPORTS_WINDOWING = False def attention( @@ -25,8 +28,6 @@ def attention( causal: bool = True, softcap: Optional[float] = None, ): - if softcap is not None: - raise NotImplementedError("softcap is not available in IPEX") out = torch.empty_like(query) kv_cache_dtype = "auto" @@ -37,6 +38,7 @@ def attention( # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. if ATTENTION == "flashdecoding-ipex": + window_size_right = -1 if window_size_left == -1 else 0 ipex.llm.modules.PagedAttention.flash_attn_varlen_func( out, query.contiguous() if query.device.type == "xpu" else query, @@ -50,11 +52,18 @@ def attention( causal, block_tables, None, + window_size_left=window_size_left, + window_size_right=window_size_right, kv_cache_dtype=kv_cache_dtype, k_scale=kv_scales.key_scale_cpu, v_scale=kv_scales.value_scale_cpu, + softcap=softcap, ) else: + if softcap is not None: + raise NotImplementedError( + "softcap is not available in IPEX paged attention" + ) ipex.llm.functional.varlen_attention( query.contiguous() if query.device.type == "xpu" else query, key.contiguous() if key.device.type == "xpu" else key, @@ -88,17 +97,14 @@ def paged_attention( softcap: Optional[float] = None, window_size_left: Optional[int] = -1, ): - if softcap is not None: - raise NotImplementedError("softcap is not available in IPEX") - out = torch.empty_like(query) kv_cache_dtype = "auto" if kv_cache.key.dtype == torch.float8_e5m2: kv_cache_dtype = "fp8_e5m2" if kv_cache.key.dtype == torch.float8_e4m3fn: kv_cache_dtype = "fp8_e4m3" - if ATTENTION == "flashdecoding-ipex": + window_size_right = -1 if window_size_left == -1 else 0 ipex.llm.modules.PagedAttention.flash_attn_varlen_func( out, query.contiguous() if query.device.type == "xpu" else query, @@ -112,12 +118,19 @@ def paged_attention( True, block_tables, None, + window_size_left=window_size_left, + window_size_right=window_size_right, kv_cache_dtype=kv_cache_dtype, k_scale=kv_scales.key_scale_cpu, v_scale=kv_scales.value_scale_cpu, + softcap=softcap, ) else: input_lengths = seqlen.input_lengths + seqlen.cache_lengths + if softcap is not None: + raise NotImplementedError( + "softcap is not available in IPEX paged attention" + ) ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( out, query, @@ -130,9 +143,6 @@ def paged_attention( BLOCK_SIZE, max_s, None, - kv_cache_dtype=kv_cache_dtype, - k_scale=kv_scales.key_scale_cpu, - v_scale=kv_scales.value_scale_cpu, ) return out From ad15a9c0afbb3a7ead90425161836c7c2577f143 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Mon, 7 Apr 2025 22:47:11 -0700 Subject: [PATCH 4/8] kv scale in pageattn Signed-off-by: Wang, Yi A --- server/text_generation_server/layers/attention/ipex.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/server/text_generation_server/layers/attention/ipex.py b/server/text_generation_server/layers/attention/ipex.py index 479d65665..6ca02afe2 100644 --- a/server/text_generation_server/layers/attention/ipex.py +++ b/server/text_generation_server/layers/attention/ipex.py @@ -143,6 +143,8 @@ def paged_attention( BLOCK_SIZE, max_s, None, + k_scale=kv_scales.key_scale_cpu, + v_scale=kv_scales.value_scale_cpu, ) return out From 68ec7603ca92eec1694050af17e0aa4b6ac183a4 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Tue, 8 Apr 2025 00:26:53 -0700 Subject: [PATCH 5/8] remove triton installation, will be installed with torch Signed-off-by: Wang, Yi A --- Dockerfile_intel | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/Dockerfile_intel b/Dockerfile_intel index b015760cb..e35c1e6df 100644 --- a/Dockerfile_intel +++ b/Dockerfile_intel @@ -100,8 +100,6 @@ ENV HF_HOME=/data \ WORKDIR /usr/src RUN pip install torch==2.6.0 --index-url https://download.pytorch.org/whl/test/xpu -RUN pip install triton-xpu==3.2.0b1 --no-cache-dir - # Install server COPY proto proto COPY server server @@ -114,7 +112,7 @@ RUN cd server && \ ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/conda/lib ENV CCL_ZE_IPC_EXCHANGE=sockets -ENV TORCH_LLM_ALLREDUCE=1 +#ENV TORCH_LLM_ALLREDUCE=1 ENV CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0 ENV TORCH_DEVICE_BACKEND_AUTOLOAD=0 From 8d36856d5789ab69c5aecd0855dbb41d4eb6f90d Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Tue, 8 Apr 2025 20:42:28 -0700 Subject: [PATCH 6/8] install xelink lib Signed-off-by: Wang, Yi A --- Dockerfile_intel | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Dockerfile_intel b/Dockerfile_intel index e35c1e6df..b0aaaeee4 100644 --- a/Dockerfile_intel +++ b/Dockerfile_intel @@ -87,7 +87,7 @@ RUN echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https:/ RUN mv /tmp/intel-for-pytorch-gpu-dev.list /etc/apt/sources.list.d -RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt install -y xpu-smi cmake ninja-build pciutils intel-ocloc +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt install -y xpu-smi cmake ninja-build pciutils intel-ocloc libnl-genl-3-200 # Text Generation Inference base env ENV HF_HOME=/data \ @@ -112,7 +112,7 @@ RUN cd server && \ ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/conda/lib ENV CCL_ZE_IPC_EXCHANGE=sockets -#ENV TORCH_LLM_ALLREDUCE=1 +ENV TORCH_LLM_ALLREDUCE=1 ENV CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0 ENV TORCH_DEVICE_BACKEND_AUTOLOAD=0 From f8c8c3d3974077255e5f7baf1cc150d233f308d2 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Tue, 8 Apr 2025 22:42:03 -0700 Subject: [PATCH 7/8] softcap default -1.0 Signed-off-by: Wang, Yi A --- server/text_generation_server/layers/attention/ipex.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/server/text_generation_server/layers/attention/ipex.py b/server/text_generation_server/layers/attention/ipex.py index 6ca02afe2..31b745f0c 100644 --- a/server/text_generation_server/layers/attention/ipex.py +++ b/server/text_generation_server/layers/attention/ipex.py @@ -105,6 +105,8 @@ def paged_attention( kv_cache_dtype = "fp8_e4m3" if ATTENTION == "flashdecoding-ipex": window_size_right = -1 if window_size_left == -1 else 0 + if softcap is None: + softcap = -1.0 ipex.llm.modules.PagedAttention.flash_attn_varlen_func( out, query.contiguous() if query.device.type == "xpu" else query, From ce8548f5c4d49f020e102c9a7c2200244ce29a13 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Sun, 13 Apr 2025 20:02:05 -0700 Subject: [PATCH 8/8] softcap default -1.0 Signed-off-by: Wang, Yi A --- server/text_generation_server/layers/attention/ipex.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/server/text_generation_server/layers/attention/ipex.py b/server/text_generation_server/layers/attention/ipex.py index 31b745f0c..36ef2efca 100644 --- a/server/text_generation_server/layers/attention/ipex.py +++ b/server/text_generation_server/layers/attention/ipex.py @@ -39,6 +39,8 @@ def attention( # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. if ATTENTION == "flashdecoding-ipex": window_size_right = -1 if window_size_left == -1 else 0 + if softcap is None: + softcap = -1.0 ipex.llm.modules.PagedAttention.flash_attn_varlen_func( out, query.contiguous() if query.device.type == "xpu" else query,