Flash decoding kernel adding and prefill-chunking and prefix caching enabling in intel cpu/xpu (#2815)

* flash decoding

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* enable xpu flashdecoding

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* set flashdecoding blocksize as 64

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* enable flashdecoding, prefill chunking and prefix caching

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* add flashdecoding-ipex

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

---------

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi 2025-01-17 19:04:57 +08:00 committed by GitHub
parent 82f6ea1b71
commit 885144166f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 97 additions and 40 deletions

View File

@ -224,9 +224,9 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/loca
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
FROM ${PLATFORM} AS final FROM ${PLATFORM} AS final
ENV ATTENTION=paged ENV ATTENTION=flashdecoding-ipex
ENV PREFIX_CACHING=0 ENV PREFIX_CACHING=1
ENV PREFILL_CHUNKING=0 ENV PREFILL_CHUNKING=1
ENV CUDA_GRAPHS=0 ENV CUDA_GRAPHS=0
ENTRYPOINT ["text-generation-launcher"] ENTRYPOINT ["text-generation-launcher"]
CMD ["--json-output"] CMD ["--json-output"]

View File

@ -143,7 +143,9 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
} }
} }
let fallback_attention = if matches!(compute_capability, Some((major, _)) if major < 8) { let fallback_attention = if compute_capability.is_none()
|| matches!(compute_capability, Some((major, _)) if major < 8)
{
"paged" "paged"
} else { } else {
"flashdecoding" "flashdecoding"

View File

@ -1,9 +1,12 @@
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
import torch import torch
from text_generation_server.layers.attention.kv_cache import KVCache, KVScales from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
from text_generation_server.models.flash_causal_lm import BLOCK_SIZE
from text_generation_server.layers.attention import Seqlen from text_generation_server.layers.attention import Seqlen
from typing import Optional from typing import Optional
from text_generation_server.models.globals import (
ATTENTION,
BLOCK_SIZE,
)
SUPPORTS_WINDOWING = False SUPPORTS_WINDOWING = False
@ -28,22 +31,38 @@ def attention(
out = torch.empty_like(query) out = torch.empty_like(query)
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
ipex.llm.functional.varlen_attention( if ATTENTION == "flashdecoding-ipex":
query.contiguous() if query.device.type == "xpu" else query, ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
key.contiguous() if key.device.type == "xpu" else key, out,
value.contiguous() if value.device.type == "xpu" else value, query.contiguous() if query.device.type == "xpu" else query,
out, kv_cache.key,
seqlen.cu_seqlen_q, kv_cache.value,
seqlen.cu_seqlen_q, seqlen.cu_seqlen_q,
seqlen.max_q, seqlen.cu_seqlen_k,
seqlen.max_q, seqlen.max_q,
0.0, seqlen.max_k,
softmax_scale, softmax_scale,
False, causal,
causal, block_tables,
False, None,
None, )
) else:
ipex.llm.functional.varlen_attention(
query.contiguous() if query.device.type == "xpu" else query,
key.contiguous() if key.device.type == "xpu" else key,
value.contiguous() if value.device.type == "xpu" else value,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q,
seqlen.max_q,
seqlen.max_q,
0.0,
softmax_scale,
False,
causal,
False,
None,
)
return out return out
@ -64,20 +83,37 @@ def paged_attention(
raise NotImplementedError("softcap is not available in IPEX") raise NotImplementedError("softcap is not available in IPEX")
out = torch.empty_like(query) out = torch.empty_like(query)
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( if ATTENTION == "flashdecoding-ipex":
out, ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
query, out,
kv_cache.key, query.contiguous() if query.device.type == "xpu" else query,
kv_cache.value, kv_cache.key,
kv_head_mapping, kv_cache.value,
softmax_scale, seqlen.cu_seqlen_q,
block_tables, seqlen.cu_seqlen_k,
input_lengths, seqlen.max_q,
BLOCK_SIZE, seqlen.max_k,
max_s, softmax_scale,
None, True,
) block_tables,
None,
)
else:
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
out,
query,
kv_cache.key,
kv_cache.value,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
BLOCK_SIZE,
max_s,
None,
)
return out return out

View File

@ -66,7 +66,9 @@ class KVCache:
else: else:
x = BLOCK_SIZE // element_size x = BLOCK_SIZE // element_size
if ATTENTION in {"flashdecoding", "flashinfer"}: if ATTENTION in {"flashdecoding", "flashinfer"} or (
ATTENTION == "flashdecoding-ipex" and device.type == "xpu"
):
self.kv_cache = ( self.kv_cache = (
torch.empty( torch.empty(
(num_blocks, BLOCK_SIZE, num_heads, head_size), (num_blocks, BLOCK_SIZE, num_heads, head_size),
@ -80,6 +82,7 @@ class KVCache:
), ),
) )
elif SYSTEM == "ipex" and device == torch.device("cpu"): elif SYSTEM == "ipex" and device == torch.device("cpu"):
# ipex cpu flashdecoding kernel and paged attention kernel share same layout
self.kv_cache = ( self.kv_cache = (
torch.empty( torch.empty(
(num_blocks, num_heads, BLOCK_SIZE, head_size), (num_blocks, num_heads, BLOCK_SIZE, head_size),
@ -187,6 +190,12 @@ class KVCache:
shape = key_cache.shape shape = key_cache.shape
key_cache.view(-1, shape[-2], shape[-1])[slots] = key key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value value_cache.view(-1, shape[-2], shape[-1])[slots] = value
elif ATTENTION == "flashdecoding-ipex" and key.device.type == "xpu":
import intel_extension_for_pytorch as ipex
ipex.llm.modules.PagedAttention.reshape_and_cache_flash(
key, value, key_cache, value_cache, slots
)
else: else:
paged_reshape_and_cache(key, value, key_cache, value_cache, slots) paged_reshape_and_cache(key, value, key_cache, value_cache, slots)

View File

@ -14,13 +14,17 @@ PREFIX_CACHING = os.environ["PREFIX_CACHING"].lower() in {
} }
PREFILL_CHUNKING = os.getenv("PREFILL_CHUNKING", "1").lower() in {"1", "true"} PREFILL_CHUNKING = os.getenv("PREFILL_CHUNKING", "1").lower() in {"1", "true"}
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}") log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
_expected = {"paged", "flashdecoding", "flashinfer"} _expected = {"paged", "flashdecoding", "flashdecoding-ipex", "flashinfer"}
assert ( assert (
ATTENTION in _expected ATTENTION in _expected
), f"Attention is not valid {ATTENTION}, expected {_expected}" ), f"Attention is not valid {ATTENTION}, expected {_expected}"
log_master(logger.info, f"Using Attention = {ATTENTION}") log_master(logger.info, f"Using Attention = {ATTENTION}")
if PREFIX_CACHING and ATTENTION not in {"flashinfer", "flashdecoding"}: if PREFIX_CACHING and ATTENTION not in {
"flashinfer",
"flashdecoding",
"flashdecoding-ipex",
}:
raise RuntimeError("Prefix caching is only supported with flashinfer") raise RuntimeError("Prefix caching is only supported with flashinfer")
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
@ -28,12 +32,15 @@ TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.95"))
assert TGI_WIGGLE_ROOM > 0 assert TGI_WIGGLE_ROOM > 0
assert TGI_WIGGLE_ROOM < 1 assert TGI_WIGGLE_ROOM < 1
# This is overridden by the cli # This is overridden by the cli
BLOCK_SIZE: int BLOCK_SIZE: int
if ATTENTION == "flashdecoding": if ATTENTION == "flashdecoding":
BLOCK_SIZE = 256 BLOCK_SIZE = 256
elif ATTENTION == "flashinfer": elif ATTENTION == "flashinfer":
BLOCK_SIZE = 1 BLOCK_SIZE = 1
elif ATTENTION == "flashdecoding-ipex":
BLOCK_SIZE = 64
else: else:
BLOCK_SIZE = 16 BLOCK_SIZE = 16

View File

@ -79,10 +79,13 @@ class Model(ABC):
"Prefill chunking will be turned off", "Prefill chunking will be turned off",
) )
support_chunking = False support_chunking = False
if ATTENTION not in ["flashinfer", "flashdecoding"] and support_chunking: if (
ATTENTION not in ["flashinfer", "flashdecoding", "flashdecoding-ipex"]
and support_chunking
):
log_master( log_master(
logger.warning, logger.warning,
"Prefill chunking is only supported with `flashinfer` or `flashdecoding` attention types.", "Prefill chunking is only supported with `flashinfer` or `flashdecoding` or `flashdecoding-ipex` attention types.",
) )
support_chunking = False support_chunking = False