mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
Merge ce8548f5c4
into 73e797528d
This commit is contained in:
commit
3142c4ac6f
@ -87,7 +87,7 @@ RUN echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https:/
|
|||||||
|
|
||||||
RUN mv /tmp/intel-for-pytorch-gpu-dev.list /etc/apt/sources.list.d
|
RUN mv /tmp/intel-for-pytorch-gpu-dev.list /etc/apt/sources.list.d
|
||||||
|
|
||||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt install -y xpu-smi cmake ninja-build pciutils intel-ocloc
|
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt install -y xpu-smi cmake ninja-build pciutils intel-ocloc libnl-genl-3-200
|
||||||
|
|
||||||
# Text Generation Inference base env
|
# Text Generation Inference base env
|
||||||
ENV HF_HOME=/data \
|
ENV HF_HOME=/data \
|
||||||
@ -100,8 +100,6 @@ ENV HF_HOME=/data \
|
|||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
RUN pip install torch==2.6.0 --index-url https://download.pytorch.org/whl/test/xpu
|
RUN pip install torch==2.6.0 --index-url https://download.pytorch.org/whl/test/xpu
|
||||||
|
|
||||||
RUN pip install triton-xpu==3.2.0b1 --no-cache-dir
|
|
||||||
|
|
||||||
# Install server
|
# Install server
|
||||||
COPY proto proto
|
COPY proto proto
|
||||||
COPY server server
|
COPY server server
|
||||||
@ -119,7 +117,9 @@ ENV CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0
|
|||||||
ENV TORCH_DEVICE_BACKEND_AUTOLOAD=0
|
ENV TORCH_DEVICE_BACKEND_AUTOLOAD=0
|
||||||
|
|
||||||
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/xpu/oneccl_bind_pt-2.6.0%2Bxpu-cp311-cp311-linux_x86_64.whl
|
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/xpu/oneccl_bind_pt-2.6.0%2Bxpu-cp311-cp311-linux_x86_64.whl
|
||||||
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/xpu/intel_extension_for_pytorch-2.6.10%2Bxpu-cp311-cp311-linux_x86_64.whl
|
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout d5a7036316a01ea8220eb4da78a2207c423a1166
|
||||||
|
RUN sed -i 's/VERSION_MINOR 7/VERSION_MINOR 6/' intel-extension-for-pytorch/version.txt
|
||||||
|
RUN cd intel-extension-for-pytorch && git submodule update --init --recursive && USE_AOT_DEVLIST='pvc,ats-m150' BUILD_SEPARATE_OPS=OFF BUILD_WITH_CPU=OFF USE_XETLA=ON python setup.py install && rm -rf /usr/src/intel-extension-for-pytorch
|
||||||
# Install benchmarker
|
# Install benchmarker
|
||||||
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||||
# Install router
|
# Install router
|
||||||
|
@ -8,7 +8,10 @@ from text_generation_server.models.globals import (
|
|||||||
BLOCK_SIZE,
|
BLOCK_SIZE,
|
||||||
)
|
)
|
||||||
|
|
||||||
SUPPORTS_WINDOWING = False
|
if ATTENTION == "flashdecoding-ipex":
|
||||||
|
SUPPORTS_WINDOWING = True
|
||||||
|
else:
|
||||||
|
SUPPORTS_WINDOWING = False
|
||||||
|
|
||||||
|
|
||||||
def attention(
|
def attention(
|
||||||
@ -25,13 +28,19 @@ def attention(
|
|||||||
causal: bool = True,
|
causal: bool = True,
|
||||||
softcap: Optional[float] = None,
|
softcap: Optional[float] = None,
|
||||||
):
|
):
|
||||||
if softcap is not None:
|
|
||||||
raise NotImplementedError("softcap is not available in IPEX")
|
|
||||||
|
|
||||||
out = torch.empty_like(query)
|
out = torch.empty_like(query)
|
||||||
|
kv_cache_dtype = "auto"
|
||||||
|
if kv_cache.key.dtype == torch.float8_e5m2:
|
||||||
|
kv_cache_dtype = "fp8_e5m2"
|
||||||
|
if kv_cache.key.dtype == torch.float8_e4m3fn:
|
||||||
|
kv_cache_dtype = "fp8_e4m3"
|
||||||
|
|
||||||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||||
if ATTENTION == "flashdecoding-ipex":
|
if ATTENTION == "flashdecoding-ipex":
|
||||||
|
window_size_right = -1 if window_size_left == -1 else 0
|
||||||
|
if softcap is None:
|
||||||
|
softcap = -1.0
|
||||||
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
|
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
|
||||||
out,
|
out,
|
||||||
query.contiguous() if query.device.type == "xpu" else query,
|
query.contiguous() if query.device.type == "xpu" else query,
|
||||||
@ -45,8 +54,18 @@ def attention(
|
|||||||
causal,
|
causal,
|
||||||
block_tables,
|
block_tables,
|
||||||
None,
|
None,
|
||||||
|
window_size_left=window_size_left,
|
||||||
|
window_size_right=window_size_right,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
k_scale=kv_scales.key_scale_cpu,
|
||||||
|
v_scale=kv_scales.value_scale_cpu,
|
||||||
|
softcap=softcap,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
if softcap is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"softcap is not available in IPEX paged attention"
|
||||||
|
)
|
||||||
ipex.llm.functional.varlen_attention(
|
ipex.llm.functional.varlen_attention(
|
||||||
query.contiguous() if query.device.type == "xpu" else query,
|
query.contiguous() if query.device.type == "xpu" else query,
|
||||||
key.contiguous() if key.device.type == "xpu" else key,
|
key.contiguous() if key.device.type == "xpu" else key,
|
||||||
@ -80,12 +99,16 @@ def paged_attention(
|
|||||||
softcap: Optional[float] = None,
|
softcap: Optional[float] = None,
|
||||||
window_size_left: Optional[int] = -1,
|
window_size_left: Optional[int] = -1,
|
||||||
):
|
):
|
||||||
if softcap is not None:
|
|
||||||
raise NotImplementedError("softcap is not available in IPEX")
|
|
||||||
|
|
||||||
out = torch.empty_like(query)
|
out = torch.empty_like(query)
|
||||||
|
kv_cache_dtype = "auto"
|
||||||
|
if kv_cache.key.dtype == torch.float8_e5m2:
|
||||||
|
kv_cache_dtype = "fp8_e5m2"
|
||||||
|
if kv_cache.key.dtype == torch.float8_e4m3fn:
|
||||||
|
kv_cache_dtype = "fp8_e4m3"
|
||||||
if ATTENTION == "flashdecoding-ipex":
|
if ATTENTION == "flashdecoding-ipex":
|
||||||
|
window_size_right = -1 if window_size_left == -1 else 0
|
||||||
|
if softcap is None:
|
||||||
|
softcap = -1.0
|
||||||
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
|
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
|
||||||
out,
|
out,
|
||||||
query.contiguous() if query.device.type == "xpu" else query,
|
query.contiguous() if query.device.type == "xpu" else query,
|
||||||
@ -99,9 +122,19 @@ def paged_attention(
|
|||||||
True,
|
True,
|
||||||
block_tables,
|
block_tables,
|
||||||
None,
|
None,
|
||||||
|
window_size_left=window_size_left,
|
||||||
|
window_size_right=window_size_right,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
k_scale=kv_scales.key_scale_cpu,
|
||||||
|
v_scale=kv_scales.value_scale_cpu,
|
||||||
|
softcap=softcap,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
|
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
|
||||||
|
if softcap is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"softcap is not available in IPEX paged attention"
|
||||||
|
)
|
||||||
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
||||||
out,
|
out,
|
||||||
query,
|
query,
|
||||||
@ -114,6 +147,8 @@ def paged_attention(
|
|||||||
BLOCK_SIZE,
|
BLOCK_SIZE,
|
||||||
max_s,
|
max_s,
|
||||||
None,
|
None,
|
||||||
|
k_scale=kv_scales.key_scale_cpu,
|
||||||
|
v_scale=kv_scales.value_scale_cpu,
|
||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
@ -68,15 +68,20 @@ class KVCache:
|
|||||||
if dtype in {torch.float8_e5m2, torch.float8_e4m3fn}:
|
if dtype in {torch.float8_e5m2, torch.float8_e4m3fn}:
|
||||||
if not (
|
if not (
|
||||||
(ATTENTION == "flashinfer" and SYSTEM == "cuda")
|
(ATTENTION == "flashinfer" and SYSTEM == "cuda")
|
||||||
or (ATTENTION == "paged" and SYSTEM in ("cuda", "rocm"))
|
or (ATTENTION == "paged" and SYSTEM in ("cuda", "rocm", "ipex"))
|
||||||
|
or (ATTENTION == "flashdecoding-ipex")
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"FP8 KV cache is currently only supported for flashinfer on CUDA and paged attention on CUDA and ROCm. "
|
"FP8 KV cache is currently only supported for flashinfer on CUDA and paged attention on CUDA, ROCm and INTEL IPEX and flashdecoding in Intel IPEX "
|
||||||
)
|
)
|
||||||
if SYSTEM == "rocm" and dtype == torch.float8_e5m2:
|
if SYSTEM == "rocm" and dtype == torch.float8_e5m2:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"float8_e5m2 FP8 KV cache is not supported on AMD ROCm"
|
"float8_e5m2 FP8 KV cache is not supported on AMD ROCm"
|
||||||
)
|
)
|
||||||
|
if device.type == "cpu" and dtype == torch.float8_e4m3fn:
|
||||||
|
raise ValueError(
|
||||||
|
"float8_e4m3fn FP8 KV cache is not supported on Intel IPEX CPU"
|
||||||
|
)
|
||||||
|
|
||||||
element_size = torch.tensor([], dtype=dtype).element_size()
|
element_size = torch.tensor([], dtype=dtype).element_size()
|
||||||
if SYSTEM == "ipex" and device.type == "xpu":
|
if SYSTEM == "ipex" and device.type == "xpu":
|
||||||
@ -133,7 +138,8 @@ class KVCache:
|
|||||||
return False
|
return False
|
||||||
elif self.dtype == torch.float8_e4m3fn and (
|
elif self.dtype == torch.float8_e4m3fn and (
|
||||||
(ATTENTION in ("paged", "flashinfer") and SYSTEM == "cuda")
|
(ATTENTION in ("paged", "flashinfer") and SYSTEM == "cuda")
|
||||||
or (ATTENTION == "paged" and SYSTEM == "rocm")
|
or (ATTENTION == "paged" and SYSTEM in ["rocm", "ipex"])
|
||||||
|
or (ATTENTION == "flashdecoding-ipex")
|
||||||
):
|
):
|
||||||
log_once(logger.info, "Using FP8 KV cache scales")
|
log_once(logger.info, "Using FP8 KV cache scales")
|
||||||
return True
|
return True
|
||||||
@ -141,7 +147,7 @@ class KVCache:
|
|||||||
# We have scales, but not the correct FP8 cache type, so warn once.
|
# We have scales, but not the correct FP8 cache type, so warn once.
|
||||||
log_once(
|
log_once(
|
||||||
logger.info,
|
logger.info,
|
||||||
"Ignoring FP8 KV cache scales, supported only for float8_e4m3fn KV cache with flashinfer on CUDA and paged attention on ROCm",
|
"Ignoring FP8 KV cache scales, supported only for float8_e4m3fn KV cache with flashinfer on CUDA and paged attention on ROCm/IPEX and flashdecoding on IPEX",
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -207,8 +213,20 @@ class KVCache:
|
|||||||
elif ATTENTION == "flashdecoding-ipex" and key.device.type == "xpu":
|
elif ATTENTION == "flashdecoding-ipex" and key.device.type == "xpu":
|
||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex
|
||||||
|
|
||||||
|
kv_cache_dtype = "auto"
|
||||||
|
if key_cache.dtype == torch.float8_e5m2:
|
||||||
|
kv_cache_dtype = "fp8_e5m2"
|
||||||
|
if key_cache.dtype == torch.float8_e4m3fn:
|
||||||
|
kv_cache_dtype = "fp8_e4m3"
|
||||||
ipex.llm.modules.PagedAttention.reshape_and_cache_flash(
|
ipex.llm.modules.PagedAttention.reshape_and_cache_flash(
|
||||||
key, value, key_cache, value_cache, slots
|
key,
|
||||||
|
value,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
slots,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
k_scale=kv_scales.key_scale_cpu,
|
||||||
|
v_scale=kv_scales.value_scale_cpu,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
paged_reshape_and_cache(
|
paged_reshape_and_cache(
|
||||||
@ -267,8 +285,21 @@ def paged_reshape_and_cache(
|
|||||||
elif SYSTEM == "ipex":
|
elif SYSTEM == "ipex":
|
||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex
|
||||||
|
|
||||||
|
kv_cache_dtype = "auto"
|
||||||
|
if key_cache.dtype == torch.float8_e5m2:
|
||||||
|
kv_cache_dtype = "fp8_e5m2"
|
||||||
|
if key_cache.dtype == torch.float8_e4m3fn:
|
||||||
|
kv_cache_dtype = "fp8_e4m3"
|
||||||
|
|
||||||
ipex.llm.modules.PagedAttention.reshape_and_cache(
|
ipex.llm.modules.PagedAttention.reshape_and_cache(
|
||||||
key, value, key_cache, value_cache, slots
|
key,
|
||||||
|
value,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
slots,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
k_scale=k_scale,
|
||||||
|
v_scale=v_scale,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
|
Loading…
Reference in New Issue
Block a user