add flashdecoding-ipex

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2024-12-19 18:33:25 -08:00
parent ac67673788
commit d6ac8cdf81
6 changed files with 30 additions and 18 deletions

View File

@ -224,7 +224,7 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/loca
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
FROM ${PLATFORM} AS final FROM ${PLATFORM} AS final
ENV ATTENTION=flashdecoding ENV ATTENTION=flashdecoding-ipex
ENV PREFIX_CACHING=1 ENV PREFIX_CACHING=1
ENV PREFILL_CHUNKING=1 ENV PREFILL_CHUNKING=1
ENV CUDA_GRAPHS=0 ENV CUDA_GRAPHS=0

View File

@ -144,7 +144,9 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
} }
} }
let fallback_attention = if matches!(compute_capability, Some((major, _)) if major < 8) { let fallback_attention = if compute_capability.is_none()
|| matches!(compute_capability, Some((major, _)) if major < 8)
{
"paged" "paged"
} else { } else {
"flashdecoding" "flashdecoding"

View File

@ -31,7 +31,7 @@ def attention(
out = torch.empty_like(query) out = torch.empty_like(query)
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
if ATTENTION == "flashdecoding": if ATTENTION == "flashdecoding-ipex":
ipex.llm.modules.PagedAttention.flash_attn_varlen_func( ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
out, out,
query.contiguous() if query.device.type == "xpu" else query, query.contiguous() if query.device.type == "xpu" else query,
@ -84,7 +84,7 @@ def paged_attention(
out = torch.empty_like(query) out = torch.empty_like(query)
if ATTENTION == "flashdecoding": if ATTENTION == "flashdecoding-ipex":
ipex.llm.modules.PagedAttention.flash_attn_varlen_func( ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
out, out,
query.contiguous() if query.device.type == "xpu" else query, query.contiguous() if query.device.type == "xpu" else query,

View File

@ -66,8 +66,8 @@ class KVCache:
else: else:
x = BLOCK_SIZE // element_size x = BLOCK_SIZE // element_size
if ATTENTION in {"flashdecoding", "flashinfer"} and not ( if ATTENTION in {"flashdecoding", "flashinfer"} or (
SYSTEM == "ipex" and device == torch.device("cpu") ATTENTION == "flashdecoding-ipex" and device.type == "xpu"
): ):
self.kv_cache = ( self.kv_cache = (
torch.empty( torch.empty(
@ -82,6 +82,7 @@ class KVCache:
), ),
) )
elif SYSTEM == "ipex" and device == torch.device("cpu"): elif SYSTEM == "ipex" and device == torch.device("cpu"):
# ipex cpu flashdecoding kernel and paged attention kernel share same layout
self.kv_cache = ( self.kv_cache = (
torch.empty( torch.empty(
(num_blocks, num_heads, BLOCK_SIZE, head_size), (num_blocks, num_heads, BLOCK_SIZE, head_size),
@ -176,9 +177,7 @@ class KVCache:
scalar=True, scalar=True,
)[0] )[0]
if ATTENTION in {"flashdecoding", "flashinfer"} and not ( if ATTENTION in {"flashdecoding", "flashinfer"}:
SYSTEM == "ipex" and key.device == torch.device("cpu")
):
key = key.to(key_cache.dtype) key = key.to(key_cache.dtype)
value = value.to(value_cache.dtype) value = value.to(value_cache.dtype)
if key_cache.dtype in {torch.float8_e4m3fn, torch.float8_e5m2}: if key_cache.dtype in {torch.float8_e4m3fn, torch.float8_e5m2}:
@ -191,6 +190,12 @@ class KVCache:
shape = key_cache.shape shape = key_cache.shape
key_cache.view(-1, shape[-2], shape[-1])[slots] = key key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value value_cache.view(-1, shape[-2], shape[-1])[slots] = value
elif ATTENTION == "flashdecoding-ipex" and key.device.type == "xpu":
import intel_extension_for_pytorch as ipex
ipex.llm.modules.PagedAttention.reshape_and_cache_flash(
key, value, key_cache, value_cache, slots
)
else: else:
paged_reshape_and_cache(key, value, key_cache, value_cache, slots) paged_reshape_and_cache(key, value, key_cache, value_cache, slots)

View File

@ -4,7 +4,6 @@ from loguru import logger
from typing import Dict, Optional from typing import Dict, Optional
from text_generation_server.utils.log import log_master from text_generation_server.utils.log import log_master
from text_generation_server.utils.import_utils import SYSTEM
REQUEST_LOGPROBS = os.getenv("REQUEST_LOGPROBS", "0").lower() in {"1", "true"} REQUEST_LOGPROBS = os.getenv("REQUEST_LOGPROBS", "0").lower() in {"1", "true"}
ATTENTION = os.environ["ATTENTION"] ATTENTION = os.environ["ATTENTION"]
@ -15,13 +14,17 @@ PREFIX_CACHING = os.environ["PREFIX_CACHING"].lower() in {
} }
PREFILL_CHUNKING = os.getenv("PREFILL_CHUNKING", "1").lower() in {"1", "true"} PREFILL_CHUNKING = os.getenv("PREFILL_CHUNKING", "1").lower() in {"1", "true"}
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}") log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
_expected = {"paged", "flashdecoding", "flashinfer"} _expected = {"paged", "flashdecoding", "flashdecoding-ipex", "flashinfer"}
assert ( assert (
ATTENTION in _expected ATTENTION in _expected
), f"Attention is not valid {ATTENTION}, expected {_expected}" ), f"Attention is not valid {ATTENTION}, expected {_expected}"
log_master(logger.info, f"Using Attention = {ATTENTION}") log_master(logger.info, f"Using Attention = {ATTENTION}")
if PREFIX_CACHING and ATTENTION not in {"flashinfer", "flashdecoding"}: if PREFIX_CACHING and ATTENTION not in {
"flashinfer",
"flashdecoding",
"flashdecoding-ipex",
}:
raise RuntimeError("Prefix caching is only supported with flashinfer") raise RuntimeError("Prefix caching is only supported with flashinfer")
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
@ -33,12 +36,11 @@ assert TGI_WIGGLE_ROOM < 1
# This is overridden by the cli # This is overridden by the cli
BLOCK_SIZE: int BLOCK_SIZE: int
if ATTENTION == "flashdecoding": if ATTENTION == "flashdecoding":
if SYSTEM == "ipex": BLOCK_SIZE = 256
BLOCK_SIZE = 64
else:
BLOCK_SIZE = 256
elif ATTENTION == "flashinfer": elif ATTENTION == "flashinfer":
BLOCK_SIZE = 1 BLOCK_SIZE = 1
elif ATTENTION == "flashdecoding-ipex":
BLOCK_SIZE = 64
else: else:
BLOCK_SIZE = 16 BLOCK_SIZE = 16

View File

@ -79,10 +79,13 @@ class Model(ABC):
"Prefill chunking will be turned off", "Prefill chunking will be turned off",
) )
support_chunking = False support_chunking = False
if ATTENTION not in ["flashinfer", "flashdecoding"] and support_chunking: if (
ATTENTION not in ["flashinfer", "flashdecoding", "flashdecoding-ipex"]
and support_chunking
):
log_master( log_master(
logger.warning, logger.warning,
"Prefill chunking is only supported with `flashinfer` or `flashdecoding` attention types.", "Prefill chunking is only supported with `flashinfer` or `flashdecoding` or `flashdecoding-ipex` attention types.",
) )
support_chunking = False support_chunking = False