mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 23:12:07 +00:00
Flash decoding kernel adding and prefill-chunking and prefix caching enabling in intel cpu/xpu (#2815)
* flash decoding Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * enable xpu flashdecoding Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * set flashdecoding blocksize as 64 Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * enable flashdecoding, prefill chunking and prefix caching Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * add flashdecoding-ipex Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> --------- Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
82f6ea1b71
commit
885144166f
@ -224,9 +224,9 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/loca
|
||||
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||
|
||||
FROM ${PLATFORM} AS final
|
||||
ENV ATTENTION=paged
|
||||
ENV PREFIX_CACHING=0
|
||||
ENV PREFILL_CHUNKING=0
|
||||
ENV ATTENTION=flashdecoding-ipex
|
||||
ENV PREFIX_CACHING=1
|
||||
ENV PREFILL_CHUNKING=1
|
||||
ENV CUDA_GRAPHS=0
|
||||
ENTRYPOINT ["text-generation-launcher"]
|
||||
CMD ["--json-output"]
|
||||
|
@ -143,7 +143,9 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
|
||||
}
|
||||
}
|
||||
|
||||
let fallback_attention = if matches!(compute_capability, Some((major, _)) if major < 8) {
|
||||
let fallback_attention = if compute_capability.is_none()
|
||||
|| matches!(compute_capability, Some((major, _)) if major < 8)
|
||||
{
|
||||
"paged"
|
||||
} else {
|
||||
"flashdecoding"
|
||||
|
@ -1,9 +1,12 @@
|
||||
import intel_extension_for_pytorch as ipex
|
||||
import torch
|
||||
from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
|
||||
from text_generation_server.models.flash_causal_lm import BLOCK_SIZE
|
||||
from text_generation_server.layers.attention import Seqlen
|
||||
from typing import Optional
|
||||
from text_generation_server.models.globals import (
|
||||
ATTENTION,
|
||||
BLOCK_SIZE,
|
||||
)
|
||||
|
||||
SUPPORTS_WINDOWING = False
|
||||
|
||||
@ -28,22 +31,38 @@ def attention(
|
||||
out = torch.empty_like(query)
|
||||
|
||||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||
ipex.llm.functional.varlen_attention(
|
||||
query.contiguous() if query.device.type == "xpu" else query,
|
||||
key.contiguous() if key.device.type == "xpu" else key,
|
||||
value.contiguous() if value.device.type == "xpu" else value,
|
||||
out,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.max_q,
|
||||
seqlen.max_q,
|
||||
0.0,
|
||||
softmax_scale,
|
||||
False,
|
||||
causal,
|
||||
False,
|
||||
None,
|
||||
)
|
||||
if ATTENTION == "flashdecoding-ipex":
|
||||
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
|
||||
out,
|
||||
query.contiguous() if query.device.type == "xpu" else query,
|
||||
kv_cache.key,
|
||||
kv_cache.value,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.cu_seqlen_k,
|
||||
seqlen.max_q,
|
||||
seqlen.max_k,
|
||||
softmax_scale,
|
||||
causal,
|
||||
block_tables,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
ipex.llm.functional.varlen_attention(
|
||||
query.contiguous() if query.device.type == "xpu" else query,
|
||||
key.contiguous() if key.device.type == "xpu" else key,
|
||||
value.contiguous() if value.device.type == "xpu" else value,
|
||||
out,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.max_q,
|
||||
seqlen.max_q,
|
||||
0.0,
|
||||
softmax_scale,
|
||||
False,
|
||||
causal,
|
||||
False,
|
||||
None,
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
@ -64,20 +83,37 @@ def paged_attention(
|
||||
raise NotImplementedError("softcap is not available in IPEX")
|
||||
|
||||
out = torch.empty_like(query)
|
||||
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
|
||||
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
||||
out,
|
||||
query,
|
||||
kv_cache.key,
|
||||
kv_cache.value,
|
||||
kv_head_mapping,
|
||||
softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
BLOCK_SIZE,
|
||||
max_s,
|
||||
None,
|
||||
)
|
||||
|
||||
if ATTENTION == "flashdecoding-ipex":
|
||||
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
|
||||
out,
|
||||
query.contiguous() if query.device.type == "xpu" else query,
|
||||
kv_cache.key,
|
||||
kv_cache.value,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.cu_seqlen_k,
|
||||
seqlen.max_q,
|
||||
seqlen.max_k,
|
||||
softmax_scale,
|
||||
True,
|
||||
block_tables,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
|
||||
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
||||
out,
|
||||
query,
|
||||
kv_cache.key,
|
||||
kv_cache.value,
|
||||
kv_head_mapping,
|
||||
softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
BLOCK_SIZE,
|
||||
max_s,
|
||||
None,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
|
@ -66,7 +66,9 @@ class KVCache:
|
||||
else:
|
||||
x = BLOCK_SIZE // element_size
|
||||
|
||||
if ATTENTION in {"flashdecoding", "flashinfer"}:
|
||||
if ATTENTION in {"flashdecoding", "flashinfer"} or (
|
||||
ATTENTION == "flashdecoding-ipex" and device.type == "xpu"
|
||||
):
|
||||
self.kv_cache = (
|
||||
torch.empty(
|
||||
(num_blocks, BLOCK_SIZE, num_heads, head_size),
|
||||
@ -80,6 +82,7 @@ class KVCache:
|
||||
),
|
||||
)
|
||||
elif SYSTEM == "ipex" and device == torch.device("cpu"):
|
||||
# ipex cpu flashdecoding kernel and paged attention kernel share same layout
|
||||
self.kv_cache = (
|
||||
torch.empty(
|
||||
(num_blocks, num_heads, BLOCK_SIZE, head_size),
|
||||
@ -187,6 +190,12 @@ class KVCache:
|
||||
shape = key_cache.shape
|
||||
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
|
||||
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
|
||||
elif ATTENTION == "flashdecoding-ipex" and key.device.type == "xpu":
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
||||
ipex.llm.modules.PagedAttention.reshape_and_cache_flash(
|
||||
key, value, key_cache, value_cache, slots
|
||||
)
|
||||
else:
|
||||
paged_reshape_and_cache(key, value, key_cache, value_cache, slots)
|
||||
|
||||
|
@ -14,13 +14,17 @@ PREFIX_CACHING = os.environ["PREFIX_CACHING"].lower() in {
|
||||
}
|
||||
PREFILL_CHUNKING = os.getenv("PREFILL_CHUNKING", "1").lower() in {"1", "true"}
|
||||
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
|
||||
_expected = {"paged", "flashdecoding", "flashinfer"}
|
||||
_expected = {"paged", "flashdecoding", "flashdecoding-ipex", "flashinfer"}
|
||||
assert (
|
||||
ATTENTION in _expected
|
||||
), f"Attention is not valid {ATTENTION}, expected {_expected}"
|
||||
log_master(logger.info, f"Using Attention = {ATTENTION}")
|
||||
|
||||
if PREFIX_CACHING and ATTENTION not in {"flashinfer", "flashdecoding"}:
|
||||
if PREFIX_CACHING and ATTENTION not in {
|
||||
"flashinfer",
|
||||
"flashdecoding",
|
||||
"flashdecoding-ipex",
|
||||
}:
|
||||
raise RuntimeError("Prefix caching is only supported with flashinfer")
|
||||
|
||||
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
||||
@ -28,12 +32,15 @@ TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.95"))
|
||||
assert TGI_WIGGLE_ROOM > 0
|
||||
assert TGI_WIGGLE_ROOM < 1
|
||||
|
||||
|
||||
# This is overridden by the cli
|
||||
BLOCK_SIZE: int
|
||||
if ATTENTION == "flashdecoding":
|
||||
BLOCK_SIZE = 256
|
||||
elif ATTENTION == "flashinfer":
|
||||
BLOCK_SIZE = 1
|
||||
elif ATTENTION == "flashdecoding-ipex":
|
||||
BLOCK_SIZE = 64
|
||||
else:
|
||||
BLOCK_SIZE = 16
|
||||
|
||||
|
@ -79,10 +79,13 @@ class Model(ABC):
|
||||
"Prefill chunking will be turned off",
|
||||
)
|
||||
support_chunking = False
|
||||
if ATTENTION not in ["flashinfer", "flashdecoding"] and support_chunking:
|
||||
if (
|
||||
ATTENTION not in ["flashinfer", "flashdecoding", "flashdecoding-ipex"]
|
||||
and support_chunking
|
||||
):
|
||||
log_master(
|
||||
logger.warning,
|
||||
"Prefill chunking is only supported with `flashinfer` or `flashdecoding` attention types.",
|
||||
"Prefill chunking is only supported with `flashinfer` or `flashdecoding` or `flashdecoding-ipex` attention types.",
|
||||
)
|
||||
support_chunking = False
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user