mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
add flashdecoding-ipex
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
ac67673788
commit
d6ac8cdf81
@ -224,7 +224,7 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/loca
|
||||
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||
|
||||
FROM ${PLATFORM} AS final
|
||||
ENV ATTENTION=flashdecoding
|
||||
ENV ATTENTION=flashdecoding-ipex
|
||||
ENV PREFIX_CACHING=1
|
||||
ENV PREFILL_CHUNKING=1
|
||||
ENV CUDA_GRAPHS=0
|
||||
|
@ -144,7 +144,9 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
|
||||
}
|
||||
}
|
||||
|
||||
let fallback_attention = if matches!(compute_capability, Some((major, _)) if major < 8) {
|
||||
let fallback_attention = if compute_capability.is_none()
|
||||
|| matches!(compute_capability, Some((major, _)) if major < 8)
|
||||
{
|
||||
"paged"
|
||||
} else {
|
||||
"flashdecoding"
|
||||
|
@ -31,7 +31,7 @@ def attention(
|
||||
out = torch.empty_like(query)
|
||||
|
||||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||
if ATTENTION == "flashdecoding":
|
||||
if ATTENTION == "flashdecoding-ipex":
|
||||
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
|
||||
out,
|
||||
query.contiguous() if query.device.type == "xpu" else query,
|
||||
@ -84,7 +84,7 @@ def paged_attention(
|
||||
|
||||
out = torch.empty_like(query)
|
||||
|
||||
if ATTENTION == "flashdecoding":
|
||||
if ATTENTION == "flashdecoding-ipex":
|
||||
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
|
||||
out,
|
||||
query.contiguous() if query.device.type == "xpu" else query,
|
||||
|
@ -66,8 +66,8 @@ class KVCache:
|
||||
else:
|
||||
x = BLOCK_SIZE // element_size
|
||||
|
||||
if ATTENTION in {"flashdecoding", "flashinfer"} and not (
|
||||
SYSTEM == "ipex" and device == torch.device("cpu")
|
||||
if ATTENTION in {"flashdecoding", "flashinfer"} or (
|
||||
ATTENTION == "flashdecoding-ipex" and device.type == "xpu"
|
||||
):
|
||||
self.kv_cache = (
|
||||
torch.empty(
|
||||
@ -82,6 +82,7 @@ class KVCache:
|
||||
),
|
||||
)
|
||||
elif SYSTEM == "ipex" and device == torch.device("cpu"):
|
||||
# ipex cpu flashdecoding kernel and paged attention kernel share same layout
|
||||
self.kv_cache = (
|
||||
torch.empty(
|
||||
(num_blocks, num_heads, BLOCK_SIZE, head_size),
|
||||
@ -176,9 +177,7 @@ class KVCache:
|
||||
scalar=True,
|
||||
)[0]
|
||||
|
||||
if ATTENTION in {"flashdecoding", "flashinfer"} and not (
|
||||
SYSTEM == "ipex" and key.device == torch.device("cpu")
|
||||
):
|
||||
if ATTENTION in {"flashdecoding", "flashinfer"}:
|
||||
key = key.to(key_cache.dtype)
|
||||
value = value.to(value_cache.dtype)
|
||||
if key_cache.dtype in {torch.float8_e4m3fn, torch.float8_e5m2}:
|
||||
@ -191,6 +190,12 @@ class KVCache:
|
||||
shape = key_cache.shape
|
||||
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
|
||||
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
|
||||
elif ATTENTION == "flashdecoding-ipex" and key.device.type == "xpu":
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
||||
ipex.llm.modules.PagedAttention.reshape_and_cache_flash(
|
||||
key, value, key_cache, value_cache, slots
|
||||
)
|
||||
else:
|
||||
paged_reshape_and_cache(key, value, key_cache, value_cache, slots)
|
||||
|
||||
|
@ -4,7 +4,6 @@ from loguru import logger
|
||||
from typing import Dict, Optional
|
||||
|
||||
from text_generation_server.utils.log import log_master
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
REQUEST_LOGPROBS = os.getenv("REQUEST_LOGPROBS", "0").lower() in {"1", "true"}
|
||||
ATTENTION = os.environ["ATTENTION"]
|
||||
@ -15,13 +14,17 @@ PREFIX_CACHING = os.environ["PREFIX_CACHING"].lower() in {
|
||||
}
|
||||
PREFILL_CHUNKING = os.getenv("PREFILL_CHUNKING", "1").lower() in {"1", "true"}
|
||||
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
|
||||
_expected = {"paged", "flashdecoding", "flashinfer"}
|
||||
_expected = {"paged", "flashdecoding", "flashdecoding-ipex", "flashinfer"}
|
||||
assert (
|
||||
ATTENTION in _expected
|
||||
), f"Attention is not valid {ATTENTION}, expected {_expected}"
|
||||
log_master(logger.info, f"Using Attention = {ATTENTION}")
|
||||
|
||||
if PREFIX_CACHING and ATTENTION not in {"flashinfer", "flashdecoding"}:
|
||||
if PREFIX_CACHING and ATTENTION not in {
|
||||
"flashinfer",
|
||||
"flashdecoding",
|
||||
"flashdecoding-ipex",
|
||||
}:
|
||||
raise RuntimeError("Prefix caching is only supported with flashinfer")
|
||||
|
||||
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
||||
@ -33,12 +36,11 @@ assert TGI_WIGGLE_ROOM < 1
|
||||
# This is overridden by the cli
|
||||
BLOCK_SIZE: int
|
||||
if ATTENTION == "flashdecoding":
|
||||
if SYSTEM == "ipex":
|
||||
BLOCK_SIZE = 64
|
||||
else:
|
||||
BLOCK_SIZE = 256
|
||||
BLOCK_SIZE = 256
|
||||
elif ATTENTION == "flashinfer":
|
||||
BLOCK_SIZE = 1
|
||||
elif ATTENTION == "flashdecoding-ipex":
|
||||
BLOCK_SIZE = 64
|
||||
else:
|
||||
BLOCK_SIZE = 16
|
||||
|
||||
|
@ -79,10 +79,13 @@ class Model(ABC):
|
||||
"Prefill chunking will be turned off",
|
||||
)
|
||||
support_chunking = False
|
||||
if ATTENTION not in ["flashinfer", "flashdecoding"] and support_chunking:
|
||||
if (
|
||||
ATTENTION not in ["flashinfer", "flashdecoding", "flashdecoding-ipex"]
|
||||
and support_chunking
|
||||
):
|
||||
log_master(
|
||||
logger.warning,
|
||||
"Prefill chunking is only supported with `flashinfer` or `flashdecoding` attention types.",
|
||||
"Prefill chunking is only supported with `flashinfer` or `flashdecoding` or `flashdecoding-ipex` attention types.",
|
||||
)
|
||||
support_chunking = False
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user