mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 14:52:20 +00:00
Flash decoding kernel adding and prefill-chunking and prefix caching enabling in intel cpu/xpu (#2815)
* flash decoding Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * enable xpu flashdecoding Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * set flashdecoding blocksize as 64 Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * enable flashdecoding, prefill chunking and prefix caching Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * add flashdecoding-ipex Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> --------- Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
82f6ea1b71
commit
885144166f
@ -224,9 +224,9 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/loca
|
|||||||
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
|
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||||
|
|
||||||
FROM ${PLATFORM} AS final
|
FROM ${PLATFORM} AS final
|
||||||
ENV ATTENTION=paged
|
ENV ATTENTION=flashdecoding-ipex
|
||||||
ENV PREFIX_CACHING=0
|
ENV PREFIX_CACHING=1
|
||||||
ENV PREFILL_CHUNKING=0
|
ENV PREFILL_CHUNKING=1
|
||||||
ENV CUDA_GRAPHS=0
|
ENV CUDA_GRAPHS=0
|
||||||
ENTRYPOINT ["text-generation-launcher"]
|
ENTRYPOINT ["text-generation-launcher"]
|
||||||
CMD ["--json-output"]
|
CMD ["--json-output"]
|
||||||
|
@ -143,7 +143,9 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let fallback_attention = if matches!(compute_capability, Some((major, _)) if major < 8) {
|
let fallback_attention = if compute_capability.is_none()
|
||||||
|
|| matches!(compute_capability, Some((major, _)) if major < 8)
|
||||||
|
{
|
||||||
"paged"
|
"paged"
|
||||||
} else {
|
} else {
|
||||||
"flashdecoding"
|
"flashdecoding"
|
||||||
|
@ -1,9 +1,12 @@
|
|||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex
|
||||||
import torch
|
import torch
|
||||||
from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
|
from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
|
||||||
from text_generation_server.models.flash_causal_lm import BLOCK_SIZE
|
|
||||||
from text_generation_server.layers.attention import Seqlen
|
from text_generation_server.layers.attention import Seqlen
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from text_generation_server.models.globals import (
|
||||||
|
ATTENTION,
|
||||||
|
BLOCK_SIZE,
|
||||||
|
)
|
||||||
|
|
||||||
SUPPORTS_WINDOWING = False
|
SUPPORTS_WINDOWING = False
|
||||||
|
|
||||||
@ -28,22 +31,38 @@ def attention(
|
|||||||
out = torch.empty_like(query)
|
out = torch.empty_like(query)
|
||||||
|
|
||||||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||||
ipex.llm.functional.varlen_attention(
|
if ATTENTION == "flashdecoding-ipex":
|
||||||
query.contiguous() if query.device.type == "xpu" else query,
|
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
|
||||||
key.contiguous() if key.device.type == "xpu" else key,
|
out,
|
||||||
value.contiguous() if value.device.type == "xpu" else value,
|
query.contiguous() if query.device.type == "xpu" else query,
|
||||||
out,
|
kv_cache.key,
|
||||||
seqlen.cu_seqlen_q,
|
kv_cache.value,
|
||||||
seqlen.cu_seqlen_q,
|
seqlen.cu_seqlen_q,
|
||||||
seqlen.max_q,
|
seqlen.cu_seqlen_k,
|
||||||
seqlen.max_q,
|
seqlen.max_q,
|
||||||
0.0,
|
seqlen.max_k,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
False,
|
causal,
|
||||||
causal,
|
block_tables,
|
||||||
False,
|
None,
|
||||||
None,
|
)
|
||||||
)
|
else:
|
||||||
|
ipex.llm.functional.varlen_attention(
|
||||||
|
query.contiguous() if query.device.type == "xpu" else query,
|
||||||
|
key.contiguous() if key.device.type == "xpu" else key,
|
||||||
|
value.contiguous() if value.device.type == "xpu" else value,
|
||||||
|
out,
|
||||||
|
seqlen.cu_seqlen_q,
|
||||||
|
seqlen.cu_seqlen_q,
|
||||||
|
seqlen.max_q,
|
||||||
|
seqlen.max_q,
|
||||||
|
0.0,
|
||||||
|
softmax_scale,
|
||||||
|
False,
|
||||||
|
causal,
|
||||||
|
False,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@ -64,20 +83,37 @@ def paged_attention(
|
|||||||
raise NotImplementedError("softcap is not available in IPEX")
|
raise NotImplementedError("softcap is not available in IPEX")
|
||||||
|
|
||||||
out = torch.empty_like(query)
|
out = torch.empty_like(query)
|
||||||
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
|
|
||||||
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
if ATTENTION == "flashdecoding-ipex":
|
||||||
out,
|
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
|
||||||
query,
|
out,
|
||||||
kv_cache.key,
|
query.contiguous() if query.device.type == "xpu" else query,
|
||||||
kv_cache.value,
|
kv_cache.key,
|
||||||
kv_head_mapping,
|
kv_cache.value,
|
||||||
softmax_scale,
|
seqlen.cu_seqlen_q,
|
||||||
block_tables,
|
seqlen.cu_seqlen_k,
|
||||||
input_lengths,
|
seqlen.max_q,
|
||||||
BLOCK_SIZE,
|
seqlen.max_k,
|
||||||
max_s,
|
softmax_scale,
|
||||||
None,
|
True,
|
||||||
)
|
block_tables,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
|
||||||
|
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
||||||
|
out,
|
||||||
|
query,
|
||||||
|
kv_cache.key,
|
||||||
|
kv_cache.value,
|
||||||
|
kv_head_mapping,
|
||||||
|
softmax_scale,
|
||||||
|
block_tables,
|
||||||
|
input_lengths,
|
||||||
|
BLOCK_SIZE,
|
||||||
|
max_s,
|
||||||
|
None,
|
||||||
|
)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@ -66,7 +66,9 @@ class KVCache:
|
|||||||
else:
|
else:
|
||||||
x = BLOCK_SIZE // element_size
|
x = BLOCK_SIZE // element_size
|
||||||
|
|
||||||
if ATTENTION in {"flashdecoding", "flashinfer"}:
|
if ATTENTION in {"flashdecoding", "flashinfer"} or (
|
||||||
|
ATTENTION == "flashdecoding-ipex" and device.type == "xpu"
|
||||||
|
):
|
||||||
self.kv_cache = (
|
self.kv_cache = (
|
||||||
torch.empty(
|
torch.empty(
|
||||||
(num_blocks, BLOCK_SIZE, num_heads, head_size),
|
(num_blocks, BLOCK_SIZE, num_heads, head_size),
|
||||||
@ -80,6 +82,7 @@ class KVCache:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
elif SYSTEM == "ipex" and device == torch.device("cpu"):
|
elif SYSTEM == "ipex" and device == torch.device("cpu"):
|
||||||
|
# ipex cpu flashdecoding kernel and paged attention kernel share same layout
|
||||||
self.kv_cache = (
|
self.kv_cache = (
|
||||||
torch.empty(
|
torch.empty(
|
||||||
(num_blocks, num_heads, BLOCK_SIZE, head_size),
|
(num_blocks, num_heads, BLOCK_SIZE, head_size),
|
||||||
@ -187,6 +190,12 @@ class KVCache:
|
|||||||
shape = key_cache.shape
|
shape = key_cache.shape
|
||||||
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
|
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
|
||||||
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
|
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
|
||||||
|
elif ATTENTION == "flashdecoding-ipex" and key.device.type == "xpu":
|
||||||
|
import intel_extension_for_pytorch as ipex
|
||||||
|
|
||||||
|
ipex.llm.modules.PagedAttention.reshape_and_cache_flash(
|
||||||
|
key, value, key_cache, value_cache, slots
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
paged_reshape_and_cache(key, value, key_cache, value_cache, slots)
|
paged_reshape_and_cache(key, value, key_cache, value_cache, slots)
|
||||||
|
|
||||||
|
@ -14,13 +14,17 @@ PREFIX_CACHING = os.environ["PREFIX_CACHING"].lower() in {
|
|||||||
}
|
}
|
||||||
PREFILL_CHUNKING = os.getenv("PREFILL_CHUNKING", "1").lower() in {"1", "true"}
|
PREFILL_CHUNKING = os.getenv("PREFILL_CHUNKING", "1").lower() in {"1", "true"}
|
||||||
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
|
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
|
||||||
_expected = {"paged", "flashdecoding", "flashinfer"}
|
_expected = {"paged", "flashdecoding", "flashdecoding-ipex", "flashinfer"}
|
||||||
assert (
|
assert (
|
||||||
ATTENTION in _expected
|
ATTENTION in _expected
|
||||||
), f"Attention is not valid {ATTENTION}, expected {_expected}"
|
), f"Attention is not valid {ATTENTION}, expected {_expected}"
|
||||||
log_master(logger.info, f"Using Attention = {ATTENTION}")
|
log_master(logger.info, f"Using Attention = {ATTENTION}")
|
||||||
|
|
||||||
if PREFIX_CACHING and ATTENTION not in {"flashinfer", "flashdecoding"}:
|
if PREFIX_CACHING and ATTENTION not in {
|
||||||
|
"flashinfer",
|
||||||
|
"flashdecoding",
|
||||||
|
"flashdecoding-ipex",
|
||||||
|
}:
|
||||||
raise RuntimeError("Prefix caching is only supported with flashinfer")
|
raise RuntimeError("Prefix caching is only supported with flashinfer")
|
||||||
|
|
||||||
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
||||||
@ -28,12 +32,15 @@ TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.95"))
|
|||||||
assert TGI_WIGGLE_ROOM > 0
|
assert TGI_WIGGLE_ROOM > 0
|
||||||
assert TGI_WIGGLE_ROOM < 1
|
assert TGI_WIGGLE_ROOM < 1
|
||||||
|
|
||||||
|
|
||||||
# This is overridden by the cli
|
# This is overridden by the cli
|
||||||
BLOCK_SIZE: int
|
BLOCK_SIZE: int
|
||||||
if ATTENTION == "flashdecoding":
|
if ATTENTION == "flashdecoding":
|
||||||
BLOCK_SIZE = 256
|
BLOCK_SIZE = 256
|
||||||
elif ATTENTION == "flashinfer":
|
elif ATTENTION == "flashinfer":
|
||||||
BLOCK_SIZE = 1
|
BLOCK_SIZE = 1
|
||||||
|
elif ATTENTION == "flashdecoding-ipex":
|
||||||
|
BLOCK_SIZE = 64
|
||||||
else:
|
else:
|
||||||
BLOCK_SIZE = 16
|
BLOCK_SIZE = 16
|
||||||
|
|
||||||
|
@ -79,10 +79,13 @@ class Model(ABC):
|
|||||||
"Prefill chunking will be turned off",
|
"Prefill chunking will be turned off",
|
||||||
)
|
)
|
||||||
support_chunking = False
|
support_chunking = False
|
||||||
if ATTENTION not in ["flashinfer", "flashdecoding"] and support_chunking:
|
if (
|
||||||
|
ATTENTION not in ["flashinfer", "flashdecoding", "flashdecoding-ipex"]
|
||||||
|
and support_chunking
|
||||||
|
):
|
||||||
log_master(
|
log_master(
|
||||||
logger.warning,
|
logger.warning,
|
||||||
"Prefill chunking is only supported with `flashinfer` or `flashdecoding` attention types.",
|
"Prefill chunking is only supported with `flashinfer` or `flashdecoding` or `flashdecoding-ipex` attention types.",
|
||||||
)
|
)
|
||||||
support_chunking = False
|
support_chunking = False
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user