diff --git a/Dockerfile_intel b/Dockerfile_intel index 5bf7632c..b015760c 100644 --- a/Dockerfile_intel +++ b/Dockerfile_intel @@ -119,7 +119,9 @@ ENV CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0 ENV TORCH_DEVICE_BACKEND_AUTOLOAD=0 RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/xpu/oneccl_bind_pt-2.6.0%2Bxpu-cp311-cp311-linux_x86_64.whl -RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/xpu/intel_extension_for_pytorch-2.6.10%2Bxpu-cp311-cp311-linux_x86_64.whl +RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout d5a7036316a01ea8220eb4da78a2207c423a1166 +RUN sed -i 's/VERSION_MINOR 7/VERSION_MINOR 6/' intel-extension-for-pytorch/version.txt +RUN cd intel-extension-for-pytorch && git submodule update --init --recursive && USE_AOT_DEVLIST='pvc,ats-m150' BUILD_SEPARATE_OPS=OFF BUILD_WITH_CPU=OFF USE_XETLA=ON python setup.py install && rm -rf /usr/src/intel-extension-for-pytorch # Install benchmarker COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark # Install router diff --git a/server/text_generation_server/layers/attention/ipex.py b/server/text_generation_server/layers/attention/ipex.py index 2b89060e..0b44072c 100644 --- a/server/text_generation_server/layers/attention/ipex.py +++ b/server/text_generation_server/layers/attention/ipex.py @@ -45,6 +45,8 @@ def attention( causal, block_tables, None, + k_scale=kv_scales.key_scale_cpu, + v_scale=kv_scales.value_scale_cpu, ) else: ipex.llm.functional.varlen_attention( @@ -99,6 +101,8 @@ def paged_attention( True, block_tables, None, + k_scale=kv_scales.key_scale_cpu, + v_scale=kv_scales.value_scale_cpu, ) else: input_lengths = seqlen.input_lengths + seqlen.cache_lengths @@ -114,6 +118,8 @@ def paged_attention( BLOCK_SIZE, max_s, None, + k_scale=kv_scales.key_scale_cpu, + v_scale=kv_scales.value_scale_cpu, ) return out diff --git a/server/text_generation_server/layers/attention/kv_cache.py b/server/text_generation_server/layers/attention/kv_cache.py index aaf4d2b2..7082d3ae 100644 --- a/server/text_generation_server/layers/attention/kv_cache.py +++ b/server/text_generation_server/layers/attention/kv_cache.py @@ -68,15 +68,20 @@ class KVCache: if dtype in {torch.float8_e5m2, torch.float8_e4m3fn}: if not ( (ATTENTION == "flashinfer" and SYSTEM == "cuda") - or (ATTENTION == "paged" and SYSTEM in ("cuda", "rocm")) + or (ATTENTION == "paged" and SYSTEM in ("cuda", "rocm", "ipex")) + or (ATTENTION == "flashdecoding-ipex") ): raise ValueError( - "FP8 KV cache is currently only supported for flashinfer on CUDA and paged attention on CUDA and ROCm. " + "FP8 KV cache is currently only supported for flashinfer on CUDA and paged attention on CUDA, ROCm and INTEL IPEX and flashdecoding in Intel IPEX " ) if SYSTEM == "rocm" and dtype == torch.float8_e5m2: raise ValueError( "float8_e5m2 FP8 KV cache is not supported on AMD ROCm" ) + if device.type == "cpu" and dtype == torch.float8_e4m3fn: + raise ValueError( + "float8_e4m3fn FP8 KV cache is not supported on Intel IPEX CPU" + ) element_size = torch.tensor([], dtype=dtype).element_size() if SYSTEM == "ipex" and device.type == "xpu": @@ -133,7 +138,8 @@ class KVCache: return False elif self.dtype == torch.float8_e4m3fn and ( (ATTENTION in ("paged", "flashinfer") and SYSTEM == "cuda") - or (ATTENTION == "paged" and SYSTEM == "rocm") + or (ATTENTION == "paged" and SYSTEM in ["rocm", "ipex"]) + or (ATTENTION == "flashdecoding-ipex") ): log_once(logger.info, "Using FP8 KV cache scales") return True @@ -141,7 +147,7 @@ class KVCache: # We have scales, but not the correct FP8 cache type, so warn once. log_once( logger.info, - "Ignoring FP8 KV cache scales, supported only for float8_e4m3fn KV cache with flashinfer on CUDA and paged attention on ROCm", + "Ignoring FP8 KV cache scales, supported only for float8_e4m3fn KV cache with flashinfer on CUDA and paged attention on ROCm/IPEX and flashdecoding on IPEX", ) return False @@ -208,7 +214,13 @@ class KVCache: import intel_extension_for_pytorch as ipex ipex.llm.modules.PagedAttention.reshape_and_cache_flash( - key, value, key_cache, value_cache, slots + key, + value, + key_cache, + value_cache, + slots, + k_scale=kv_scales.key_scale_cpu, + v_scale=kv_scales.value_scale_cpu, ) else: paged_reshape_and_cache( @@ -268,7 +280,7 @@ def paged_reshape_and_cache( import intel_extension_for_pytorch as ipex ipex.llm.modules.PagedAttention.reshape_and_cache( - key, value, key_cache, value_cache, slots + key, value, key_cache, value_cache, slots, k_scale=k_scale, v_scale=v_scale ) else: raise NotImplementedError(