mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-26 12:32:10 +00:00
IPEX support FP8 kvcache
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
0142550096
commit
065f87a337
@ -119,7 +119,9 @@ ENV CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0
|
||||
ENV TORCH_DEVICE_BACKEND_AUTOLOAD=0
|
||||
|
||||
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/xpu/oneccl_bind_pt-2.6.0%2Bxpu-cp311-cp311-linux_x86_64.whl
|
||||
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/xpu/intel_extension_for_pytorch-2.6.10%2Bxpu-cp311-cp311-linux_x86_64.whl
|
||||
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout d5a7036316a01ea8220eb4da78a2207c423a1166
|
||||
RUN sed -i 's/VERSION_MINOR 7/VERSION_MINOR 6/' intel-extension-for-pytorch/version.txt
|
||||
RUN cd intel-extension-for-pytorch && git submodule update --init --recursive && USE_AOT_DEVLIST='pvc,ats-m150' BUILD_SEPARATE_OPS=OFF BUILD_WITH_CPU=OFF USE_XETLA=ON python setup.py install && rm -rf /usr/src/intel-extension-for-pytorch
|
||||
# Install benchmarker
|
||||
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||
# Install router
|
||||
|
@ -45,6 +45,8 @@ def attention(
|
||||
causal,
|
||||
block_tables,
|
||||
None,
|
||||
k_scale=kv_scales.key_scale_cpu,
|
||||
v_scale=kv_scales.value_scale_cpu,
|
||||
)
|
||||
else:
|
||||
ipex.llm.functional.varlen_attention(
|
||||
@ -99,6 +101,8 @@ def paged_attention(
|
||||
True,
|
||||
block_tables,
|
||||
None,
|
||||
k_scale=kv_scales.key_scale_cpu,
|
||||
v_scale=kv_scales.value_scale_cpu,
|
||||
)
|
||||
else:
|
||||
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
|
||||
@ -114,6 +118,8 @@ def paged_attention(
|
||||
BLOCK_SIZE,
|
||||
max_s,
|
||||
None,
|
||||
k_scale=kv_scales.key_scale_cpu,
|
||||
v_scale=kv_scales.value_scale_cpu,
|
||||
)
|
||||
return out
|
||||
|
||||
|
@ -68,15 +68,20 @@ class KVCache:
|
||||
if dtype in {torch.float8_e5m2, torch.float8_e4m3fn}:
|
||||
if not (
|
||||
(ATTENTION == "flashinfer" and SYSTEM == "cuda")
|
||||
or (ATTENTION == "paged" and SYSTEM in ("cuda", "rocm"))
|
||||
or (ATTENTION == "paged" and SYSTEM in ("cuda", "rocm", "ipex"))
|
||||
or (ATTENTION == "flashdecoding-ipex")
|
||||
):
|
||||
raise ValueError(
|
||||
"FP8 KV cache is currently only supported for flashinfer on CUDA and paged attention on CUDA and ROCm. "
|
||||
"FP8 KV cache is currently only supported for flashinfer on CUDA and paged attention on CUDA, ROCm and INTEL IPEX and flashdecoding in Intel IPEX "
|
||||
)
|
||||
if SYSTEM == "rocm" and dtype == torch.float8_e5m2:
|
||||
raise ValueError(
|
||||
"float8_e5m2 FP8 KV cache is not supported on AMD ROCm"
|
||||
)
|
||||
if device.type == "cpu" and dtype == torch.float8_e4m3fn:
|
||||
raise ValueError(
|
||||
"float8_e4m3fn FP8 KV cache is not supported on Intel IPEX CPU"
|
||||
)
|
||||
|
||||
element_size = torch.tensor([], dtype=dtype).element_size()
|
||||
if SYSTEM == "ipex" and device.type == "xpu":
|
||||
@ -133,7 +138,8 @@ class KVCache:
|
||||
return False
|
||||
elif self.dtype == torch.float8_e4m3fn and (
|
||||
(ATTENTION in ("paged", "flashinfer") and SYSTEM == "cuda")
|
||||
or (ATTENTION == "paged" and SYSTEM == "rocm")
|
||||
or (ATTENTION == "paged" and SYSTEM in ["rocm", "ipex"])
|
||||
or (ATTENTION == "flashdecoding-ipex")
|
||||
):
|
||||
log_once(logger.info, "Using FP8 KV cache scales")
|
||||
return True
|
||||
@ -141,7 +147,7 @@ class KVCache:
|
||||
# We have scales, but not the correct FP8 cache type, so warn once.
|
||||
log_once(
|
||||
logger.info,
|
||||
"Ignoring FP8 KV cache scales, supported only for float8_e4m3fn KV cache with flashinfer on CUDA and paged attention on ROCm",
|
||||
"Ignoring FP8 KV cache scales, supported only for float8_e4m3fn KV cache with flashinfer on CUDA and paged attention on ROCm/IPEX and flashdecoding on IPEX",
|
||||
)
|
||||
return False
|
||||
|
||||
@ -208,7 +214,13 @@ class KVCache:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
||||
ipex.llm.modules.PagedAttention.reshape_and_cache_flash(
|
||||
key, value, key_cache, value_cache, slots
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slots,
|
||||
k_scale=kv_scales.key_scale_cpu,
|
||||
v_scale=kv_scales.value_scale_cpu,
|
||||
)
|
||||
else:
|
||||
paged_reshape_and_cache(
|
||||
@ -268,7 +280,7 @@ def paged_reshape_and_cache(
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
||||
ipex.llm.modules.PagedAttention.reshape_and_cache(
|
||||
key, value, key_cache, value_cache, slots
|
||||
key, value, key_cache, value_cache, slots, k_scale=k_scale, v_scale=v_scale
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
|
Loading…
Reference in New Issue
Block a user